ict_dataset.py 7.33 KB
Newer Older
1
import itertools
2
import random
3
4
import os
import time
5
6

import numpy as np
7
import torch
8
9
10
from torch.utils.data import Dataset

from megatron import get_tokenizer
11
12
13
from megatron import print_rank_0
from megatron import mpu
from megatron.data import helpers
14

15

16
class InverseClozeDataset(Dataset):
17
    """Dataset containing sentences and their blocks for an inverse cloze task."""
Neel Kant's avatar
Neel Kant committed
18
    def __init__(self, name, block_dataset, title_dataset, data_prefix,
19
20
21
22
23
                 num_epochs, max_num_samples, max_seq_length,
                 short_seq_prob, seed):
        self.name = name
        self.seed = seed
        self.max_seq_length = max_seq_length
Neel Kant's avatar
Neel Kant committed
24
25
        self.block_dataset = block_dataset
        self.title_dataset = title_dataset
26
        self.short_seq_prob = short_seq_prob
27
28
        self.rng = random.Random(self.seed)

Neel Kant's avatar
Neel Kant committed
29
30
        self.samples_mapping = self.get_samples_mapping(
            data_prefix, num_epochs, max_num_samples)
31
32
33
34
35
36
37
        self.tokenizer = get_tokenizer()
        self.vocab_id_list = list(self.tokenizer.inv_vocab.keys())
        self.vocab_id_to_token_list = self.tokenizer.inv_vocab
        self.cls_id = self.tokenizer.cls
        self.sep_id = self.tokenizer.sep
        self.mask_id = self.tokenizer.mask
        self.pad_id = self.tokenizer.pad
38
39

    def __len__(self):
40
        return self.samples_mapping.shape[0]
41
42

    def __getitem__(self, idx):
Neel Kant's avatar
Neel Kant committed
43
        start_idx, end_idx, doc_idx, block_idx = self.samples_mapping[idx]
Neel Kant's avatar
Neel Kant committed
44
45
46
        title = list(self.title_dataset[int(doc_idx)])
        block = [list(self.block_dataset[i]) for i in range(start_idx, end_idx)]
        assert len(block) > 1
47

48
        # avoid selecting the first or last sentence to be the query.
Neel Kant's avatar
Neel Kant committed
49
        if len(block) == 2:
50
51
            rand_sent_idx = int(self.rng.random() > 0.5)
        else:
Neel Kant's avatar
Neel Kant committed
52
            rand_sent_idx = self.rng.randint(1, len(block) - 2)
53

54
        # keep the query in the context 10% of the time.
55
        if self.rng.random() < 1:
Neel Kant's avatar
Neel Kant committed
56
            query = block[rand_sent_idx].copy()
57
        else:
Neel Kant's avatar
Neel Kant committed
58
            query = block.pop(rand_sent_idx)
59

Neel Kant's avatar
Neel Kant committed
60
        # still need to truncate because blocks are concluded when
61
        # the sentence lengths have exceeded max_seq_length.
Neel Kant's avatar
Neel Kant committed
62
63
        query = query[:self.max_seq_length - 2]
        block = list(itertools.chain(*block))[:self.max_seq_length - (3 + len(title))]
64

65
66
        query_tokens, query_pad_mask = self.concat_and_pad_tokens(query)
        block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title)
67
68

        sample = {
Neel Kant's avatar
Neel Kant committed
69
70
71
            'query_tokens': np.array(query_tokens),
            'query_pad_mask': np.array(query_pad_mask),
            'block_tokens': np.array(block_tokens),
Neel Kant's avatar
Neel Kant committed
72
            'block_pad_mask': np.array(block_pad_mask),
73
            'block_data': np.array([start_idx, end_idx, doc_idx, block_idx]).astype(np.int64)
74
75
76
77
        }

        return sample

78
79
80
81
82
    def encode_text(self, text):
        return self.tokenizer.tokenize(text)

    def decode_tokens(self, token_ids):
        tokens = self.tokenizer.tokenizer.convert_ids_to_tokens(token_ids)
Neel Kant's avatar
Neel Kant committed
83
        return ' '.join(token for token in tokens if token != '[PAD]')
84
85
86

    def get_block(self, start_idx, end_idx, doc_idx):
        """Get the IDs for an evidence block plus the title of the corresponding document"""
Neel Kant's avatar
Neel Kant committed
87
88
        block = [list(self.block_dataset[i]) for i in range(start_idx, end_idx)]
        title = list(self.title_dataset[int(doc_idx)])
89

Neel Kant's avatar
Neel Kant committed
90
        block = list(itertools.chain(*block))[:self.max_seq_length - (3 + len(title))]
91
92
        block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title)

Neel Kant's avatar
Neel Kant committed
93
        return (block_tokens, block_pad_mask)
94

95
    def concat_and_pad_tokens(self, tokens, title=None):
96
97
        """concat with special tokens and pad sequence to self.max_seq_length"""
        tokens = [self.cls_id] + tokens + [self.sep_id]
98
99
        if title is not None:
            tokens += title + [self.sep_id]
100
        assert len(tokens) <= self.max_seq_length, len(tokens)
101
102

        num_pad = self.max_seq_length - len(tokens)
103
        pad_mask = [1] * len(tokens) + [0] * num_pad
104
        tokens += [self.pad_id] * num_pad
105
        return tokens, pad_mask
Neel Kant's avatar
Neel Kant committed
106

Neel Kant's avatar
Neel Kant committed
107
108
109
110
111
112
    def get_samples_mapping(self, data_prefix, num_epochs, max_num_samples):
        if not num_epochs:
            if not max_num_samples:
                raise ValueError("Need to specify either max_num_samples "
                                 "or num_epochs")
            num_epochs = np.iinfo(np.int32).max - 1
113
        if not max_num_samples:
Neel Kant's avatar
Neel Kant committed
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
            max_num_samples = np.iinfo(np.int64).max - 1

        # Filename of the index mapping
        indexmap_filename = data_prefix
        indexmap_filename += '_{}_indexmap'.format(self.name)
        if num_epochs != (np.iinfo(np.int32).max - 1):
            indexmap_filename += '_{}ep'.format(num_epochs)
        if max_num_samples != (np.iinfo(np.int64).max - 1):
            indexmap_filename += '_{}mns'.format(max_num_samples)
        indexmap_filename += '_{}msl'.format(self.max_seq_length)
        indexmap_filename += '_{}s'.format(self.seed)
        indexmap_filename += '.npy'

        # Build the indexed mapping if not exist.
        if torch.distributed.get_rank() == 0 and \
                not os.path.isfile(indexmap_filename):
            print(' > WARNING: could not find index map file {}, building '
                  'the indices on rank 0 ...'.format(indexmap_filename))

            # Make sure the types match the helpers input types.
Neel Kant's avatar
Neel Kant committed
134
135
            assert self.block_dataset.doc_idx.dtype == np.int64
            assert self.block_dataset.sizes.dtype == np.int32
Neel Kant's avatar
Neel Kant committed
136
137
138
139
140
141
142

            # Build samples mapping
            verbose = torch.distributed.get_rank() == 0
            start_time = time.time()
            print_rank_0(' > building samples index mapping for {} ...'.format(
                self.name))
            samples_mapping = helpers.build_blocks_mapping(
Neel Kant's avatar
Neel Kant committed
143
144
145
                self.block_dataset.doc_idx,
                self.block_dataset.sizes,
                self.title_dataset.sizes,
Neel Kant's avatar
Neel Kant committed
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
                num_epochs,
                max_num_samples,
                self.max_seq_length-3,  # account for added tokens
                self.seed,
                verbose)
            print_rank_0(' > done building samples index mapping')
            np.save(indexmap_filename, samples_mapping, allow_pickle=True)
            print_rank_0(' > saved the index mapping in {}'.format(
                indexmap_filename))
            # Make sure all the ranks have built the mapping
            print_rank_0(' > elapsed time to build and save samples mapping '
                         '(seconds): {:4f}'.format(
                time.time() - start_time))
        # This should be a barrier but nccl barrier assumes
        # device_index=rank which is not the case for model
        # parallel case
        counts = torch.cuda.LongTensor([1])
        torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group())
        assert counts[0].item() == torch.distributed.get_world_size(
            group=mpu.get_data_parallel_group())

        # Load indexed dataset.
        print_rank_0(' > loading indexed mapping from {}'.format(
169
            indexmap_filename))
Neel Kant's avatar
Neel Kant committed
170
171
172
        start_time = time.time()
        samples_mapping = np.load(indexmap_filename, allow_pickle=True)
        print_rank_0('    loaded indexed file in {:3.3f} seconds'.format(
173
            time.time() - start_time))
Neel Kant's avatar
Neel Kant committed
174
175
176
177
        print_rank_0('    total number of samples: {}'.format(
            samples_mapping.shape[0]))

        return samples_mapping