build.py 16.9 KB
Newer Older
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
# Copyright (c) 2023, NVIDIA CORPORATION.  All rights reserved.

from collections import defaultdict
from concurrent.futures import as_completed, ProcessPoolExecutor
from functools import reduce
import glob
import json
import numpy as np
import os
from pathlib import Path
import threading
import torch
from tqdm import tqdm
import types

from megatron import get_retro_args, print_rank_0
from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset
from megatron.tokenizer.tokenizer import (
    _BertWordPieceTokenizer,
    _GPT2BPETokenizer,
)
from tools.bert_embedding.utils import get_missing_blocks_by_rank
from tools.retro.external_libs import h5py
from tools.retro.utils import get_gpt_tokenizer, get_bert_tokenizer

from .utils import (
liangjing's avatar
v1  
liangjing committed
27
28
    get_indexed_dataset_infos,
    get_indexed_dataset_infos_path,
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
29
    get_individual_db_dir,
liangjing's avatar
v1  
liangjing committed
30
31
    get_individual_chunk_db,
    get_individual_doc_offsets,
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
    get_merged_dataset,
    get_merged_db_path_map,
    save_indexed_dataset_infos,
)


def init_indexed_dataset_infos():
    '''Gather meta-info about each indexed dataset.

    The returned info array allows for easy access to the configuration, and
    helps remove ambiguity.
    '''

    args = get_retro_args()

    assert len(args.data_path) % 2 == 0, \
        "currently, only blendable dataset is supported."

    # Dataset infos.
    infos = []
    for i in range(0, len(args.data_path), 2):
        ratio = float(args.data_path[i])
        prefix = args.data_path[i + 1]
        path = prefix + ".bin"
        name = os.path.basename(prefix)
liangjing's avatar
v1  
liangjing committed
57
        assert os.path.exists(path), "couldn't find '%s'." % path
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
        infos.append({
            "ratio" : ratio,
            "prefix" : prefix,
            "path" : path,
            "name" : name,
            "db_dir" : get_individual_db_dir(name),
            "dataset" : make_indexed_dataset(prefix, "mmap", True),
        })

    return infos


def build_partial_db(
        dataset_idx,
        n_datasets,
        indexed_dataset,
        block_id,
        n_blocks,
        block,
        proc_id,
        n_procs,
        tokenizers,
):
    '''Process a document index range of the indexed dataset.

    The chunk database is built in parallel blocks, since de-tokenizing &
    re-tokenizing for Bert-length computation is expensive. This method
    iterates each document and extracts sequential 'chunk-length' sequences
    from each document.
    '''

    args = get_retro_args()

    # Document start/end indexes.
    doc_range = block["range"]
    n_docs = doc_range[1] - doc_range[0]
    n_docs_per_proc = int(np.ceil(n_docs / n_procs))
    doc_start_id = doc_range[0] + proc_id * n_docs_per_proc
    doc_end_id = min(doc_range[1], doc_start_id + n_docs_per_proc)

    # Print progress.
    progress_proc_ids = set(range(n_procs)) \
        if torch.distributed.get_rank() == 0 else set()
    if proc_id in progress_proc_ids:
        print(" > building partial chunk db, proc %d / %d, docs %d:%d / %d."%(
            proc_id,
            n_procs,
            doc_start_id,
            doc_end_id,
            n_docs,
        ))

    # Progress bars (snapshot of overall progress).
    doc_id_iter = range(doc_start_id, doc_end_id)
    pbar = tqdm(doc_id_iter) \
        if proc_id in progress_proc_ids else \
           doc_id_iter

    # Iterate documents & parse chunks.
    chunk_db_valid = []
    chunk_db_invalid = []
liangjing's avatar
v1  
liangjing committed
119
    doc_size_map = {}
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
    for doc_id in pbar:

        # Progress description.
        try:
            pbar.set_description("ds %d / %d, block %d / %d, proc %d / %d." % (
                dataset_idx,
                n_datasets,
                block_id,
                n_blocks,
                proc_id,
                n_procs))
        except:
            pass

        # Remove EOD token.
        doc = indexed_dataset.get(doc_id)
liangjing's avatar
v1  
liangjing committed
136
        if doc[-1].item() == tokenizers.gpt.eod:
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
137
138
139
140
141
142
143
144
145
            doc = doc[:-1]
        doc_len = len(doc)

        # Chunk start/end indexes.
        chunk_start_idxs = list(range(0, doc_len, args.retro_gpt_chunk_length))
        chunk_end_idxs = [min(doc_len, s + args.retro_gpt_chunk_length)
                          for s in chunk_start_idxs]

        # Re-tokenize each chunk to Bert/Wordpiece (empty bert -> 'invalid').
liangjing's avatar
v1  
liangjing committed
146
        doc_size_map[doc_id] = 0
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
147
148
149
150
151
152
153
154
155
        for i, chunk_start_idx in enumerate(chunk_start_idxs):

            # Re-tokenize.
            chunk_end_idx = chunk_end_idxs[i]
            gpt_token_ids = indexed_dataset.get(
                idx=doc_id,
                offset=chunk_start_idx,
                length=chunk_end_idx - chunk_start_idx,
            )
liangjing's avatar
v1  
liangjing committed
156
            text = tokenizers.gpt.detokenize(gpt_token_ids.tolist())
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
157
158
159
            bert_token_ids = tokenizers.bert.tokenize(text)

            # 'Valid' for non-empty Bert chunks; 'invalid' otherwise.
liangjing's avatar
v1  
liangjing committed
160
161
162
163
164
            if len(bert_token_ids) == 0:
                _chunk_db = chunk_db_invalid
            else:
                _chunk_db = chunk_db_valid
                doc_size_map[doc_id] += 1
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
165
166
167
168
169
170
171
            _chunk_db.append((
                doc_id,
                chunk_start_idx,
                chunk_end_idx,
                len(bert_token_ids),
            ))

liangjing's avatar
v1  
liangjing committed
172
    return proc_id, chunk_db_valid, chunk_db_invalid, doc_size_map
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189


def build_individual_db(dataset_idx, n_datasets, dataset_info, tokenizers):
    '''Process a single indexed dataset & extract chunks.'''

    args = get_retro_args()

    # Make directory.
    db_dir = dataset_info["db_dir"]
    os.makedirs(db_dir, exist_ok=True)

    # Indexed dataset.
    indexed_dataset = dataset_info["dataset"]

    # Missing db blocks.
    n_missing_world, missing_db_blocks = get_missing_blocks_by_rank(
        db_dir,
liangjing's avatar
v1  
liangjing committed
190
        len(indexed_dataset),
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
191
        args.retro_doc_block_size,
liangjing's avatar
v1  
liangjing committed
192
193
        validate=lambda f : f["chunks_valid"].shape == (0,) \
            or f["chunks_valid"].shape[1] == 4)
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218

    # Prevent missing-path-write race condition.
    torch.distributed.barrier()

    if not missing_db_blocks:
        return

    # Num processes.
    if n_missing_world == 1:
        n_procs = 128
    elif n_missing_world <= 2:
        n_procs = 64
    elif n_missing_world <= 4:
        n_procs = 32
    elif n_missing_world <= 8:
        n_procs = 16
    else:
        n_procs = 8

    # Process documents in parallel.
    with ProcessPoolExecutor(max_workers=n_procs) as executor:
        for block_idx, block in enumerate(missing_db_blocks):

            if block is not None:

liangjing's avatar
v1  
liangjing committed
219
220
                db_path = block["path"]

Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
                # Build partial dbs.
                print_rank_0(' > build partial dbs.')
                futures = []
                for proc_id in range(n_procs): # not true process id
                    futures.append(executor.submit(
                        build_partial_db,
                        dataset_idx,
                        n_datasets,
                        indexed_dataset,
                        block_idx,
                        len(missing_db_blocks),
                        block,
                        proc_id,
                        n_procs,
                        tokenizers,
                    ))
                partial_chunk_dbs = []
                for future in as_completed(futures):
                    partial_chunk_dbs.append(future.result())

                # Concatenate chunks.
                partial_chunk_dbs.sort(key=lambda item:item[0]) # sort by proc_id
                chunk_db_valid = [item
                                  for partial_chunk_db in partial_chunk_dbs
                                  for item in partial_chunk_db[1]]
                chunk_db_invalid = [item
                                    for partial_chunk_db in partial_chunk_dbs
                                    for item in partial_chunk_db[2]]

                # Convert to numpy.
                print_rank_0(' > converting chunk db to numpy.')
liangjing's avatar
v1  
liangjing committed
252
253
254
255
256
257
258
259
260
261
262
263
264
                chunk_db_valid = np.array(chunk_db_valid, dtype="uint32")
                chunk_db_invalid = np.array(chunk_db_invalid, dtype="uint32")

                # Document offsets.
                doc_sizes = [(d, s)
                             for partial_chunk_db in partial_chunk_dbs
                             for d, s in partial_chunk_db[3].items()]
                doc_sizes.sort(key = lambda item : item[0])
                doc_offsets = np.cumsum([item[1] for item in doc_sizes]) \
                                .astype("uint64")
                doc_offsets = np.stack((
                    np.array([item[0] for item in doc_sizes], dtype="uint64"),
                    doc_offsets), axis=1)
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
265
266
267

                # Save DB.
                print_rank_0(" > saving individual db.")
liangjing's avatar
v1  
liangjing committed
268
269
270
271
272
                with h5py.File(db_path, "w") as f:
                    dset = f.create_dataset("chunks_valid", data=chunk_db_valid)
                    dset = f.create_dataset("chunks_invalid",
                                            data=chunk_db_invalid)
                    dset = f.create_dataset("doc_offsets", data=doc_offsets)
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315

            # Wait for all ranks to finish block.
            print_rank_0(" > waiting for all ranks to finish block.")
            torch.distributed.barrier()

    print_rank_0(" > finished saving individual db.")


def build_individual_dbs(indexed_dataset_infos):
    '''Iterate each indexed dataset & process its chunks.'''

    args = get_retro_args()

    # Tokenizers.
    tokenizers = types.SimpleNamespace(
        gpt=get_gpt_tokenizer(),
        bert=get_bert_tokenizer(),
    )

    # Build individual DBs.
    print_rank_0(" > build individual chunk dbs.")
    for ds_idx, ds_info in enumerate(indexed_dataset_infos):

        # Progress.
        print_rank_0(" > building individual db, dataset %d / %d ... '%s'." % (
            ds_idx,
            len(indexed_dataset_infos),
            ds_info["name"],
        ))

        # Process single dataset.
        build_individual_db(ds_idx, len(indexed_dataset_infos),
                            ds_info, tokenizers)


def update_chunk_counts(indexed_dataset_infos):
    '''Set n_chunks_train & n_chunks sampled for each individual DB.'''

    args = get_retro_args()

    if torch.distributed.get_rank() != 0:
        return

liangjing's avatar
v1  
liangjing committed
316
317
318
    # Data ratio sum (for setting index training chunks).
    data_ratio_sum = sum([ d["ratio"] for d in indexed_dataset_infos ])

Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
319
320
321
322
323
324
    # Training split size (split at document level).
    train_fraction = float(args.split.split(",")[0]) / 100
    assert train_fraction > 0 and train_fraction <= 1

    # Set n_chunks (including n_chunks_sampled for unambiguity).
    print_rank_0(" > compute n_chunks.")
liangjing's avatar
v1  
liangjing committed
325
    for ds_index, ds_info in enumerate(indexed_dataset_infos):
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
326
327
328
329
330
331
332
333
334
335

        db_dir = ds_info["db_dir"]
        db_paths = sorted(glob.glob(db_dir + "/*.hdf5"))

        # Update counts.
        ds_info["n_docs"] = len(ds_info["dataset"].doc_idx) - 1
        ds_info["n_docs_train"] = int(train_fraction * ds_info["n_docs"])
        ds_info["n_chunks"] = 0 # previously, 'n_chunks_valid'
        ds_info["n_chunks_train"] = 0
        ds_info["n_chunks_invalid"] = 0
liangjing's avatar
v1  
liangjing committed
336
337
338
        for db_path in tqdm(db_paths, "%d/%d, %s" % (
                ds_index, len(indexed_dataset_infos), ds_info["name"])):
           with h5py.File(db_path, "r") as f:
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
339
340
341
342
343
344
                ds_info["n_chunks"] += len(f["chunks_valid"])
                ds_info["n_chunks_invalid"] += len(f["chunks_invalid"])
                ds_info["n_chunks_train"] += \
                    (np.copy(f["chunks_valid"][:, 0]) < ds_info["n_docs_train"]) \
                    .sum().item()

liangjing's avatar
v1  
liangjing committed
345
346
        ds_info["n_chunks_sampled"] = int(args.retro_index_ntrain *
                                          ds_info["ratio"] / data_ratio_sum)
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365

        # Verify counts.
        assert ds_info["n_chunks_train"] <= ds_info["n_chunks"], \
            "n_train (%d) > n_total (%d)." % (
                ds_info["n_chunks_train"], ds_info["n_chunks"])
        assert ds_info["n_chunks_sampled"] <= ds_info["n_chunks_train"], \
            "n_sampled (%d) > n_train (%d)." % (
                ds_info["n_chunks_sampled"], ds_info["n_chunks_train"])


def merge_dbs(indexed_dataset_infos, db_type):
    '''Merge individual DBs into single DB.'''

    if torch.distributed.get_rank() != 0:
        return

    print(" > build %s chunk db." % db_type)

    # Count chunks.
liangjing's avatar
v1  
liangjing committed
366
    if db_type == "sampled":
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
367
        n_chunks_key = "n_chunks_sampled"
liangjing's avatar
v1  
liangjing committed
368
        n_docs_key = None
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
369
370
    elif db_type == "train":
        n_chunks_key = "n_chunks_train"
liangjing's avatar
v1  
liangjing committed
371
        n_docs_key = "n_docs_train"
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
372
    elif db_type == "valid":
liangjing's avatar
v1  
liangjing committed
373
        n_docs_key = None
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
374
375
376
377
378
379
380
381
    else:
        raise Exception("handle db_type '%s'." % db_type)

    if db_type == "valid":
        n_chunks = sum(m["n_chunks"] - m["n_chunks_train"]
                       for m in indexed_dataset_infos)
    else:
        n_chunks = sum(m[n_chunks_key] for m in indexed_dataset_infos)
liangjing's avatar
v1  
liangjing committed
382
383
        n_docs = None if n_docs_key is None else \
            sum(m[n_docs_key] for m in indexed_dataset_infos)
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402

    # DB path.
    db_path = get_merged_db_path_map()[db_type]

    # Delete existing chunk db if incorrect size.
    if os.path.exists(db_path):

        try:

            f = h5py.File(db_path)
            n_alloc = len(f["chunks"])           # total allocated
            n_written = f["n_written"][0].item() # total written
            f.close()

            if n_chunks != n_alloc or n_chunks != n_written:
                os.remove(db_path)

        except Exception as e:
            if isinstance(e, OSError):
liangjing's avatar
v1  
liangjing committed
403
                os.remove(db_path)
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
404
405
            elif isinstance(e, KeyError):
                f.close()
liangjing's avatar
v1  
liangjing committed
406
                os.remove(db_path)
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
407
408
409
410
411
412
413
414
415
416
            else:
                raise e

    # Build merged chunk db.
    if not os.path.exists(db_path):

        os.makedirs(os.path.dirname(db_path), exist_ok=True)
        f = h5py.File(db_path, "w")

        # Initialize output arrays.
liangjing's avatar
v1  
liangjing committed
417
418
419
420
        merged_chunk_db = \
            f.create_dataset("chunks", (n_chunks, 5), dtype="uint32")
        merged_doc_offsets = None if n_docs_key is None else \
            f.create_dataset("doc_offsets", (n_docs, 3), dtype="uint64")
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
421
422
423
424
        n_written = f.create_dataset("n_written", (1,), dtype="uint64")
        n_written[0] = 0

        # Iterate indexed datasets & collect chunks.
liangjing's avatar
v1  
liangjing committed
425
426
427
        chunk_start_index = 0
        doc_start_index = 0
        doc_start_offset = 0
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
428
429
430
        for ds_idx, ds_info in enumerate(indexed_dataset_infos):
            print(" > merging dbs; '%s', dataset %d / %d ... '%s'." %
                  (db_type, ds_idx, len(indexed_dataset_infos), ds_info["name"]))
liangjing's avatar
v1  
liangjing committed
431
432
433
            individual_chunk_db = get_individual_chunk_db(ds_idx, ds_info)
            individual_doc_offsets = None if n_docs_key is None else \
                get_individual_doc_offsets(ds_idx, ds_info)
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
434
435

            if db_type == "valid":
liangjing's avatar
v1  
liangjing committed
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
                individual_chunk_db = \
                    individual_chunk_db[ds_info["n_chunks_train"]:]
                if n_docs_key is None:
                    individual_doc_offsets = None
                else:
                    train_doc_offset = \
                        individual_doc_offsets[ds_info["n_docs_train"] - 1, 2]
                    individual_doc_offsets = \
                        np.copy(individual_doc_offsets[ds_info["n_docs_train"]:])
                    individual_doc_offsets[:, 2] -= train_doc_offset

                    print("~~~")
                    print(individual_doc_offsets)
                    print(train_doc_offset)
                    raise Exception("test me.")
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
451
            else:
liangjing's avatar
v1  
liangjing committed
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
                individual_chunk_db = \
                    individual_chunk_db[:ds_info[n_chunks_key]]
                individual_doc_offsets = None if n_docs_key is None else \
                    np.copy(individual_doc_offsets[:ds_info[n_docs_key]])

            merged_chunk_db[chunk_start_index:chunk_start_index+len(individual_chunk_db)] = individual_chunk_db
            chunk_start_index += len(individual_chunk_db)
            n_written[0] = chunk_start_index
            if n_docs_key is not None:
                individual_doc_offsets[:, 2] += doc_start_offset
                doc_end_index = doc_start_index + individual_doc_offsets.shape[0]
                merged_doc_offsets[doc_start_index:doc_end_index] = \
                    individual_doc_offsets
                doc_start_index = doc_end_index
                doc_start_offset = individual_doc_offsets[-1, 2].item()
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487

        f.close()


def build_db():
    '''Extract token chunks from each indexed dataset.

    Iterate each document of each indexed dataset, extract that document's
    chunks, and save to a 'DB' (hdf5 file).
    '''

    # Indexed dataset info.
    indexed_dataset_infos = init_indexed_dataset_infos()

    # Build dbs.
    build_individual_dbs(indexed_dataset_infos)

    # Single-process going forward.
    if torch.distributed.get_rank() != 0:
        return

liangjing's avatar
v1  
liangjing committed
488
489
490
491
492
    # Update n_chunks & save indexed dataset infos.
    if not os.path.exists(get_indexed_dataset_infos_path()):
        update_chunk_counts(indexed_dataset_infos)
        save_indexed_dataset_infos(indexed_dataset_infos)
    indexed_dataset_infos = get_indexed_dataset_infos()
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
493
494
495
496
497

    # Merge dbs.
    merge_dbs(indexed_dataset_infos, "sampled")
    merge_dbs(indexed_dataset_infos, "train")
    merge_dbs(indexed_dataset_infos, "valid")