preprocess_data.py 15.5 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
Mohammad's avatar
Mohammad committed
2

liangjing's avatar
v1  
liangjing committed
3
"""Processing large data for pretraining."""
4
import argparse
liangjing's avatar
v1  
liangjing committed
5
import math
6
import json
Mohammad's avatar
Mohammad committed
7
import os
8
import sys
Mohammad's avatar
Mohammad committed
9
10
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
                                             os.path.pardir)))
11
import time
liangjing's avatar
v1  
liangjing committed
12
13
import gzip
import glob
14
import torch
liangjing's avatar
v1  
liangjing committed
15
16
import numpy as np
import multiprocessing
17
18
19
20
21
22
23
24
25
try:
    import nltk
    nltk_available = True
except ImportError:
    nltk_available = False

from megatron.tokenizer import build_tokenizer
from megatron.data import indexed_dataset

Mohammad's avatar
Mohammad committed
26

27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
# https://stackoverflow.com/questions/33139531/preserve-empty-lines-with-nltks-punkt-tokenizer
class CustomLanguageVars(nltk.tokenize.punkt.PunktLanguageVars):

    _period_context_fmt = r"""
        \S*                          # some word material
        %(SentEndChars)s             # a potential sentence ending
        \s*                       #  <-- THIS is what I changed
        (?=(?P<after_tok>
            %(NonWord)s              # either other punctuation
            |
            (?P<next_tok>\S+)     #  <-- Normally you would have \s+ here
        ))"""

class IdentitySplitter(object):
    def tokenize(self, *text):
        return text

liangjing's avatar
v1  
liangjing committed
44

45
46
47
48
49
50
51
52
53
54
55
class Encoder(object):
    def __init__(self, args):
        self.args = args

    def initializer(self):
        # Use Encoder class as a container for global data
        Encoder.tokenizer = build_tokenizer(self.args)
        if self.args.split_sentences:
            if not nltk_available:
                print("NLTK is not available to split sentences.")
                exit()
56
57
            library = "tokenizers/punkt/{}.pickle".format(self.args.lang)
            splitter = nltk.load(library)
58
59
60
            if self.args.keep_newlines:
                # this prevents punkt from eating newlines after sentences
                Encoder.splitter = nltk.tokenize.punkt.PunktSentenceTokenizer(
liangjing's avatar
v1  
liangjing committed
61
62
                    train_text = splitter._params,
                    lang_vars = CustomLanguageVars())
63
64
65
66
67
68
            else:
                Encoder.splitter = splitter

        else:
            Encoder.splitter = IdentitySplitter()

liangjing's avatar
v1  
liangjing committed
69
70
71
72
73
74
75
76
77
78
    def split(self, json_line):
        data = json.loads(json_line)
        output = {}
        for key in self.args.json_keys:
            text = data[key]
            max_len = 1000000
            tokens_list = [Encoder.splitter.tokenize(text[i:i+max_len]) for i in range(0, len(text), max_len)]
            output[key] = [tokens for partial in tokens_list for tokens in partial]
        return json.dumps(output), len(json_line)

79
80
81
    def encode(self, json_line):
        data = json.loads(json_line)
        ids = {}
liangjing's avatar
v1  
liangjing committed
82
        lens = {}
83
84
        for key in self.args.json_keys:
            text = data[key]
liangjing's avatar
v1  
liangjing committed
85
86
87
88
            if isinstance(text, list):
                sentences = text
            else:
                sentences = [text]
89
            doc_ids = []
liangjing's avatar
v1  
liangjing committed
90
91
            sentence_lens = []
            for sentence in sentences:
92
93
                sentence_ids = Encoder.tokenizer.tokenize(sentence)
                if len(sentence_ids) > 0:
liangjing's avatar
v1  
liangjing committed
94
95
                    doc_ids.extend(sentence_ids)
                    sentence_lens.append(len(sentence_ids))
96
            if len(doc_ids) > 0 and self.args.append_eod:
liangjing's avatar
v1  
liangjing committed
97
98
                doc_ids.append(Encoder.tokenizer.eod)
                sentence_lens[-1] += 1
99
            ids[key] = doc_ids
liangjing's avatar
v1  
liangjing committed
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
            lens[key] = sentence_lens
        return ids, lens, len(json_line)


class Partition(object):
    def __init__(self, args, workers):
        self.args = args
        self.workers = workers

    def print_processing_stats(self, count, proc_start, total_bytes_processed):
        if count % self.args.log_interval == 0:
            current = time.time()
            elapsed = current - proc_start
            mbs = total_bytes_processed/elapsed/1024/1024
            print(f"Processed {count} documents",
                  f"({count/elapsed} docs/s, {mbs} MB/s).",
                  file=sys.stderr)

    def split_sentences(self, file_name):
        input_file_name, output_file_name = file_name
        print("Opening", input_file_name)
        fin = open(input_file_name, 'r', encoding='utf-8')
        fout = open(output_file_name, 'w')

        encoder = Encoder(self.args)
        pool = multiprocessing.Pool(self.workers, initializer=encoder.initializer)
        split_docs = pool.imap(encoder.split, fin, 32)

        proc_start = time.time()
        total_bytes_processed = 0
        for i, (doc, bytes_processed) in enumerate(split_docs, start=1):
            total_bytes_processed += bytes_processed
            fout.write(doc + "\n")
            self.print_processing_stats(i, proc_start, total_bytes_processed)

        fin.close()
        fout.close()


    def process_json_file(self, file_name):
        input_file_name, output_prefix = file_name
        print("Opening", input_file_name)
        fin = open(input_file_name, 'r', encoding='utf-8')

        startup_start = time.time()
        encoder = Encoder(self.args)
        tokenizer = build_tokenizer(self.args)
        pool = multiprocessing.Pool(self.workers, initializer=encoder.initializer)
        encoded_docs = pool.imap(encoder.encode, fin, 32)

        level = "document"
        if self.args.split_sentences:
            level = "sentence"

        output_bin_files = {}
        output_idx_files = {}
        builders = {}

        for key in self.args.json_keys:
            output_bin_files[key] = "{}_{}_{}.bin".format(output_prefix,
                                                          key, level)
            output_idx_files[key] = "{}_{}_{}.idx".format(output_prefix,
                                                          key, level)
            builders[key] = indexed_dataset.make_builder(output_bin_files[key],
                                                   impl=self.args.dataset_impl,
                                                   vocab_size=tokenizer.vocab_size)

        startup_end = time.time()
        proc_start = time.time()
        total_bytes_processed = 0
        print("Time to startup:", startup_end - startup_start)
        for i, (doc, sentence_lens, bytes_processed) in enumerate(encoded_docs, start=1):
            total_bytes_processed += bytes_processed
            for key in doc.keys():
                builders[key].add_doc(doc[key], sentence_lens[key])
            self.print_processing_stats(i, proc_start, total_bytes_processed)

        fin.close()
        builders[key].finalize(output_idx_files[key])

180
181
182
183
184
185
186
187
188
189
190
191
192
193
194

def get_args():
    parser = argparse.ArgumentParser()
    group = parser.add_argument_group(title='input data')
    group.add_argument('--input', type=str, required=True,
                       help='Path to input JSON')
    group.add_argument('--json-keys', nargs='+', default=['text'],
                       help='space separate listed of keys to extract from json')
    group.add_argument('--split-sentences', action='store_true',
                       help='Split documents into sentences.')
    group.add_argument('--keep-newlines', action='store_true',
                       help='Keep newlines between sentences when splitting.')

    group = parser.add_argument_group(title='tokenizer')
    group.add_argument('--tokenizer-type', type=str, required=True,
Raul Puri's avatar
Raul Puri committed
195
                       choices=['BertWordPieceLowerCase','BertWordPieceCase',
liangjing's avatar
v1  
liangjing committed
196
197
                                'GPT2BPETokenizer', 'SentencePieceTokenizer',
                                'GPTSentencePieceTokenizer', 'NullTokenizer'],
198
                       help='What type of tokenizer to use.')
liangjing's avatar
v1  
liangjing committed
199
200
    group.add_argument('--tokenizer-model', type=str, default=None,
                       help='YTTM tokenizer model.')
201
202
    group.add_argument('--vocab-file', type=str, default=None,
                       help='Path to the vocab file')
liangjing's avatar
v1  
liangjing committed
203
204
    group.add_argument('--vocab-size', default=786,
                       help='size of vocab for use with NullTokenizer')
205
206
207
208
    group.add_argument('--merge-file', type=str, default=None,
                       help='Path to the BPE merge file (if necessary).')
    group.add_argument('--append-eod', action='store_true',
                       help='Append an <eod> token to the end of a document.')
209
210
    group.add_argument('--lang', type=str, default='english',
                       help='Language to use for NLTK-powered sentence splitting.')
211
212
213
214
215
216
217
    group = parser.add_argument_group(title='output data')
    group.add_argument('--output-prefix', type=str, required=True,
                       help='Path to binary output file without suffix')
    group.add_argument('--dataset-impl', type=str, default='mmap',
                       choices=['lazy', 'cached', 'mmap'])

    group = parser.add_argument_group(title='runtime')
218
    group.add_argument('--workers', type=int, required=True,
liangjing's avatar
v1  
liangjing committed
219
220
221
222
223
224
                       help=('Number of worker processes to launch.'
                             'A good default for fast pre-processing '
                             'is: (workers * partitions) = available CPU cores.'))
    group.add_argument('--partitions', type=int, default=1,
                        help='Number of file partitions')
    group.add_argument('--log-interval', type=int, default=1000,
225
                       help='Interval between progress updates')
liangjing's avatar
v1  
liangjing committed
226
227
228
    group.add_argument('--keep-sequential-samples', action='store_true',
                       help='Ensure ordering of samples in .jsonl files is '
                            'preserved when using partitions>1.')
229
230
231
    args = parser.parse_args()
    args.keep_empty = False

liangjing's avatar
v1  
liangjing committed
232
233
    if args.tokenizer_type.lower().startswith('bert') and not args.split_sentences:
        print("Are you sure you don't want to split sentences?")
234
235

    # some default/dummy values for the tokenizer
liangjing's avatar
v1  
liangjing committed
236
    args.rank = 1
237
    args.make_vocab_size_divisible_by = 128
238
    args.tensor_model_parallel_size = 1
Jared Casper's avatar
Jared Casper committed
239
    args.vocab_extra_ids = 0
240
241
242

    return args

liangjing's avatar
v1  
liangjing committed
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262

def get_file_name(args, file_id):
    file_name, extension = os.path.splitext(args.input)
    input_file_name = file_name + "_" + str(file_id) + extension
    sentence_split_file = file_name + "_ss_" + str(file_id) + extension
    output_prefix = args.output_prefix + "_" + str(file_id)
    file_names = {
        'partition': input_file_name,
        'sentence_split': sentence_split_file,
        'output_prefix': output_prefix}
    return file_names


def check_files_exist(in_ss_out_names, key, num_partitions):
    for i in range(num_partitions):
        if not os.path.exists(in_ss_out_names[i][key]):
            return False
    return True


263
264
265
def main():
    args = get_args()

liangjing's avatar
v1  
liangjing committed
266
267
268
269
270
271
    if args.split_sentences:
        if nltk_available:
            nltk.download("punkt", quiet=True)
        else:
            raise Exception(
                "nltk library required for sentence splitting is not available.")
272

liangjing's avatar
v1  
liangjing committed
273
274
275
276
277
278
279
280
281
282
283
    in_ss_out_names = []
    if args.partitions == 1:
        file_name, extension = os.path.splitext(args.input)
        sentence_split_file = file_name + "_ss" + extension
        file_names = {
            'partition': args.input,
            'sentence_split': sentence_split_file,
            'output_prefix': args.output_prefix}
        in_ss_out_names.append(file_names)
    else:
        in_file_names = glob.glob(args.input)
284

liangjing's avatar
v1  
liangjing committed
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
        # Count total number of lines across .jsonl files
        if args.keep_sequential_samples:
            total_sample_count = 0
            for filename in in_file_names:
                with open(filename, "r") as fin:
                    for fc, _ in enumerate(fin):
                        pass
                total_sample_count += (fc + 1)
            partition_size = math.ceil(total_sample_count / args.partitions)

        # create .jsonl parition files
        for idx in range(args.partitions):
            in_ss_out_name = get_file_name(args, idx)
            in_ss_out_names.append(in_ss_out_name)

        # check to see if paritions were already created
        partitions_present = check_files_exist(in_ss_out_names, 'partition', args.partitions)

        # check to see if paritions with split sentences already created
        split_sentences_present = check_files_exist(in_ss_out_names, 'sentence_split', args.partitions)

        if not partitions_present and not split_sentences_present:
            # populate .jsonl partition files from parent files
            partitioned_input_files = []
            for idx in range(args.partitions):
                partitioned_input_file = open(in_ss_out_names[idx]['partition'], 'w')
                partitioned_input_files.append(partitioned_input_file)

            index = 0
            if args.keep_sequential_samples: line_count = 0
            for in_file_name in in_file_names:
                # support for gzip files
                if in_file_name.endswith(".gz"):
                    fin = gzip.open(in_file_name, 'rt')
                else:
                    fin = open(in_file_name, 'r', encoding='utf-8')

                for line in fin:
                    partitioned_input_files[index].write(line)
                    if args.keep_sequential_samples:
                        line_count += 1
                        if line_count % partition_size == 0:
                            index += 1
                    else:
                        index = (index + 1)%args.partitions

                fin.close()

            for idx in range(args.partitions):
                partitioned_input_files[idx].close()

    assert args.workers % args.partitions == 0
    partition = Partition(args, args.workers//args.partitions)

    # check to see if paritions with split sentences already created
    split_sentences_present = check_files_exist(in_ss_out_names, 'sentence_split', args.partitions)
341

liangjing's avatar
v1  
liangjing committed
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
    # split sentences in partition files
    if args.split_sentences and not split_sentences_present:
        processes = []
        for name in in_ss_out_names:
            p = multiprocessing.Process(target=partition.split_sentences,
                                        args=((name['partition'], name['sentence_split']),))
            p.start()
            processes.append(p)

        for p in processes:
            p.join()

        if args.partitions == 1:
            return


    # encode partition files in parallel
    processes = []
    input_key = 'sentence_split' if args.split_sentences else 'partition'
    for name in in_ss_out_names:
        p = multiprocessing.Process(target=partition.process_json_file,
                                    args=((name[input_key], name['output_prefix']),))
        p.start()
        processes.append(p)

    for p in processes:
        p.join()

    if args.partitions == 1:
        return

    # merge bin/idx partitions
374
375
376
377
    level = "document"
    if args.split_sentences:
        level = "sentence"

378
379
380
    output_bin_files = {}
    output_idx_files = {}
    builders = {}
liangjing's avatar
v1  
liangjing committed
381
382
    tokenizer = build_tokenizer(args)

383
    for key in args.json_keys:
384
385
386
387
        output_bin_files[key] = "{}_{}_{}.bin".format(args.output_prefix,
                                                      key, level)
        output_idx_files[key] = "{}_{}_{}.idx".format(args.output_prefix,
                                                      key, level)
388
        builders[key] = indexed_dataset.make_builder(output_bin_files[key],
liangjing's avatar
v1  
liangjing committed
389
390
391
392
393
394
395
                                                     impl=args.dataset_impl,
                                                     vocab_size=tokenizer.vocab_size)
        for name in in_ss_out_names:
            parition_output_prefix = name['output_prefix']
            full_partition_output_prefix = "{}_{}_{}".format(parition_output_prefix,
                                                             key, level)
            builders[key].merge_file_(full_partition_output_prefix)
396
397
        builders[key].finalize(output_idx_files[key])

liangjing's avatar
v1  
liangjing committed
398

399
if __name__ == '__main__':
liangjing's avatar
v1  
liangjing committed
400

401
    main()
liangjing's avatar
v1  
liangjing committed
402