ict_dataset.py 6.11 KB
Newer Older
xingjinliang's avatar
xingjinliang committed
1
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2
import itertools
3
import random
4

Neel Kant's avatar
Neel Kant committed
5
import numpy as np
6
from torch.utils.data import Dataset
Neel Kant's avatar
Neel Kant committed
7

xingjinliang's avatar
xingjinliang committed
8
9
10
11
from megatron.training import get_tokenizer
from megatron.training import get_args
from megatron.legacy.data.dataset_utils import get_indexed_dataset_
from megatron.legacy.data.realm_dataset_utils import get_block_samples_mapping
12

Mostofa Patwary's avatar
Mostofa Patwary committed
13
14
15
16
17
18
19
20
21
22
def make_attention_mask(source_block, target_block):
    """
    Returns a 2-dimensional (2-D) attention mask
    :param source_block: 1-D array
    :param target_block: 1-D array
    """
    mask = (target_block[None, :] >= 1) * (source_block[:, None] >= 1)
    mask = mask.astype(np.int64)
    # (source_length, target_length)
    return mask
23

Neel Kant's avatar
Neel Kant committed
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
def get_ict_dataset(use_titles=True, query_in_block_prob=1):
    """Get a dataset which uses block samples mappings to get ICT/block indexing data (via get_block())
    rather than for training, since it is only built with a single epoch sample mapping.
    """
    args = get_args()
    block_dataset = get_indexed_dataset_(args.data_path, 'mmap', True)
    titles_dataset = get_indexed_dataset_(args.titles_data_path, 'mmap', True)

    kwargs = dict(
        name='full',
        block_dataset=block_dataset,
        title_dataset=titles_dataset,
        data_prefix=args.data_path,
        num_epochs=1,
        max_num_samples=None,
        max_seq_length=args.seq_length,
        seed=1,
        query_in_block_prob=query_in_block_prob,
        use_titles=use_titles,
        use_one_sent_docs=args.use_one_sent_docs
    )
    dataset = ICTDataset(**kwargs)
    return dataset


49
class ICTDataset(Dataset):
50
51
    """Dataset containing sentences and their blocks for an inverse cloze task."""
    def __init__(self, name, block_dataset, title_dataset, data_prefix,
Neel Kant's avatar
Neel Kant committed
52
                 num_epochs, max_num_samples, max_seq_length, query_in_block_prob,
Mostofa Patwary's avatar
Mostofa Patwary committed
53
                 seed, use_titles=True, use_one_sent_docs=False, binary_head=False):
54
55
56
        self.name = name
        self.seed = seed
        self.max_seq_length = max_seq_length
Neel Kant's avatar
Neel Kant committed
57
        self.query_in_block_prob = query_in_block_prob
58
59
60
        self.block_dataset = block_dataset
        self.title_dataset = title_dataset
        self.rng = random.Random(self.seed)
61
        self.use_titles = use_titles
Neel Kant's avatar
Neel Kant committed
62
        self.use_one_sent_docs = use_one_sent_docs
63

64
65
        self.samples_mapping = get_block_samples_mapping(
            block_dataset, title_dataset, data_prefix, num_epochs,
Neel Kant's avatar
Neel Kant committed
66
            max_num_samples, max_seq_length, seed, name, use_one_sent_docs)
67
68
69
70
71
72
73
74
75
        self.tokenizer = get_tokenizer()
        self.vocab_id_list = list(self.tokenizer.inv_vocab.keys())
        self.vocab_id_to_token_list = self.tokenizer.inv_vocab
        self.cls_id = self.tokenizer.cls
        self.sep_id = self.tokenizer.sep
        self.mask_id = self.tokenizer.mask
        self.pad_id = self.tokenizer.pad

    def __len__(self):
Neel Kant's avatar
Neel Kant committed
76
        return len(self.samples_mapping)
77
78

    def __getitem__(self, idx):
79
        """Get an ICT example of a pseudo-query and the block of text from which it was extracted"""
Neel Kant's avatar
Neel Kant committed
80
81
82
        sample_data = self.samples_mapping[idx]
        start_idx, end_idx, doc_idx, block_idx = sample_data.as_tuple()

83
        if self.use_titles:
84
            title = self.title_dataset[int(doc_idx)]
85
86
87
88
            title_pad_offset = 3 + len(title)
        else:
            title = None
            title_pad_offset = 2
89
        block = [self.block_dataset[i] for i in range(start_idx, end_idx)]
Neel Kant's avatar
Neel Kant committed
90
        assert len(block) > 1 or self.use_one_sent_docs or self.query_in_block_prob == 1
91

92
        # randint() is inclusive for Python rng
93
        rand_sent_idx = self.rng.randint(0, len(block) - 1)
94

95
        # keep the query in the context query_in_block_prob fraction of the time.
Neel Kant's avatar
Neel Kant committed
96
        if self.rng.random() < self.query_in_block_prob:
97
98
99
100
101
102
103
            query = block[rand_sent_idx].copy()
        else:
            query = block.pop(rand_sent_idx)

        # still need to truncate because blocks are concluded when
        # the sentence lengths have exceeded max_seq_length.
        query = query[:self.max_seq_length - 2]
104
        block = list(itertools.chain(*block))[:self.max_seq_length - title_pad_offset]
105
106

        query_tokens, query_pad_mask = self.concat_and_pad_tokens(query)
Mostofa Patwary's avatar
Mostofa Patwary committed
107
108
109
110
111
        context_tokens, context_pad_mask = self.concat_and_pad_tokens(block, title)

        query_mask = make_attention_mask(query_tokens, query_tokens)
        context_mask = make_attention_mask(context_tokens, context_tokens)

Neel Kant's avatar
Neel Kant committed
112
        block_data = sample_data.as_array()
113
114

        sample = {
115
            'query_tokens': query_tokens,
Mostofa Patwary's avatar
Mostofa Patwary committed
116
            'query_mask': query_mask,
117
            'query_pad_mask': query_pad_mask,
Mostofa Patwary's avatar
Mostofa Patwary committed
118
119
120
            'context_tokens': context_tokens,
            'context_mask': context_mask,
            'context_pad_mask': context_pad_mask,
121
            'block_data': block_data,
122
123
124
125
126
127
        }

        return sample

    def get_block(self, start_idx, end_idx, doc_idx):
        """Get the IDs for an evidence block plus the title of the corresponding document"""
128
129
        block = [self.block_dataset[i] for i in range(start_idx, end_idx)]
        title = self.title_dataset[int(doc_idx)]
130
131
132
133

        block = list(itertools.chain(*block))[:self.max_seq_length - (3 + len(title))]
        block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title)

134
        return block_tokens, block_pad_mask
135
136

    def get_null_block(self):
137
        """Get empty block and title - used in REALM pretraining"""
138
139
140
        block, title = [], []
        block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title)

141
        return block_tokens, block_pad_mask
142
143

    def concat_and_pad_tokens(self, tokens, title=None):
144
        """Concat with special tokens and pad sequence to self.max_seq_length"""
145
        tokens = list(tokens)
146
147
148
        if title is None:
            tokens = [self.cls_id] + tokens + [self.sep_id]
        else:
149
            title = list(title)
150
            tokens = [self.cls_id] + title + [self.sep_id] + tokens + [self.sep_id]
151
        assert len(tokens) <= self.max_seq_length
152
153
154
155

        num_pad = self.max_seq_length - len(tokens)
        pad_mask = [1] * len(tokens) + [0] * num_pad
        tokens += [self.pad_id] * num_pad
156
157

        return np.array(tokens), np.array(pad_mask)