realm_dataset_utils.py 7.55 KB
Newer Older
xingjinliang's avatar
xingjinliang committed
1
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2
3
4
5
6
7
import os
import time

import numpy as np
import torch

xingjinliang's avatar
xingjinliang committed
8
from megatron.training import print_rank_0
9
from megatron.core import mpu, tensor_parallel
xingjinliang's avatar
xingjinliang committed
10
11
from megatron.legacy.data.dataset_utils import create_masked_lm_predictions, pad_and_convert_to_numpy
from megatron.training import get_args, get_tokenizer, print_rank_0
12
13


14
def get_one_epoch_dataloader(dataset, micro_batch_size=None):
Neel Kant's avatar
Neel Kant committed
15
16
17
18
19
    """Specifically one epoch to be used in an indexing job."""
    args = get_args()

    world_size = mpu.get_data_parallel_world_size()
    rank = mpu.get_data_parallel_rank()
20
21
22
    if micro_batch_size is None:
        micro_batch_size = args.micro_batch_size
    global_batch_size = micro_batch_size * world_size
Neel Kant's avatar
Neel Kant committed
23
24
25
26
    num_workers = args.num_workers

    sampler = torch.utils.data.SequentialSampler(dataset)
    # importantly, drop_last must be False to get all the data.
mohammad's avatar
mohammad committed
27
    assert False, 'DistributedBatchSampler deprecated, change the implementation'
xingjinliang's avatar
xingjinliang committed
28
    from megatron.legacy.data.samplers import DistributedBatchSampler
Neel Kant's avatar
Neel Kant committed
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
    batch_sampler = DistributedBatchSampler(sampler,
                                            batch_size=global_batch_size,
                                            drop_last=False,
                                            rank=rank,
                                            world_size=world_size)

    return torch.utils.data.DataLoader(dataset,
                                       batch_sampler=batch_sampler,
                                       num_workers=num_workers,
                                       pin_memory=True)


def get_ict_batch(data_iterator):
    # Items and their type.
    keys = ['query_tokens', 'query_pad_mask',
            'block_tokens', 'block_pad_mask', 'block_data']
    datatype = torch.int64

    # Broadcast data.
    if data_iterator is None:
        data = None
    else:
        data = next(data_iterator)
52
    data_b = tensor_parallel.broadcast_data(keys, data, datatype)
Neel Kant's avatar
Neel Kant committed
53
54
55
56
57
58
59
60
61
62
63
64

    # Unpack.
    query_tokens = data_b['query_tokens'].long()
    query_pad_mask = data_b['query_pad_mask'].long()
    block_tokens = data_b['block_tokens'].long()
    block_pad_mask = data_b['block_pad_mask'].long()
    block_indices = data_b['block_data'].long()

    return query_tokens, query_pad_mask,\
           block_tokens, block_pad_mask, block_indices


65
66
67
68
69
70
71
72
73
74
75
def join_str_list(str_list):
    """Join a list of strings, handling spaces appropriately"""
    result = ""
    for s in str_list:
        if s.startswith("##"):
            result += s[2:]
        else:
            result += " " + s
    return result


76
77
class BlockSampleData(object):
    """A struct for fully describing a fixed-size block of data as used in REALM
78

79
80
81
82
83
84
85
86
87
88
    :param start_idx: for first sentence of the block
    :param end_idx: for last sentence of the block (may be partially truncated in sample construction)
    :param doc_idx: the index of the document from which the block comes in the original indexed dataset
    :param block_idx: a unique integer identifier given to every block.
    """
    def __init__(self, start_idx, end_idx, doc_idx, block_idx):
        self.start_idx = start_idx
        self.end_idx = end_idx
        self.doc_idx = doc_idx
        self.block_idx = block_idx
89

90
91
    def as_array(self):
        return np.array([self.start_idx, self.end_idx, self.doc_idx, self.block_idx]).astype(np.int64)
92

93
94
    def as_tuple(self):
        return self.start_idx, self.end_idx, self.doc_idx, self.block_idx
95
96


97
98
99
100
101
class BlockSamplesMapping(object):
    def __init__(self, mapping_array):
        # make sure that the array is compatible with BlockSampleData
        assert mapping_array.shape[1] == 4
        self.mapping_array = mapping_array
Neel Kant's avatar
Neel Kant committed
102
103
104

    def __len__(self):
        return self.mapping_array.shape[0]
105

106
    def __getitem__(self, idx):
Neel Kant's avatar
Neel Kant committed
107
        """Get the data associated with an indexed sample."""
Neel Kant's avatar
Neel Kant committed
108
        sample_data = BlockSampleData(*self.mapping_array[idx])
109
        return sample_data
110
111
112


def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epochs,
113
                              max_num_samples, max_seq_length, seed, name, use_one_sent_docs=False):
114
    """Get samples mapping for a dataset over fixed size blocks. This function also requires
115
    a dataset of the titles for the source documents since their lengths must be taken into account.
116

117
118
    :return: samples_mapping (BlockSamplesMapping)
    """
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136

    if not num_epochs:
        if not max_num_samples:
            raise ValueError("Need to specify either max_num_samples "
                             "or num_epochs")
        num_epochs = np.iinfo(np.int32).max - 1
    if not max_num_samples:
        max_num_samples = np.iinfo(np.int64).max - 1

    # Filename of the index mapping
    indexmap_filename = data_prefix
    indexmap_filename += '_{}_indexmap'.format(name)
    if num_epochs != (np.iinfo(np.int32).max - 1):
        indexmap_filename += '_{}ep'.format(num_epochs)
    if max_num_samples != (np.iinfo(np.int64).max - 1):
        indexmap_filename += '_{}mns'.format(max_num_samples)
    indexmap_filename += '_{}msl'.format(max_seq_length)
    indexmap_filename += '_{}s'.format(seed)
137
138
    if use_one_sent_docs:
        indexmap_filename += '_1sentok'
139
140
141
    indexmap_filename += '.npy'

    # Build the indexed mapping if not exist.
142
    if mpu.get_data_parallel_rank() == 0 and \
143
144
145
146
147
            not os.path.isfile(indexmap_filename):
        print(' > WARNING: could not find index map file {}, building '
              'the indices on rank 0 ...'.format(indexmap_filename))

        # Make sure the types match the helpers input types.
xingjinliang's avatar
xingjinliang committed
148
149
        assert block_dataset.document_indices.dtype == np.int64
        assert block_dataset.sequence_lengths.dtype == np.int32
150
151
152
153
154
155

        # Build samples mapping
        verbose = torch.distributed.get_rank() == 0
        start_time = time.time()
        print_rank_0(' > building samples index mapping for {} ...'.format(
            name))
156

xingjinliang's avatar
xingjinliang committed
157
        from megatron.core.datasets import helpers
158
        mapping_array = helpers.build_blocks_mapping(
xingjinliang's avatar
xingjinliang committed
159
160
161
            block_dataset.document_indices,
            block_dataset.sequence_lengths,
            title_dataset.sequence_lengths,
162
163
            num_epochs,
            max_num_samples,
164
            max_seq_length - 3,  # account for added tokens
165
            seed,
166
167
            verbose,
            use_one_sent_docs)
Neel Kant's avatar
Neel Kant committed
168

169

170
        print_rank_0(' > done building samples index mapping')
Neel Kant's avatar
Neel Kant committed
171
        np.save(indexmap_filename, mapping_array, allow_pickle=True)
172
173
174
175
176
177
        print_rank_0(' > saved the index mapping in {}'.format(
            indexmap_filename))
        # Make sure all the ranks have built the mapping
        print_rank_0(' > elapsed time to build and save samples mapping '
                     '(seconds): {:4f}'.format(
            time.time() - start_time))
178

179
180
181
    # This should be a barrier but nccl barrier assumes
    # device_index=rank which is not the case for model
    # parallel case
xingjinliang's avatar
xingjinliang committed
182
    counts = torch.tensor([1], dtype=torch.long, device='cuda')
183
    torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group())
Neel Kant's avatar
Neel Kant committed
184
185
    assert counts[0].item() == torch.distributed.get_world_size(
        group=mpu.get_data_parallel_group())
186
187
188
189
190

    # Load indexed dataset.
    print_rank_0(' > loading indexed mapping from {}'.format(
        indexmap_filename))
    start_time = time.time()
Neel Kant's avatar
Neel Kant committed
191

Neel Kant's avatar
Neel Kant committed
192
    mapping_array = np.load(indexmap_filename, allow_pickle=True, mmap_mode='r')
Neel Kant's avatar
Neel Kant committed
193
194
    samples_mapping = BlockSamplesMapping(mapping_array)

195
196
197
    print_rank_0('    loaded indexed file in {:3.3f} seconds'.format(
        time.time() - start_time))
    print_rank_0('    total number of samples: {}'.format(
Neel Kant's avatar
Neel Kant committed
198
        mapping_array.shape[0]))
199
200

    return samples_mapping