realm_dataset.py 6.76 KB
Newer Older
1
import itertools
2
import random
3

Neel Kant's avatar
Neel Kant committed
4
import numpy as np
5
from torch.utils.data import Dataset
Neel Kant's avatar
Neel Kant committed
6

7
8
from megatron import get_tokenizer
from megatron.data.realm_dataset_utils import build_realm_training_sample, get_block_samples_mapping
9
10


11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
class REALMDataset(Dataset):
    """Dataset containing simple masked sentences for masked language modeling.

    The dataset should yield sentences just like the regular BertDataset
    However, this dataset also needs to be able to return a set of blocks
    given their start and end indices.

    Presumably

    """
    def __init__(self, name, block_dataset, title_dataset, data_prefix,
                 num_epochs, max_num_samples, masked_lm_prob,
                 max_seq_length, short_seq_prob, seed):
        self.name = name
        self.seed = seed
        self.max_seq_length = max_seq_length
        self.masked_lm_prob = masked_lm_prob
        self.block_dataset = block_dataset
        self.title_dataset = title_dataset
        self.short_seq_prob = short_seq_prob
        self.rng = random.Random(self.seed)

33
34
35
36
        self.samples_mapping = get_block_samples_mapping(
            block_dataset, title_dataset, data_prefix, num_epochs,
            max_num_samples, max_seq_length, seed, name)

37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
        self.tokenizer = get_tokenizer()
        self.vocab_id_list = list(self.tokenizer.inv_vocab.keys())
        self.vocab_id_to_token_list = self.tokenizer.inv_vocab
        self.cls_id = self.tokenizer.cls
        self.sep_id = self.tokenizer.sep
        self.mask_id = self.tokenizer.mask
        self.pad_id = self.tokenizer.pad

    def __len__(self):
        return self.samples_mapping.shape[0]

    def __getitem__(self, idx):
        start_idx, end_idx, doc_idx, block_idx = self.samples_mapping[idx]
        block = [list(self.block_dataset[i]) for i in range(start_idx, end_idx)]
        assert len(block) > 1
        np_rng = np.random.RandomState(seed=(self.seed + idx))

54
55
56
57
58
59
60
61
62
63
        sample = build_realm_training_sample(block,
                                             self.max_seq_length,
                                             self.vocab_id_list,
                                             self.vocab_id_to_token_list,
                                             self.cls_id,
                                             self.sep_id,
                                             self.mask_id,
                                             self.pad_id,
                                             self.masked_lm_prob,
                                             np_rng)
64
        sample.update({'query_block_indices': np.array([block_idx]).astype(np.int64)})
65
66
        return sample

67

68
class ICTDataset(Dataset):
69
70
71
    """Dataset containing sentences and their blocks for an inverse cloze task."""
    def __init__(self, name, block_dataset, title_dataset, data_prefix,
                 num_epochs, max_num_samples, max_seq_length,
Neel Kant's avatar
Neel Kant committed
72
                 query_in_block_prob, short_seq_prob, seed, use_titles=True):
73
74
75
        self.name = name
        self.seed = seed
        self.max_seq_length = max_seq_length
Neel Kant's avatar
Neel Kant committed
76
        self.query_in_block_prob = query_in_block_prob
77
78
79
80
        self.block_dataset = block_dataset
        self.title_dataset = title_dataset
        self.short_seq_prob = short_seq_prob
        self.rng = random.Random(self.seed)
81
        self.use_titles = use_titles
82

83
84
85
        self.samples_mapping = get_block_samples_mapping(
            block_dataset, title_dataset, data_prefix, num_epochs,
            max_num_samples, max_seq_length, seed, name)
86
87
88
89
90
91
92
93
94
95
96
97
98
        self.tokenizer = get_tokenizer()
        self.vocab_id_list = list(self.tokenizer.inv_vocab.keys())
        self.vocab_id_to_token_list = self.tokenizer.inv_vocab
        self.cls_id = self.tokenizer.cls
        self.sep_id = self.tokenizer.sep
        self.mask_id = self.tokenizer.mask
        self.pad_id = self.tokenizer.pad

    def __len__(self):
        return self.samples_mapping.shape[0]

    def __getitem__(self, idx):
        start_idx, end_idx, doc_idx, block_idx = self.samples_mapping[idx]
99
100
101
102
103
104
        if self.use_titles:
            title = list(self.title_dataset[int(doc_idx)])
            title_pad_offset = 3 + len(title)
        else:
            title = None
            title_pad_offset = 2
105
106
107
        block = [list(self.block_dataset[i]) for i in range(start_idx, end_idx)]
        assert len(block) > 1

108
        rand_sent_idx = self.rng.randint(0, len(block) - 1)
109
110

        # keep the query in the context 10% of the time.
Neel Kant's avatar
Neel Kant committed
111
        if self.rng.random() < self.query_in_block_prob:
112
113
114
115
116
117
118
            query = block[rand_sent_idx].copy()
        else:
            query = block.pop(rand_sent_idx)

        # still need to truncate because blocks are concluded when
        # the sentence lengths have exceeded max_seq_length.
        query = query[:self.max_seq_length - 2]
119
        block = list(itertools.chain(*block))[:self.max_seq_length - title_pad_offset]
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149

        query_tokens, query_pad_mask = self.concat_and_pad_tokens(query)
        block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title)

        sample = {
            'query_tokens': np.array(query_tokens),
            'query_pad_mask': np.array(query_pad_mask),
            'block_tokens': np.array(block_tokens),
            'block_pad_mask': np.array(block_pad_mask),
            'block_data': np.array([start_idx, end_idx, doc_idx, block_idx]).astype(np.int64)
        }

        return sample

    def encode_text(self, text):
        return self.tokenizer.tokenize(text)

    def decode_tokens(self, token_ids):
        tokens = self.tokenizer.tokenizer.convert_ids_to_tokens(token_ids)
        return ' '.join(token for token in tokens if token != '[PAD]')

    def get_block(self, start_idx, end_idx, doc_idx):
        """Get the IDs for an evidence block plus the title of the corresponding document"""
        block = [list(self.block_dataset[i]) for i in range(start_idx, end_idx)]
        title = list(self.title_dataset[int(doc_idx)])

        block = list(itertools.chain(*block))[:self.max_seq_length - (3 + len(title))]
        block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title)

        return (block_tokens, block_pad_mask)
150
151
152
153
154
155

    def get_null_block(self):
        block, title = [], []
        block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title)

        return (block_tokens, block_pad_mask)
156
157
158

    def concat_and_pad_tokens(self, tokens, title=None):
        """concat with special tokens and pad sequence to self.max_seq_length"""
159
160
161
162
        if title is None:
            tokens = [self.cls_id] + tokens + [self.sep_id]
        else:
            tokens = [self.cls_id] + title + [self.sep_id] + tokens + [self.sep_id]
163
164
165
166
167
168
        assert len(tokens) <= self.max_seq_length, len(tokens)

        num_pad = self.max_seq_length - len(tokens)
        pad_mask = [1] * len(tokens) + [0] * num_pad
        tokens += [self.pad_id] * num_pad
        return tokens, pad_mask