realm_dataset_utils.py 3.9 KB
Newer Older
1
2
3
4
5
6
import os
import time

import numpy as np
import torch

7
from megatron import mpu, print_rank_0
8
9
10
11
12
13
14
15
16
17
18
19
20
21


def join_str_list(str_list):
    """Join a list of strings, handling spaces appropriately"""
    result = ""
    for s in str_list:
        if s.startswith("##"):
            result += s[2:]
        else:
            result += " " + s
    return result


def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epochs,
22
                              max_num_samples, max_seq_length, seed, name, use_one_sent_docs=False):
23
24
    """Get samples mapping for a dataset over fixed size blocks. This function also requires
    a dataset of the titles for the source documents since their lengths must be taken into account."""
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
    if not num_epochs:
        if not max_num_samples:
            raise ValueError("Need to specify either max_num_samples "
                             "or num_epochs")
        num_epochs = np.iinfo(np.int32).max - 1
    if not max_num_samples:
        max_num_samples = np.iinfo(np.int64).max - 1

    # Filename of the index mapping
    indexmap_filename = data_prefix
    indexmap_filename += '_{}_indexmap'.format(name)
    if num_epochs != (np.iinfo(np.int32).max - 1):
        indexmap_filename += '_{}ep'.format(num_epochs)
    if max_num_samples != (np.iinfo(np.int64).max - 1):
        indexmap_filename += '_{}mns'.format(max_num_samples)
    indexmap_filename += '_{}msl'.format(max_seq_length)
    indexmap_filename += '_{}s'.format(seed)
42
43
    if use_one_sent_docs:
        indexmap_filename += '_1sentok'
44
45
46
    indexmap_filename += '.npy'

    # Build the indexed mapping if not exist.
47
    if mpu.get_data_parallel_rank() == 0 and \
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
            not os.path.isfile(indexmap_filename):
        print(' > WARNING: could not find index map file {}, building '
              'the indices on rank 0 ...'.format(indexmap_filename))

        # Make sure the types match the helpers input types.
        assert block_dataset.doc_idx.dtype == np.int64
        assert block_dataset.sizes.dtype == np.int32

        # Build samples mapping
        verbose = torch.distributed.get_rank() == 0
        start_time = time.time()
        print_rank_0(' > building samples index mapping for {} ...'.format(
            name))
        from megatron.data.dataset_utils import compile_helper
        compile_helper()
        from megatron.data import helpers
        samples_mapping = helpers.build_blocks_mapping(
            block_dataset.doc_idx,
            block_dataset.sizes,
            title_dataset.sizes,
            num_epochs,
            max_num_samples,
            max_seq_length-3,  # account for added tokens
            seed,
72
73
            verbose,
            use_one_sent_docs)
74
75
76
77
78
79
80
81
82
83
84
85
86
        print_rank_0(' > done building samples index mapping')
        np.save(indexmap_filename, samples_mapping, allow_pickle=True)
        print_rank_0(' > saved the index mapping in {}'.format(
            indexmap_filename))
        # Make sure all the ranks have built the mapping
        print_rank_0(' > elapsed time to build and save samples mapping '
                     '(seconds): {:4f}'.format(
            time.time() - start_time))
    # This should be a barrier but nccl barrier assumes
    # device_index=rank which is not the case for model
    # parallel case
    counts = torch.cuda.LongTensor([1])
    torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group())
Neel Kant's avatar
Neel Kant committed
87
88
    assert counts[0].item() == torch.distributed.get_world_size(
        group=mpu.get_data_parallel_group())
89
90
91
92
93

    # Load indexed dataset.
    print_rank_0(' > loading indexed mapping from {}'.format(
        indexmap_filename))
    start_time = time.time()
Neel Kant's avatar
Neel Kant committed
94
    samples_mapping = np.load(indexmap_filename, allow_pickle=True, mmap_mode='r')
95
96
97
98
99
100
    print_rank_0('    loaded indexed file in {:3.3f} seconds'.format(
        time.time() - start_time))
    print_rank_0('    total number of samples: {}'.format(
        samples_mapping.shape[0]))

    return samples_mapping