realm_dataset_utils.py 5.51 KB
Newer Older
1
2
3
4
5
6
import os
import time

import numpy as np
import torch

7
from megatron import mpu, print_rank_0
8
from megatron.data.dataset_utils import create_masked_lm_predictions, pad_and_convert_to_numpy
9
from megatron import get_args, get_tokenizer, print_rank_0, mpu
10
11
12
13
14
15
16
17
18
19
20
21
22


def join_str_list(str_list):
    """Join a list of strings, handling spaces appropriately"""
    result = ""
    for s in str_list:
        if s.startswith("##"):
            result += s[2:]
        else:
            result += " " + s
    return result


23
24
class BlockSampleData(object):
    """A struct for fully describing a fixed-size block of data as used in REALM
25

26
27
28
29
30
31
32
33
34
35
    :param start_idx: for first sentence of the block
    :param end_idx: for last sentence of the block (may be partially truncated in sample construction)
    :param doc_idx: the index of the document from which the block comes in the original indexed dataset
    :param block_idx: a unique integer identifier given to every block.
    """
    def __init__(self, start_idx, end_idx, doc_idx, block_idx):
        self.start_idx = start_idx
        self.end_idx = end_idx
        self.doc_idx = doc_idx
        self.block_idx = block_idx
36

37
38
    def as_array(self):
        return np.array([self.start_idx, self.end_idx, self.doc_idx, self.block_idx]).astype(np.int64)
39

40
41
    def as_tuple(self):
        return self.start_idx, self.end_idx, self.doc_idx, self.block_idx
42
43


44
45
46
47
48
class BlockSamplesMapping(object):
    def __init__(self, mapping_array):
        # make sure that the array is compatible with BlockSampleData
        assert mapping_array.shape[1] == 4
        self.mapping_array = mapping_array
Neel Kant's avatar
Neel Kant committed
49
        self.shape = self.mapping_array.shape
50

51
52
    def __getitem__(self, idx):
        """Get the data associated with a particular sample."""
Neel Kant's avatar
Neel Kant committed
53
        sample_data = BlockSampleData(*self.mapping_array[idx])
54
        return sample_data
55
56
57


def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epochs,
58
                              max_num_samples, max_seq_length, seed, name, use_one_sent_docs=False):
59
    """Get samples mapping for a dataset over fixed size blocks. This function also requires
60
    a dataset of the titles for the source documents since their lengths must be taken into account.
61

62
63
    :return: samples_mapping (BlockSamplesMapping)
    """
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81

    if not num_epochs:
        if not max_num_samples:
            raise ValueError("Need to specify either max_num_samples "
                             "or num_epochs")
        num_epochs = np.iinfo(np.int32).max - 1
    if not max_num_samples:
        max_num_samples = np.iinfo(np.int64).max - 1

    # Filename of the index mapping
    indexmap_filename = data_prefix
    indexmap_filename += '_{}_indexmap'.format(name)
    if num_epochs != (np.iinfo(np.int32).max - 1):
        indexmap_filename += '_{}ep'.format(num_epochs)
    if max_num_samples != (np.iinfo(np.int64).max - 1):
        indexmap_filename += '_{}mns'.format(max_num_samples)
    indexmap_filename += '_{}msl'.format(max_seq_length)
    indexmap_filename += '_{}s'.format(seed)
82
83
    if use_one_sent_docs:
        indexmap_filename += '_1sentok'
84
85
86
    indexmap_filename += '.npy'

    # Build the indexed mapping if not exist.
87
    if mpu.get_data_parallel_rank() == 0 and \
88
89
90
91
92
93
94
95
96
97
98
99
100
            not os.path.isfile(indexmap_filename):
        print(' > WARNING: could not find index map file {}, building '
              'the indices on rank 0 ...'.format(indexmap_filename))

        # Make sure the types match the helpers input types.
        assert block_dataset.doc_idx.dtype == np.int64
        assert block_dataset.sizes.dtype == np.int32

        # Build samples mapping
        verbose = torch.distributed.get_rank() == 0
        start_time = time.time()
        print_rank_0(' > building samples index mapping for {} ...'.format(
            name))
101
102

        # compile/bind the C++ helper code
103
104
        from megatron.data.dataset_utils import compile_helper
        compile_helper()
105

106
        from megatron.data import helpers
107
        mapping_array = helpers.build_blocks_mapping(
108
109
110
111
112
            block_dataset.doc_idx,
            block_dataset.sizes,
            title_dataset.sizes,
            num_epochs,
            max_num_samples,
113
            max_seq_length - 3,  # account for added tokens
114
            seed,
115
116
            verbose,
            use_one_sent_docs)
Neel Kant's avatar
Neel Kant committed
117

118

119
        print_rank_0(' > done building samples index mapping')
Neel Kant's avatar
Neel Kant committed
120
        np.save(indexmap_filename, mapping_array, allow_pickle=True)
121
122
123
124
125
126
        print_rank_0(' > saved the index mapping in {}'.format(
            indexmap_filename))
        # Make sure all the ranks have built the mapping
        print_rank_0(' > elapsed time to build and save samples mapping '
                     '(seconds): {:4f}'.format(
            time.time() - start_time))
127

128
129
130
131
132
    # This should be a barrier but nccl barrier assumes
    # device_index=rank which is not the case for model
    # parallel case
    counts = torch.cuda.LongTensor([1])
    torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group())
Neel Kant's avatar
Neel Kant committed
133
134
    assert counts[0].item() == torch.distributed.get_world_size(
        group=mpu.get_data_parallel_group())
135
136
137
138
139

    # Load indexed dataset.
    print_rank_0(' > loading indexed mapping from {}'.format(
        indexmap_filename))
    start_time = time.time()
Neel Kant's avatar
Neel Kant committed
140
141
142
143

    mapping_array = np.load(indexmap_filename, allow_pickle=True)
    samples_mapping = BlockSamplesMapping(mapping_array)

144
145
146
147
148
149
    print_rank_0('    loaded indexed file in {:3.3f} seconds'.format(
        time.time() - start_time))
    print_rank_0('    total number of samples: {}'.format(
        samples_mapping.shape[0]))

    return samples_mapping