preprocess_data.py 15.8 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
Mohammad's avatar
Mohammad committed
2

liangjing's avatar
v1  
liangjing committed
3
"""Processing large data for pretraining."""
4
import argparse
liangjing's avatar
v1  
liangjing committed
5
import math
6
import json
Mohammad's avatar
Mohammad committed
7
import os
8
import sys
Mohammad's avatar
Mohammad committed
9
10
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
                                             os.path.pardir)))
11
import time
liangjing's avatar
v1  
liangjing committed
12
13
import gzip
import glob
14
import torch
liangjing's avatar
v1  
liangjing committed
15
16
import numpy as np
import multiprocessing
17
18
try:
    import nltk
liangjing's avatar
liangjing committed
19
    from nltk.tokenize.punkt import PunktLanguageVars
20
21
    nltk_available = True
except ImportError:
liangjing's avatar
liangjing committed
22
    PunktLanguageVars = object  # Fallback to the built-in object class
23
24
    nltk_available = False

liangjing's avatar
liangjing committed
25
26
from megatron.training.tokenizer import build_tokenizer
from megatron.core.datasets import indexed_dataset
27

Mohammad's avatar
Mohammad committed
28

29
# https://stackoverflow.com/questions/33139531/preserve-empty-lines-with-nltks-punkt-tokenizer
liangjing's avatar
liangjing committed
30
class CustomLanguageVars(PunktLanguageVars):
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45

    _period_context_fmt = r"""
        \S*                          # some word material
        %(SentEndChars)s             # a potential sentence ending
        \s*                       #  <-- THIS is what I changed
        (?=(?P<after_tok>
            %(NonWord)s              # either other punctuation
            |
            (?P<next_tok>\S+)     #  <-- Normally you would have \s+ here
        ))"""

class IdentitySplitter(object):
    def tokenize(self, *text):
        return text

liangjing's avatar
v1  
liangjing committed
46

47
48
49
50
51
52
53
54
55
56
57
class Encoder(object):
    def __init__(self, args):
        self.args = args

    def initializer(self):
        # Use Encoder class as a container for global data
        Encoder.tokenizer = build_tokenizer(self.args)
        if self.args.split_sentences:
            if not nltk_available:
                print("NLTK is not available to split sentences.")
                exit()
liangjing's avatar
liangjing committed
58
59
60
61
62
63
64
            if os.environ.get("NLTK_DATA"):
                library = os.path.join(os.environ.get("NLTK_DATA"), "tokenizers", "punkt", f"{self.args.lang}.pickle")
                url = f"file:{library}"
            else:
                library = os.path.join("tokenizers", "punkt", f"{self.args.lang}.pickle")
                url = f"nltk:{library}"
            splitter = nltk.load(url)
65
66
67
            if self.args.keep_newlines:
                # this prevents punkt from eating newlines after sentences
                Encoder.splitter = nltk.tokenize.punkt.PunktSentenceTokenizer(
liangjing's avatar
v1  
liangjing committed
68
69
                    train_text = splitter._params,
                    lang_vars = CustomLanguageVars())
70
71
72
73
74
75
            else:
                Encoder.splitter = splitter

        else:
            Encoder.splitter = IdentitySplitter()

liangjing's avatar
v1  
liangjing committed
76
77
78
79
80
81
82
83
84
85
    def split(self, json_line):
        data = json.loads(json_line)
        output = {}
        for key in self.args.json_keys:
            text = data[key]
            max_len = 1000000
            tokens_list = [Encoder.splitter.tokenize(text[i:i+max_len]) for i in range(0, len(text), max_len)]
            output[key] = [tokens for partial in tokens_list for tokens in partial]
        return json.dumps(output), len(json_line)

86
87
88
    def encode(self, json_line):
        data = json.loads(json_line)
        ids = {}
liangjing's avatar
v1  
liangjing committed
89
        lens = {}
90
91
        for key in self.args.json_keys:
            text = data[key]
liangjing's avatar
v1  
liangjing committed
92
93
94
95
            if isinstance(text, list):
                sentences = text
            else:
                sentences = [text]
96
            doc_ids = []
liangjing's avatar
v1  
liangjing committed
97
98
            sentence_lens = []
            for sentence in sentences:
99
100
                sentence_ids = Encoder.tokenizer.tokenize(sentence)
                if len(sentence_ids) > 0:
liangjing's avatar
v1  
liangjing committed
101
102
                    doc_ids.extend(sentence_ids)
                    sentence_lens.append(len(sentence_ids))
103
            if len(doc_ids) > 0 and self.args.append_eod:
liangjing's avatar
v1  
liangjing committed
104
105
                doc_ids.append(Encoder.tokenizer.eod)
                sentence_lens[-1] += 1
106
            ids[key] = doc_ids
liangjing's avatar
v1  
liangjing committed
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
            lens[key] = sentence_lens
        return ids, lens, len(json_line)


class Partition(object):
    def __init__(self, args, workers):
        self.args = args
        self.workers = workers

    def print_processing_stats(self, count, proc_start, total_bytes_processed):
        if count % self.args.log_interval == 0:
            current = time.time()
            elapsed = current - proc_start
            mbs = total_bytes_processed/elapsed/1024/1024
            print(f"Processed {count} documents",
                  f"({count/elapsed} docs/s, {mbs} MB/s).",
                  file=sys.stderr)

    def split_sentences(self, file_name):
        input_file_name, output_file_name = file_name
        print("Opening", input_file_name)
        fin = open(input_file_name, 'r', encoding='utf-8')
        fout = open(output_file_name, 'w')

        encoder = Encoder(self.args)
        pool = multiprocessing.Pool(self.workers, initializer=encoder.initializer)
        split_docs = pool.imap(encoder.split, fin, 32)

        proc_start = time.time()
        total_bytes_processed = 0
        for i, (doc, bytes_processed) in enumerate(split_docs, start=1):
            total_bytes_processed += bytes_processed
            fout.write(doc + "\n")
            self.print_processing_stats(i, proc_start, total_bytes_processed)

        fin.close()
        fout.close()


    def process_json_file(self, file_name):
        input_file_name, output_prefix = file_name
        print("Opening", input_file_name)
        fin = open(input_file_name, 'r', encoding='utf-8')

        startup_start = time.time()
        encoder = Encoder(self.args)
        tokenizer = build_tokenizer(self.args)
        pool = multiprocessing.Pool(self.workers, initializer=encoder.initializer)
        encoded_docs = pool.imap(encoder.encode, fin, 32)

        level = "document"
        if self.args.split_sentences:
            level = "sentence"

        output_bin_files = {}
        output_idx_files = {}
        builders = {}

        for key in self.args.json_keys:
            output_bin_files[key] = "{}_{}_{}.bin".format(output_prefix,
                                                          key, level)
            output_idx_files[key] = "{}_{}_{}.idx".format(output_prefix,
                                                          key, level)
liangjing's avatar
liangjing committed
170
171
172
173
            builders[key] = indexed_dataset.IndexedDatasetBuilder(
                output_bin_files[key],
                dtype=indexed_dataset.DType.optimal_dtype(tokenizer.vocab_size),
            )
liangjing's avatar
v1  
liangjing committed
174
175
176
177
178
179
180
181

        startup_end = time.time()
        proc_start = time.time()
        total_bytes_processed = 0
        print("Time to startup:", startup_end - startup_start)
        for i, (doc, sentence_lens, bytes_processed) in enumerate(encoded_docs, start=1):
            total_bytes_processed += bytes_processed
            for key in doc.keys():
liangjing's avatar
liangjing committed
182
                builders[key].add_document(doc[key], sentence_lens[key])
liangjing's avatar
v1  
liangjing committed
183
184
185
186
187
            self.print_processing_stats(i, proc_start, total_bytes_processed)

        fin.close()
        builders[key].finalize(output_idx_files[key])

188
189
190
191
192
193
194
195
196
197
198
199
200
201
202

def get_args():
    parser = argparse.ArgumentParser()
    group = parser.add_argument_group(title='input data')
    group.add_argument('--input', type=str, required=True,
                       help='Path to input JSON')
    group.add_argument('--json-keys', nargs='+', default=['text'],
                       help='space separate listed of keys to extract from json')
    group.add_argument('--split-sentences', action='store_true',
                       help='Split documents into sentences.')
    group.add_argument('--keep-newlines', action='store_true',
                       help='Keep newlines between sentences when splitting.')

    group = parser.add_argument_group(title='tokenizer')
    group.add_argument('--tokenizer-type', type=str, required=True,
Raul Puri's avatar
Raul Puri committed
203
                       choices=['BertWordPieceLowerCase','BertWordPieceCase',
liangjing's avatar
v1  
liangjing committed
204
                                'GPT2BPETokenizer', 'SentencePieceTokenizer',
liangjing's avatar
liangjing committed
205
206
                                'GPTSentencePieceTokenizer', 'Llama2Tokenizer',
                                'Llama3Tokenizer', 'MistralTokenizer', 'NullTokenizer'],
207
                       help='What type of tokenizer to use.')
liangjing's avatar
v1  
liangjing committed
208
209
    group.add_argument('--tokenizer-model', type=str, default=None,
                       help='YTTM tokenizer model.')
210
211
    group.add_argument('--vocab-file', type=str, default=None,
                       help='Path to the vocab file')
liangjing's avatar
v1  
liangjing committed
212
213
    group.add_argument('--vocab-size', default=786,
                       help='size of vocab for use with NullTokenizer')
214
215
216
217
    group.add_argument('--merge-file', type=str, default=None,
                       help='Path to the BPE merge file (if necessary).')
    group.add_argument('--append-eod', action='store_true',
                       help='Append an <eod> token to the end of a document.')
218
219
    group.add_argument('--lang', type=str, default='english',
                       help='Language to use for NLTK-powered sentence splitting.')
220
221
222
223
224
    group = parser.add_argument_group(title='output data')
    group.add_argument('--output-prefix', type=str, required=True,
                       help='Path to binary output file without suffix')

    group = parser.add_argument_group(title='runtime')
225
    group.add_argument('--workers', type=int, required=True,
liangjing's avatar
v1  
liangjing committed
226
227
228
229
230
231
                       help=('Number of worker processes to launch.'
                             'A good default for fast pre-processing '
                             'is: (workers * partitions) = available CPU cores.'))
    group.add_argument('--partitions', type=int, default=1,
                        help='Number of file partitions')
    group.add_argument('--log-interval', type=int, default=1000,
232
                       help='Interval between progress updates')
liangjing's avatar
v1  
liangjing committed
233
234
235
    group.add_argument('--keep-sequential-samples', action='store_true',
                       help='Ensure ordering of samples in .jsonl files is '
                            'preserved when using partitions>1.')
236
237
238
    args = parser.parse_args()
    args.keep_empty = False

liangjing's avatar
v1  
liangjing committed
239
240
    if args.tokenizer_type.lower().startswith('bert') and not args.split_sentences:
        print("Are you sure you don't want to split sentences?")
241
242

    # some default/dummy values for the tokenizer
liangjing's avatar
v1  
liangjing committed
243
    args.rank = 1
244
    args.make_vocab_size_divisible_by = 128
245
    args.tensor_model_parallel_size = 1
Jared Casper's avatar
Jared Casper committed
246
    args.vocab_extra_ids = 0
247
248
249

    return args

liangjing's avatar
v1  
liangjing committed
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269

def get_file_name(args, file_id):
    file_name, extension = os.path.splitext(args.input)
    input_file_name = file_name + "_" + str(file_id) + extension
    sentence_split_file = file_name + "_ss_" + str(file_id) + extension
    output_prefix = args.output_prefix + "_" + str(file_id)
    file_names = {
        'partition': input_file_name,
        'sentence_split': sentence_split_file,
        'output_prefix': output_prefix}
    return file_names


def check_files_exist(in_ss_out_names, key, num_partitions):
    for i in range(num_partitions):
        if not os.path.exists(in_ss_out_names[i][key]):
            return False
    return True


270
271
272
def main():
    args = get_args()

liangjing's avatar
v1  
liangjing committed
273
274
    if args.split_sentences:
        if nltk_available:
liangjing's avatar
liangjing committed
275
            nltk.download("punkt", quiet=True, download_dir=os.environ.get("NLTK_DATA"))
liangjing's avatar
v1  
liangjing committed
276
277
278
        else:
            raise Exception(
                "nltk library required for sentence splitting is not available.")
279

liangjing's avatar
v1  
liangjing committed
280
281
282
283
284
285
286
287
288
289
290
    in_ss_out_names = []
    if args.partitions == 1:
        file_name, extension = os.path.splitext(args.input)
        sentence_split_file = file_name + "_ss" + extension
        file_names = {
            'partition': args.input,
            'sentence_split': sentence_split_file,
            'output_prefix': args.output_prefix}
        in_ss_out_names.append(file_names)
    else:
        in_file_names = glob.glob(args.input)
291

liangjing's avatar
v1  
liangjing committed
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
        # Count total number of lines across .jsonl files
        if args.keep_sequential_samples:
            total_sample_count = 0
            for filename in in_file_names:
                with open(filename, "r") as fin:
                    for fc, _ in enumerate(fin):
                        pass
                total_sample_count += (fc + 1)
            partition_size = math.ceil(total_sample_count / args.partitions)

        # create .jsonl parition files
        for idx in range(args.partitions):
            in_ss_out_name = get_file_name(args, idx)
            in_ss_out_names.append(in_ss_out_name)

        # check to see if paritions were already created
        partitions_present = check_files_exist(in_ss_out_names, 'partition', args.partitions)

        # check to see if paritions with split sentences already created
        split_sentences_present = check_files_exist(in_ss_out_names, 'sentence_split', args.partitions)

        if not partitions_present and not split_sentences_present:
            # populate .jsonl partition files from parent files
            partitioned_input_files = []
            for idx in range(args.partitions):
                partitioned_input_file = open(in_ss_out_names[idx]['partition'], 'w')
                partitioned_input_files.append(partitioned_input_file)

            index = 0
            if args.keep_sequential_samples: line_count = 0
            for in_file_name in in_file_names:
                # support for gzip files
                if in_file_name.endswith(".gz"):
                    fin = gzip.open(in_file_name, 'rt')
                else:
                    fin = open(in_file_name, 'r', encoding='utf-8')

                for line in fin:
                    partitioned_input_files[index].write(line)
                    if args.keep_sequential_samples:
                        line_count += 1
                        if line_count % partition_size == 0:
                            index += 1
                    else:
                        index = (index + 1)%args.partitions

                fin.close()

            for idx in range(args.partitions):
                partitioned_input_files[idx].close()

    assert args.workers % args.partitions == 0
    partition = Partition(args, args.workers//args.partitions)

    # check to see if paritions with split sentences already created
    split_sentences_present = check_files_exist(in_ss_out_names, 'sentence_split', args.partitions)
348

liangjing's avatar
v1  
liangjing committed
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
    # split sentences in partition files
    if args.split_sentences and not split_sentences_present:
        processes = []
        for name in in_ss_out_names:
            p = multiprocessing.Process(target=partition.split_sentences,
                                        args=((name['partition'], name['sentence_split']),))
            p.start()
            processes.append(p)

        for p in processes:
            p.join()

        if args.partitions == 1:
            return


    # encode partition files in parallel
    processes = []
    input_key = 'sentence_split' if args.split_sentences else 'partition'
    for name in in_ss_out_names:
        p = multiprocessing.Process(target=partition.process_json_file,
                                    args=((name[input_key], name['output_prefix']),))
        p.start()
        processes.append(p)

    for p in processes:
        p.join()

    if args.partitions == 1:
        return

    # merge bin/idx partitions
381
382
383
384
    level = "document"
    if args.split_sentences:
        level = "sentence"

385
386
387
    output_bin_files = {}
    output_idx_files = {}
    builders = {}
liangjing's avatar
v1  
liangjing committed
388
389
    tokenizer = build_tokenizer(args)

390
    for key in args.json_keys:
391
392
393
394
        output_bin_files[key] = "{}_{}_{}.bin".format(args.output_prefix,
                                                      key, level)
        output_idx_files[key] = "{}_{}_{}.idx".format(args.output_prefix,
                                                      key, level)
liangjing's avatar
liangjing committed
395
396
397
398
399
        builders[key] = indexed_dataset.IndexedDatasetBuilder(
            output_bin_files[key],
            dtype=indexed_dataset.DType.optimal_dtype(tokenizer.vocab_size),
        )

liangjing's avatar
v1  
liangjing committed
400
401
402
403
        for name in in_ss_out_names:
            parition_output_prefix = name['output_prefix']
            full_partition_output_prefix = "{}_{}_{}".format(parition_output_prefix,
                                                             key, level)
liangjing's avatar
liangjing committed
404
            builders[key].add_index(full_partition_output_prefix)
405
406
        builders[key].finalize(output_idx_files[key])

liangjing's avatar
v1  
liangjing committed
407

408
if __name__ == '__main__':
liangjing's avatar
v1  
liangjing committed
409

410
    main()
liangjing's avatar
v1  
liangjing committed
411