bert_dataset.py 7.46 KB
Newer Older
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
3
4
5
6
7
8
9
10
11
12
13
14
15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
16
"""BERT Style dataset."""
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
17
18
19
20

import numpy as np
import torch

21
22
23
24
25
26
27
28
29
30
31
32
33
from megatron import (
    get_args,
    get_tokenizer,
    mpu,
    print_rank_0
)
from megatron.data.dataset_utils import (
    get_samples_mapping,
    get_a_and_b_segments,
    truncate_segments,
    create_tokens_and_tokentypes,
    create_masked_lm_predictions
)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
34

Neel Kant's avatar
Neel Kant committed
35

36
class BertDataset(torch.utils.data.Dataset):
37

38
    def __init__(self, name, indexed_dataset, data_prefix,
39
                 num_epochs, max_num_samples, masked_lm_prob,
40
                 max_seq_length, short_seq_prob, seed, binary_head):
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
41
42

        # Params to store.
43
        self.name = name
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
44
45
46
        self.seed = seed
        self.masked_lm_prob = masked_lm_prob
        self.max_seq_length = max_seq_length
47
        self.binary_head = binary_head
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
48

49
        # Dataset.
50
51
        self.indexed_dataset = indexed_dataset

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
52
        # Build the samples mapping.
53
54
55
56
57
58
59
60
61
        self.samples_mapping = get_samples_mapping(self.indexed_dataset,
                                                   data_prefix,
                                                   num_epochs,
                                                   max_num_samples,
                                                   self.max_seq_length - 3, # account for added tokens
                                                   short_seq_prob,
                                                   self.seed,
                                                   self.name,
                                                   self.binary_head)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
62
63

        # Vocab stuff.
64
65
66
67
68
69
70
        tokenizer = get_tokenizer()
        self.vocab_id_list = list(tokenizer.inv_vocab.keys())
        self.vocab_id_to_token_dict = tokenizer.inv_vocab
        self.cls_id = tokenizer.cls
        self.sep_id = tokenizer.sep
        self.mask_id = tokenizer.mask
        self.pad_id = tokenizer.pad
71

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
72
    def __len__(self):
73
        return self.samples_mapping.shape[0]
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
74
75

    def __getitem__(self, idx):
76
77
        start_idx, end_idx, seq_length = self.samples_mapping[idx]
        sample = [self.indexed_dataset[i] for i in range(start_idx, end_idx)]
78
79
        # Note that this rng state should be numpy and not python since
        # python randint is inclusive whereas the numpy one is exclusive.
80
81
        # We % 2**32 since numpy requres the seed to be between 0 and 2**32 - 1
        np_rng = np.random.RandomState(seed=((self.seed + idx) % 2**32))
82
83
84
85
86
87
        return build_training_sample(sample, seq_length,
                                     self.max_seq_length,  # needed for padding
                                     self.vocab_id_list,
                                     self.vocab_id_to_token_dict,
                                     self.cls_id, self.sep_id,
                                     self.mask_id, self.pad_id,
88
89
                                     self.masked_lm_prob, np_rng,
                                     self.binary_head)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
90

91

92
93
94
95
96
97


def build_training_sample(sample,
                          target_seq_length, max_seq_length,
                          vocab_id_list, vocab_id_to_token_dict,
                          cls_id, sep_id, mask_id, pad_id,
98
                          masked_lm_prob, np_rng, binary_head):
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
    """Biuld training sample.

    Arguments:
        sample: A list of sentences in which each sentence is a list token ids.
        target_seq_length: Desired sequence length.
        max_seq_length: Maximum length of the sequence. All values are padded to
            this length.
        vocab_id_list: List of vocabulary ids. Used to pick a random id.
        vocab_id_to_token_dict: A dictionary from vocab ids to text tokens.
        cls_id: Start of example id.
        sep_id: Separator id.
        mask_id: Mask token id.
        pad_id: Padding token id.
        masked_lm_prob: Probability to mask tokens.
        np_rng: Random number genenrator. Note that this rng state should be
              numpy and not python since python randint is inclusive for
              the opper bound whereas the numpy one is exclusive.
    """

118
119
120
    if binary_head:
        # We assume that we have at least two sentences in the sample
        assert len(sample) > 1
121
122
123
    assert target_seq_length <= max_seq_length

    # Divide sample into two segments (A and B).
124
125
126
127
128
129
130
131
132
    if binary_head:
        tokens_a, tokens_b, is_next_random = get_a_and_b_segments(sample,
                                                                  np_rng)
    else:
        tokens_a = []
        for j in range(len(sample)):
            tokens_a.extend(sample[j])
        tokens_b = []
        is_next_random = False
133
134
135
136
137
138
139
140
141
142
143
144

    # Truncate to `target_sequence_length`.
    max_num_tokens = target_seq_length
    truncated = truncate_segments(tokens_a, tokens_b, len(tokens_a),
                                  len(tokens_b), max_num_tokens, np_rng)

    # Build tokens and toketypes.
    tokens, tokentypes = create_tokens_and_tokentypes(tokens_a, tokens_b,
                                                      cls_id, sep_id)

    # Masking.
    max_predictions_per_seq = masked_lm_prob * max_num_tokens
145
    (tokens, masked_positions, masked_labels, _, _) = create_masked_lm_predictions(
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
        tokens, vocab_id_list, vocab_id_to_token_dict, masked_lm_prob,
        cls_id, sep_id, mask_id, max_predictions_per_seq, np_rng)

    # Padding.
    tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np \
        = pad_and_convert_to_numpy(tokens, tokentypes, masked_positions,
                                   masked_labels, pad_id, max_seq_length)

    train_sample = {
        'text': tokens_np,
        'types': tokentypes_np,
        'labels': labels_np,
        'is_random': int(is_next_random),
        'loss_mask': loss_mask_np,
        'padding_mask': padding_mask_np,
        'truncated': int(truncated)}
    return train_sample

164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195

def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions,
                             masked_labels, pad_id, max_seq_length):
    """Pad sequences and convert them to numpy."""

    # Some checks.
    num_tokens = len(tokens)
    padding_length = max_seq_length - num_tokens
    assert padding_length >= 0
    assert len(tokentypes) == num_tokens
    assert len(masked_positions) == len(masked_labels)

    # Tokens and token types.
    filler = [pad_id] * padding_length
    tokens_np = np.array(tokens + filler, dtype=np.int64)
    tokentypes_np = np.array(tokentypes + filler, dtype=np.int64)

    # Padding mask.
    padding_mask_np = np.array([1] * num_tokens + [0] * padding_length,
                               dtype=np.int64)

    # Lables and loss mask.
    labels = [-1] * max_seq_length
    loss_mask = [0] * max_seq_length
    for i in range(len(masked_positions)):
        assert masked_positions[i] < num_tokens
        labels[masked_positions[i]] = masked_labels[i]
        loss_mask[masked_positions[i]] = 1
    labels_np = np.array(labels, dtype=np.int64)
    loss_mask_np = np.array(loss_mask, dtype=np.int64)

    return tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np