ict_dataset.py 6.01 KB
Newer Older
1
import itertools
2
import random
3

Neel Kant's avatar
Neel Kant committed
4
import numpy as np
5
from torch.utils.data import Dataset
Neel Kant's avatar
Neel Kant committed
6

7
from megatron import get_tokenizer
Neel Kant's avatar
Neel Kant committed
8
9
from megatron import get_args
from megatron.data.dataset_utils import get_indexed_dataset_
Neel Kant's avatar
Neel Kant committed
10
from megatron.data.realm_dataset_utils import get_block_samples_mapping
11

Mostofa Patwary's avatar
Mostofa Patwary committed
12
13
14
15
16
17
18
19
20
21
def make_attention_mask(source_block, target_block):
    """
    Returns a 2-dimensional (2-D) attention mask
    :param source_block: 1-D array
    :param target_block: 1-D array
    """
    mask = (target_block[None, :] >= 1) * (source_block[:, None] >= 1)
    mask = mask.astype(np.int64)
    # (source_length, target_length)
    return mask
22

Neel Kant's avatar
Neel Kant committed
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
def get_ict_dataset(use_titles=True, query_in_block_prob=1):
    """Get a dataset which uses block samples mappings to get ICT/block indexing data (via get_block())
    rather than for training, since it is only built with a single epoch sample mapping.
    """
    args = get_args()
    block_dataset = get_indexed_dataset_(args.data_path, 'mmap', True)
    titles_dataset = get_indexed_dataset_(args.titles_data_path, 'mmap', True)

    kwargs = dict(
        name='full',
        block_dataset=block_dataset,
        title_dataset=titles_dataset,
        data_prefix=args.data_path,
        num_epochs=1,
        max_num_samples=None,
        max_seq_length=args.seq_length,
        seed=1,
        query_in_block_prob=query_in_block_prob,
        use_titles=use_titles,
        use_one_sent_docs=args.use_one_sent_docs
    )
    dataset = ICTDataset(**kwargs)
    return dataset


48
class ICTDataset(Dataset):
49
50
    """Dataset containing sentences and their blocks for an inverse cloze task."""
    def __init__(self, name, block_dataset, title_dataset, data_prefix,
Neel Kant's avatar
Neel Kant committed
51
                 num_epochs, max_num_samples, max_seq_length, query_in_block_prob,
Mostofa Patwary's avatar
Mostofa Patwary committed
52
                 seed, use_titles=True, use_one_sent_docs=False, binary_head=False):
53
54
55
        self.name = name
        self.seed = seed
        self.max_seq_length = max_seq_length
Neel Kant's avatar
Neel Kant committed
56
        self.query_in_block_prob = query_in_block_prob
57
58
59
        self.block_dataset = block_dataset
        self.title_dataset = title_dataset
        self.rng = random.Random(self.seed)
60
        self.use_titles = use_titles
Neel Kant's avatar
Neel Kant committed
61
        self.use_one_sent_docs = use_one_sent_docs
62

63
64
        self.samples_mapping = get_block_samples_mapping(
            block_dataset, title_dataset, data_prefix, num_epochs,
Neel Kant's avatar
Neel Kant committed
65
            max_num_samples, max_seq_length, seed, name, use_one_sent_docs)
66
67
68
69
70
71
72
73
74
        self.tokenizer = get_tokenizer()
        self.vocab_id_list = list(self.tokenizer.inv_vocab.keys())
        self.vocab_id_to_token_list = self.tokenizer.inv_vocab
        self.cls_id = self.tokenizer.cls
        self.sep_id = self.tokenizer.sep
        self.mask_id = self.tokenizer.mask
        self.pad_id = self.tokenizer.pad

    def __len__(self):
Neel Kant's avatar
Neel Kant committed
75
        return len(self.samples_mapping)
76
77

    def __getitem__(self, idx):
78
        """Get an ICT example of a pseudo-query and the block of text from which it was extracted"""
Neel Kant's avatar
Neel Kant committed
79
80
81
        sample_data = self.samples_mapping[idx]
        start_idx, end_idx, doc_idx, block_idx = sample_data.as_tuple()

82
        if self.use_titles:
83
            title = self.title_dataset[int(doc_idx)]
84
85
86
87
            title_pad_offset = 3 + len(title)
        else:
            title = None
            title_pad_offset = 2
88
        block = [self.block_dataset[i] for i in range(start_idx, end_idx)]
Neel Kant's avatar
Neel Kant committed
89
        assert len(block) > 1 or self.use_one_sent_docs or self.query_in_block_prob == 1
90

91
        # randint() is inclusive for Python rng
92
        rand_sent_idx = self.rng.randint(0, len(block) - 1)
93

94
        # keep the query in the context query_in_block_prob fraction of the time.
Neel Kant's avatar
Neel Kant committed
95
        if self.rng.random() < self.query_in_block_prob:
96
97
98
99
100
101
102
            query = block[rand_sent_idx].copy()
        else:
            query = block.pop(rand_sent_idx)

        # still need to truncate because blocks are concluded when
        # the sentence lengths have exceeded max_seq_length.
        query = query[:self.max_seq_length - 2]
103
        block = list(itertools.chain(*block))[:self.max_seq_length - title_pad_offset]
104
105

        query_tokens, query_pad_mask = self.concat_and_pad_tokens(query)
Mostofa Patwary's avatar
Mostofa Patwary committed
106
107
108
109
110
        context_tokens, context_pad_mask = self.concat_and_pad_tokens(block, title)

        query_mask = make_attention_mask(query_tokens, query_tokens)
        context_mask = make_attention_mask(context_tokens, context_tokens)

Neel Kant's avatar
Neel Kant committed
111
        block_data = sample_data.as_array()
112
113

        sample = {
114
            'query_tokens': query_tokens,
Mostofa Patwary's avatar
Mostofa Patwary committed
115
            'query_mask': query_mask,
116
            'query_pad_mask': query_pad_mask,
Mostofa Patwary's avatar
Mostofa Patwary committed
117
118
119
            'context_tokens': context_tokens,
            'context_mask': context_mask,
            'context_pad_mask': context_pad_mask,
120
            'block_data': block_data,
121
122
123
124
125
126
        }

        return sample

    def get_block(self, start_idx, end_idx, doc_idx):
        """Get the IDs for an evidence block plus the title of the corresponding document"""
127
128
        block = [self.block_dataset[i] for i in range(start_idx, end_idx)]
        title = self.title_dataset[int(doc_idx)]
129
130
131
132

        block = list(itertools.chain(*block))[:self.max_seq_length - (3 + len(title))]
        block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title)

133
        return block_tokens, block_pad_mask
134
135

    def get_null_block(self):
136
        """Get empty block and title - used in REALM pretraining"""
137
138
139
        block, title = [], []
        block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title)

140
        return block_tokens, block_pad_mask
141
142

    def concat_and_pad_tokens(self, tokens, title=None):
143
        """Concat with special tokens and pad sequence to self.max_seq_length"""
144
        tokens = list(tokens)
145
146
147
        if title is None:
            tokens = [self.cls_id] + tokens + [self.sep_id]
        else:
148
            title = list(title)
149
            tokens = [self.cls_id] + title + [self.sep_id] + tokens + [self.sep_id]
150
        assert len(tokens) <= self.max_seq_length
151
152
153
154

        num_pad = self.max_seq_length - len(tokens)
        pad_mask = [1] * len(tokens) + [0] * num_pad
        tokens += [self.pad_id] * num_pad
155
156

        return np.array(tokens), np.array(pad_mask)