bert_dataset.py 6.92 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
2

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
3
"""BERT Style dataset."""
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
4
5
6
7

import numpy as np
import torch

8
9
10
11
12
13
14
15
16
17
18
19
20
from megatron import (
    get_args,
    get_tokenizer,
    mpu,
    print_rank_0
)
from megatron.data.dataset_utils import (
    get_samples_mapping,
    get_a_and_b_segments,
    truncate_segments,
    create_tokens_and_tokentypes,
    create_masked_lm_predictions
)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
21

Neel Kant's avatar
Neel Kant committed
22

23
class BertDataset(torch.utils.data.Dataset):
24

25
    def __init__(self, name, indexed_dataset, data_prefix,
26
                 num_epochs, max_num_samples, masked_lm_prob,
27
                 max_seq_length, short_seq_prob, seed, binary_head):
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
28
29

        # Params to store.
30
        self.name = name
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
31
32
33
        self.seed = seed
        self.masked_lm_prob = masked_lm_prob
        self.max_seq_length = max_seq_length
34
        self.binary_head = binary_head
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
35

36
        # Dataset.
37
38
        self.indexed_dataset = indexed_dataset

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
39
        # Build the samples mapping.
40
41
42
43
44
45
46
47
48
        self.samples_mapping = get_samples_mapping(self.indexed_dataset,
                                                   data_prefix,
                                                   num_epochs,
                                                   max_num_samples,
                                                   self.max_seq_length - 3, # account for added tokens
                                                   short_seq_prob,
                                                   self.seed,
                                                   self.name,
                                                   self.binary_head)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
49
50

        # Vocab stuff.
51
52
53
54
55
56
57
        tokenizer = get_tokenizer()
        self.vocab_id_list = list(tokenizer.inv_vocab.keys())
        self.vocab_id_to_token_dict = tokenizer.inv_vocab
        self.cls_id = tokenizer.cls
        self.sep_id = tokenizer.sep
        self.mask_id = tokenizer.mask
        self.pad_id = tokenizer.pad
58

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
59
    def __len__(self):
60
        return self.samples_mapping.shape[0]
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
61
62

    def __getitem__(self, idx):
63
64
        start_idx, end_idx, seq_length = self.samples_mapping[idx]
        sample = [self.indexed_dataset[i] for i in range(start_idx, end_idx)]
65
66
        # Note that this rng state should be numpy and not python since
        # python randint is inclusive whereas the numpy one is exclusive.
67
68
        # We % 2**32 since numpy requres the seed to be between 0 and 2**32 - 1
        np_rng = np.random.RandomState(seed=((self.seed + idx) % 2**32))
69
70
71
72
73
74
        return build_training_sample(sample, seq_length,
                                     self.max_seq_length,  # needed for padding
                                     self.vocab_id_list,
                                     self.vocab_id_to_token_dict,
                                     self.cls_id, self.sep_id,
                                     self.mask_id, self.pad_id,
75
76
                                     self.masked_lm_prob, np_rng,
                                     self.binary_head)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
77

78

79
80
81
82
83
84


def build_training_sample(sample,
                          target_seq_length, max_seq_length,
                          vocab_id_list, vocab_id_to_token_dict,
                          cls_id, sep_id, mask_id, pad_id,
85
                          masked_lm_prob, np_rng, binary_head):
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
    """Biuld training sample.

    Arguments:
        sample: A list of sentences in which each sentence is a list token ids.
        target_seq_length: Desired sequence length.
        max_seq_length: Maximum length of the sequence. All values are padded to
            this length.
        vocab_id_list: List of vocabulary ids. Used to pick a random id.
        vocab_id_to_token_dict: A dictionary from vocab ids to text tokens.
        cls_id: Start of example id.
        sep_id: Separator id.
        mask_id: Mask token id.
        pad_id: Padding token id.
        masked_lm_prob: Probability to mask tokens.
        np_rng: Random number genenrator. Note that this rng state should be
              numpy and not python since python randint is inclusive for
              the opper bound whereas the numpy one is exclusive.
    """

105
106
107
    if binary_head:
        # We assume that we have at least two sentences in the sample
        assert len(sample) > 1
108
109
110
    assert target_seq_length <= max_seq_length

    # Divide sample into two segments (A and B).
111
112
113
114
115
116
117
118
119
    if binary_head:
        tokens_a, tokens_b, is_next_random = get_a_and_b_segments(sample,
                                                                  np_rng)
    else:
        tokens_a = []
        for j in range(len(sample)):
            tokens_a.extend(sample[j])
        tokens_b = []
        is_next_random = False
120
121
122
123
124
125
126
127
128
129
130
131

    # Truncate to `target_sequence_length`.
    max_num_tokens = target_seq_length
    truncated = truncate_segments(tokens_a, tokens_b, len(tokens_a),
                                  len(tokens_b), max_num_tokens, np_rng)

    # Build tokens and toketypes.
    tokens, tokentypes = create_tokens_and_tokentypes(tokens_a, tokens_b,
                                                      cls_id, sep_id)

    # Masking.
    max_predictions_per_seq = masked_lm_prob * max_num_tokens
132
    (tokens, masked_positions, masked_labels, _, _) = create_masked_lm_predictions(
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
        tokens, vocab_id_list, vocab_id_to_token_dict, masked_lm_prob,
        cls_id, sep_id, mask_id, max_predictions_per_seq, np_rng)

    # Padding.
    tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np \
        = pad_and_convert_to_numpy(tokens, tokentypes, masked_positions,
                                   masked_labels, pad_id, max_seq_length)

    train_sample = {
        'text': tokens_np,
        'types': tokentypes_np,
        'labels': labels_np,
        'is_random': int(is_next_random),
        'loss_mask': loss_mask_np,
        'padding_mask': padding_mask_np,
        'truncated': int(truncated)}
    return train_sample

151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182

def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions,
                             masked_labels, pad_id, max_seq_length):
    """Pad sequences and convert them to numpy."""

    # Some checks.
    num_tokens = len(tokens)
    padding_length = max_seq_length - num_tokens
    assert padding_length >= 0
    assert len(tokentypes) == num_tokens
    assert len(masked_positions) == len(masked_labels)

    # Tokens and token types.
    filler = [pad_id] * padding_length
    tokens_np = np.array(tokens + filler, dtype=np.int64)
    tokentypes_np = np.array(tokentypes + filler, dtype=np.int64)

    # Padding mask.
    padding_mask_np = np.array([1] * num_tokens + [0] * padding_length,
                               dtype=np.int64)

    # Lables and loss mask.
    labels = [-1] * max_seq_length
    loss_mask = [0] * max_seq_length
    for i in range(len(masked_positions)):
        assert masked_positions[i] < num_tokens
        labels[masked_positions[i]] = masked_labels[i]
        loss_mask[masked_positions[i]] = 1
    labels_np = np.array(labels, dtype=np.int64)
    loss_mask_np = np.array(loss_mask, dtype=np.int64)

    return tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np