realm_dataset.py 6.8 KB
Newer Older
1
import itertools
2
import random
3

Neel Kant's avatar
Neel Kant committed
4
import numpy as np
5
from torch.utils.data import Dataset
Neel Kant's avatar
Neel Kant committed
6

7
from megatron import get_tokenizer
8
from megatron.data.realm_dataset_utils import build_realm_training_sample, get_block_samples_mapping, join_str_list
9
10


11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
class REALMDataset(Dataset):
    """Dataset containing simple masked sentences for masked language modeling.

    The dataset should yield sentences just like the regular BertDataset
    However, this dataset also needs to be able to return a set of blocks
    given their start and end indices.

    Presumably

    """
    def __init__(self, name, block_dataset, title_dataset, data_prefix,
                 num_epochs, max_num_samples, masked_lm_prob,
                 max_seq_length, short_seq_prob, seed):
        self.name = name
        self.seed = seed
        self.max_seq_length = max_seq_length
        self.masked_lm_prob = masked_lm_prob
        self.block_dataset = block_dataset
        self.title_dataset = title_dataset
        self.short_seq_prob = short_seq_prob
        self.rng = random.Random(self.seed)

33
34
35
36
        self.samples_mapping = get_block_samples_mapping(
            block_dataset, title_dataset, data_prefix, num_epochs,
            max_num_samples, max_seq_length, seed, name)

37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
        self.tokenizer = get_tokenizer()
        self.vocab_id_list = list(self.tokenizer.inv_vocab.keys())
        self.vocab_id_to_token_list = self.tokenizer.inv_vocab
        self.cls_id = self.tokenizer.cls
        self.sep_id = self.tokenizer.sep
        self.mask_id = self.tokenizer.mask
        self.pad_id = self.tokenizer.pad

    def __len__(self):
        return self.samples_mapping.shape[0]

    def __getitem__(self, idx):
        start_idx, end_idx, doc_idx, block_idx = self.samples_mapping[idx]
        block = [list(self.block_dataset[i]) for i in range(start_idx, end_idx)]
        assert len(block) > 1
        np_rng = np.random.RandomState(seed=(self.seed + idx))

54
55
56
57
58
59
60
61
62
63
        sample = build_realm_training_sample(block,
                                             self.max_seq_length,
                                             self.vocab_id_list,
                                             self.vocab_id_to_token_list,
                                             self.cls_id,
                                             self.sep_id,
                                             self.mask_id,
                                             self.pad_id,
                                             self.masked_lm_prob,
                                             np_rng)
64
        sample.update({'query_block_indices': np.array([block_idx]).astype(np.int64)})
65
66
        return sample

67

68
class ICTDataset(Dataset):
69
70
71
    """Dataset containing sentences and their blocks for an inverse cloze task."""
    def __init__(self, name, block_dataset, title_dataset, data_prefix,
                 num_epochs, max_num_samples, max_seq_length,
Neel Kant's avatar
Neel Kant committed
72
                 query_in_block_prob, short_seq_prob, seed, use_titles=True):
73
74
75
        self.name = name
        self.seed = seed
        self.max_seq_length = max_seq_length
Neel Kant's avatar
Neel Kant committed
76
        self.query_in_block_prob = query_in_block_prob
77
78
79
80
        self.block_dataset = block_dataset
        self.title_dataset = title_dataset
        self.short_seq_prob = short_seq_prob
        self.rng = random.Random(self.seed)
81
        self.use_titles = use_titles
82

83
84
85
        self.samples_mapping = get_block_samples_mapping(
            block_dataset, title_dataset, data_prefix, num_epochs,
            max_num_samples, max_seq_length, seed, name)
86
87
88
89
90
91
92
93
94
95
96
97
98
        self.tokenizer = get_tokenizer()
        self.vocab_id_list = list(self.tokenizer.inv_vocab.keys())
        self.vocab_id_to_token_list = self.tokenizer.inv_vocab
        self.cls_id = self.tokenizer.cls
        self.sep_id = self.tokenizer.sep
        self.mask_id = self.tokenizer.mask
        self.pad_id = self.tokenizer.pad

    def __len__(self):
        return self.samples_mapping.shape[0]

    def __getitem__(self, idx):
        start_idx, end_idx, doc_idx, block_idx = self.samples_mapping[idx]
99
100
101
102
103
104
        if self.use_titles:
            title = list(self.title_dataset[int(doc_idx)])
            title_pad_offset = 3 + len(title)
        else:
            title = None
            title_pad_offset = 2
105
106
107
        block = [list(self.block_dataset[i]) for i in range(start_idx, end_idx)]
        assert len(block) > 1

108
        rand_sent_idx = self.rng.randint(0, len(block) - 1)
109
110

        # keep the query in the context 10% of the time.
Neel Kant's avatar
Neel Kant committed
111
        if self.rng.random() < self.query_in_block_prob:
112
113
114
115
116
117
118
            query = block[rand_sent_idx].copy()
        else:
            query = block.pop(rand_sent_idx)

        # still need to truncate because blocks are concluded when
        # the sentence lengths have exceeded max_seq_length.
        query = query[:self.max_seq_length - 2]
119
        block = list(itertools.chain(*block))[:self.max_seq_length - title_pad_offset]
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138

        query_tokens, query_pad_mask = self.concat_and_pad_tokens(query)
        block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title)

        sample = {
            'query_tokens': np.array(query_tokens),
            'query_pad_mask': np.array(query_pad_mask),
            'block_tokens': np.array(block_tokens),
            'block_pad_mask': np.array(block_pad_mask),
            'block_data': np.array([start_idx, end_idx, doc_idx, block_idx]).astype(np.int64)
        }

        return sample

    def encode_text(self, text):
        return self.tokenizer.tokenize(text)

    def decode_tokens(self, token_ids):
        tokens = self.tokenizer.tokenizer.convert_ids_to_tokens(token_ids)
139
140
        non_pads = [t for t in tokens if t != '[PAD]']
        return join_str_list(non_pads)
141
142
143
144
145
146
147
148
149
150

    def get_block(self, start_idx, end_idx, doc_idx):
        """Get the IDs for an evidence block plus the title of the corresponding document"""
        block = [list(self.block_dataset[i]) for i in range(start_idx, end_idx)]
        title = list(self.title_dataset[int(doc_idx)])

        block = list(itertools.chain(*block))[:self.max_seq_length - (3 + len(title))]
        block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title)

        return (block_tokens, block_pad_mask)
151
152
153
154
155
156

    def get_null_block(self):
        block, title = [], []
        block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title)

        return (block_tokens, block_pad_mask)
157
158
159

    def concat_and_pad_tokens(self, tokens, title=None):
        """concat with special tokens and pad sequence to self.max_seq_length"""
160
161
162
163
        if title is None:
            tokens = [self.cls_id] + tokens + [self.sep_id]
        else:
            tokens = [self.cls_id] + title + [self.sep_id] + tokens + [self.sep_id]
164
165
166
167
168
169
        assert len(tokens) <= self.max_seq_length, len(tokens)

        num_pad = self.max_seq_length - len(tokens)
        pad_mask = [1] * len(tokens) + [0] * num_pad
        tokens += [self.pad_id] * num_pad
        return tokens, pad_mask