ict_dataset.py 5.42 KB
Newer Older
1
import itertools
2
import random
3

Neel Kant's avatar
Neel Kant committed
4
import numpy as np
5
from torch.utils.data import Dataset
Neel Kant's avatar
Neel Kant committed
6

7
from megatron import get_tokenizer
Neel Kant's avatar
Neel Kant committed
8
9
from megatron import get_args
from megatron.data.dataset_utils import get_indexed_dataset_
Neel Kant's avatar
Neel Kant committed
10
from megatron.data.realm_dataset_utils import get_block_samples_mapping
11

12

Neel Kant's avatar
Neel Kant committed
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
def get_ict_dataset(use_titles=True, query_in_block_prob=1):
    """Get a dataset which uses block samples mappings to get ICT/block indexing data (via get_block())
    rather than for training, since it is only built with a single epoch sample mapping.
    """
    args = get_args()
    block_dataset = get_indexed_dataset_(args.data_path, 'mmap', True)
    titles_dataset = get_indexed_dataset_(args.titles_data_path, 'mmap', True)

    kwargs = dict(
        name='full',
        block_dataset=block_dataset,
        title_dataset=titles_dataset,
        data_prefix=args.data_path,
        num_epochs=1,
        max_num_samples=None,
        max_seq_length=args.seq_length,
        seed=1,
        query_in_block_prob=query_in_block_prob,
        use_titles=use_titles,
        use_one_sent_docs=args.use_one_sent_docs
    )
    dataset = ICTDataset(**kwargs)
    return dataset


38
class ICTDataset(Dataset):
39
40
    """Dataset containing sentences and their blocks for an inverse cloze task."""
    def __init__(self, name, block_dataset, title_dataset, data_prefix,
Neel Kant's avatar
Neel Kant committed
41
42
                 num_epochs, max_num_samples, max_seq_length, query_in_block_prob,
                 seed, use_titles=True, use_one_sent_docs=False):
43
44
45
        self.name = name
        self.seed = seed
        self.max_seq_length = max_seq_length
Neel Kant's avatar
Neel Kant committed
46
        self.query_in_block_prob = query_in_block_prob
47
48
49
        self.block_dataset = block_dataset
        self.title_dataset = title_dataset
        self.rng = random.Random(self.seed)
50
        self.use_titles = use_titles
Neel Kant's avatar
Neel Kant committed
51
        self.use_one_sent_docs = use_one_sent_docs
52

53
54
        self.samples_mapping = get_block_samples_mapping(
            block_dataset, title_dataset, data_prefix, num_epochs,
Neel Kant's avatar
Neel Kant committed
55
            max_num_samples, max_seq_length, seed, name, use_one_sent_docs)
56
57
58
59
60
61
62
63
64
        self.tokenizer = get_tokenizer()
        self.vocab_id_list = list(self.tokenizer.inv_vocab.keys())
        self.vocab_id_to_token_list = self.tokenizer.inv_vocab
        self.cls_id = self.tokenizer.cls
        self.sep_id = self.tokenizer.sep
        self.mask_id = self.tokenizer.mask
        self.pad_id = self.tokenizer.pad

    def __len__(self):
Neel Kant's avatar
Neel Kant committed
65
        return len(self.samples_mapping)
66
67

    def __getitem__(self, idx):
68
        """Get an ICT example of a pseudo-query and the block of text from which it was extracted"""
Neel Kant's avatar
Neel Kant committed
69
70
71
        sample_data = self.samples_mapping[idx]
        start_idx, end_idx, doc_idx, block_idx = sample_data.as_tuple()

72
        if self.use_titles:
73
            title = self.title_dataset[int(doc_idx)]
74
75
76
77
            title_pad_offset = 3 + len(title)
        else:
            title = None
            title_pad_offset = 2
78
        block = [self.block_dataset[i] for i in range(start_idx, end_idx)]
Neel Kant's avatar
Neel Kant committed
79
        assert len(block) > 1 or self.use_one_sent_docs or self.query_in_block_prob == 1
80

81
        # randint() is inclusive for Python rng
82
        rand_sent_idx = self.rng.randint(0, len(block) - 1)
83

84
        # keep the query in the context query_in_block_prob fraction of the time.
Neel Kant's avatar
Neel Kant committed
85
        if self.rng.random() < self.query_in_block_prob:
86
87
88
89
90
91
92
            query = block[rand_sent_idx].copy()
        else:
            query = block.pop(rand_sent_idx)

        # still need to truncate because blocks are concluded when
        # the sentence lengths have exceeded max_seq_length.
        query = query[:self.max_seq_length - 2]
93
        block = list(itertools.chain(*block))[:self.max_seq_length - title_pad_offset]
94
95
96

        query_tokens, query_pad_mask = self.concat_and_pad_tokens(query)
        block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title)
Neel Kant's avatar
Neel Kant committed
97
        block_data = sample_data.as_array()
98
99

        sample = {
100
101
102
103
104
            'query_tokens': query_tokens,
            'query_pad_mask': query_pad_mask,
            'block_tokens': block_tokens,
            'block_pad_mask': block_pad_mask,
            'block_data': block_data,
105
106
107
108
109
110
        }

        return sample

    def get_block(self, start_idx, end_idx, doc_idx):
        """Get the IDs for an evidence block plus the title of the corresponding document"""
111
112
        block = [self.block_dataset[i] for i in range(start_idx, end_idx)]
        title = self.title_dataset[int(doc_idx)]
113
114
115
116

        block = list(itertools.chain(*block))[:self.max_seq_length - (3 + len(title))]
        block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title)

117
        return block_tokens, block_pad_mask
118
119

    def get_null_block(self):
120
        """Get empty block and title - used in REALM pretraining"""
121
122
123
        block, title = [], []
        block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title)

124
        return block_tokens, block_pad_mask
125
126

    def concat_and_pad_tokens(self, tokens, title=None):
127
        """Concat with special tokens and pad sequence to self.max_seq_length"""
128
        tokens = list(tokens)
129
130
131
        if title is None:
            tokens = [self.cls_id] + tokens + [self.sep_id]
        else:
132
            title = list(title)
133
            tokens = [self.cls_id] + title + [self.sep_id] + tokens + [self.sep_id]
134
        assert len(tokens) <= self.max_seq_length
135
136
137
138

        num_pad = self.max_seq_length - len(tokens)
        pad_mask = [1] * len(tokens) + [0] * num_pad
        tokens += [self.pad_id] * num_pad
139
140

        return np.array(tokens), np.array(pad_mask)