preprocess_data.py 4.58 KB
Newer Older
1
import argparse
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
2
import json
3
import multiprocessing
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
4
import nltk
5
6
7
8
import sys
import time

import torch
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
9
10

from bert_tokenization import FullTokenizer
11
import indexed_dataset
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
12

13
class CustomLanguageVars(nltk.tokenize.punkt.PunktLanguageVars):
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
14

15
16
17
18
19
20
21
22
23
    _period_context_fmt = r"""
        \S*                          # some word material
        %(SentEndChars)s             # a potential sentence ending
        \s*                       #  <-- THIS is what I changed
        (?=(?P<after_tok>
            %(NonWord)s              # either other punctuation
            |
            (?P<next_tok>\S+)     #  <-- Normally you would have \s+ here
        ))"""
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
24

25
26
27
class Encoder(object):
    def __init__(self, args):
        self.args = args
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
28

29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
    def initializer(self):
        # Use Encoder class as a container for global data
        Encoder.tokenizer = FullTokenizer(self.args.vocab, do_lower_case=True)
        spliter = nltk.load("tokenizers/punkt/english.pickle")
        if self.args.keep_newlines:
            # this prevents punkt from eating newlines after sentences
            Encoder.spliter = nltk.tokenize.punkt.PunktSentenceTokenizer(
                train_text = spliter._params,
                lang_vars = CustomLanguageVars())
        else:
            Encoder.splitter = spliter

    def encode(self, json_line):
        text = json.loads(json_line)[self.args.json_key]
        doc_ids = []
        for sentence in Encoder.splitter.tokenize(text):
            tokens = Encoder.tokenizer.tokenize(sentence)
            ids = Encoder.tokenizer.convert_tokens_to_ids(tokens)
            doc_ids.append(ids)
        doc_ids.append([])
        return doc_ids, len(json_line)

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--input', type=str, help='Path to input JSON')
    parser.add_argument('--vocab', type=str, help='Path to vocab.txt')
    parser.add_argument('--json-key', type=str, default='text',
                        help='Key to extract from json')
    parser.add_argument('--output-prefix', type=str, help='Path to binary output file without suffix')
    parser.add_argument('--workers', type=int, default=20,
                        help='Number of worker processes to launch')
    parser.add_argument('--log-interval', type=int, default=100,
                        help='Interval between progress updates')
    parser.add_argument('--keep-newlines', action='store_true',
                        help='Keep newlines between sentences.')
    parser.add_argument('--dataset-impl', type=str, default='mmap',
                        choices=['lazy', 'cached', 'mmap'])
    args = parser.parse_args()
    args.keep_empty = False
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
68

69
    startup_start = time.time()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
70

71
72
    print("Opening", args.input)
    fin = open(args.input, 'r', encoding='utf-8')
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
73

74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
    vocab_size = 1

    nltk.download("punkt", quiet=True)

    encoder = Encoder(args)
    tokenizer = FullTokenizer(args.vocab, do_lower_case=True)
    pool = multiprocessing.Pool(args.workers, initializer=encoder.initializer)
    encoded_docs = pool.imap(encoder.encode, fin, 25)

    output_bin_file = "{}.bin".format(args.output_prefix)
    output_idx_file = "{}.idx".format(args.output_prefix)
    ds = indexed_dataset.make_builder(output_bin_file,
                                      impl=args.dataset_impl,
                                      vocab_size=vocab_size)

    startup_end = time.time()
    proc_start = time.time()
    total_bytes_processed = 0
    print("Time to startup:", startup_end - startup_start)
    for i, (doc, bytes_processed) in enumerate(encoded_docs, start=1):
        total_bytes_processed += bytes_processed
        for sentence in doc:
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
96
            print(sentence)
97
98
99
100
101
102
103
104
105
            print(tokenizer.convert_ids_to_tokens(sentence))
            ds.add_item(torch.IntTensor(sentence))
        if i % args.log_interval == 0:
            current = time.time()
            elapsed = current - proc_start
            mbs = total_bytes_processed/elapsed/1024/1024
            print(f'Processed {i} documents',
                  f"({i/elapsed} docs/s, {mbs} MB/s).",
                  file=sys.stderr)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
106

107
    ds.finalize(output_idx_file)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
108

109
110
111
if __name__ == '__main__':
    main()
    # print('processing data ...')
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
112

113
114
    # input_file = '/raid/mshoeybi/data/albert/sample/samples_11.json'
    # vocab_file = '/raid/mshoeybi/data/albert/bert_vocab/vocab.txt'
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
115

116
117
118
119
120
121
122
    # tokenizer = FullTokenizer(vocab_file, do_lower_case=True)
    # document_generator = document_generator_provider(input_file)
    # for sentences in document_generator:
    #     for sentence in sentences:
    #         tokens = tokenizer.tokenize(sentence)
    #         print(sentence)
    #         print(tokens)