bert_dataset.py 7.02 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
2

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
3
"""BERT Style dataset."""
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
4
5
6
7

import numpy as np
import torch

8
9
10
11
12
13
14
15
16
17
18
19
20
from megatron import (
    get_args,
    get_tokenizer,
    mpu,
    print_rank_0
)
from megatron.data.dataset_utils import (
    get_samples_mapping,
    get_a_and_b_segments,
    truncate_segments,
    create_tokens_and_tokentypes,
    create_masked_lm_predictions
)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
21

22
class BertDataset(torch.utils.data.Dataset):
23

24
    def __init__(self, name, indexed_dataset, data_prefix,
25
                 num_epochs, max_num_samples, masked_lm_prob,
26
                 max_seq_length, short_seq_prob, seed, binary_head):
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
27
28

        # Params to store.
29
        self.name = name
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
30
31
32
        self.seed = seed
        self.masked_lm_prob = masked_lm_prob
        self.max_seq_length = max_seq_length
33
        self.binary_head = binary_head
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
34

35
        # Dataset.
36
37
        self.indexed_dataset = indexed_dataset

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
38
        # Build the samples mapping.
39
40
41
42
43
44
45
46
47
        self.samples_mapping = get_samples_mapping(self.indexed_dataset,
                                                   data_prefix,
                                                   num_epochs,
                                                   max_num_samples,
                                                   self.max_seq_length - 3, # account for added tokens
                                                   short_seq_prob,
                                                   self.seed,
                                                   self.name,
                                                   self.binary_head)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
48
49

        # Vocab stuff.
50
51
52
53
54
55
56
        tokenizer = get_tokenizer()
        self.vocab_id_list = list(tokenizer.inv_vocab.keys())
        self.vocab_id_to_token_dict = tokenizer.inv_vocab
        self.cls_id = tokenizer.cls
        self.sep_id = tokenizer.sep
        self.mask_id = tokenizer.mask
        self.pad_id = tokenizer.pad
57

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
58
    def __len__(self):
59
        return self.samples_mapping.shape[0]
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
60
61

    def __getitem__(self, idx):
62
63
        start_idx, end_idx, seq_length = self.samples_mapping[idx]
        sample = [self.indexed_dataset[i] for i in range(start_idx, end_idx)]
64
65
        # Note that this rng state should be numpy and not python since
        # python randint is inclusive whereas the numpy one is exclusive.
66
67
        # We % 2**32 since numpy requres the seed to be between 0 and 2**32 - 1
        np_rng = np.random.RandomState(seed=((self.seed + idx) % 2**32))
68
69
70
71
72
73
        return build_training_sample(sample, seq_length,
                                     self.max_seq_length,  # needed for padding
                                     self.vocab_id_list,
                                     self.vocab_id_to_token_dict,
                                     self.cls_id, self.sep_id,
                                     self.mask_id, self.pad_id,
74
75
                                     self.masked_lm_prob, np_rng,
                                     self.binary_head)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
76

77

78
79
80
81
82
83


def build_training_sample(sample,
                          target_seq_length, max_seq_length,
                          vocab_id_list, vocab_id_to_token_dict,
                          cls_id, sep_id, mask_id, pad_id,
84
                          masked_lm_prob, np_rng, binary_head):
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
    """Biuld training sample.

    Arguments:
        sample: A list of sentences in which each sentence is a list token ids.
        target_seq_length: Desired sequence length.
        max_seq_length: Maximum length of the sequence. All values are padded to
            this length.
        vocab_id_list: List of vocabulary ids. Used to pick a random id.
        vocab_id_to_token_dict: A dictionary from vocab ids to text tokens.
        cls_id: Start of example id.
        sep_id: Separator id.
        mask_id: Mask token id.
        pad_id: Padding token id.
        masked_lm_prob: Probability to mask tokens.
        np_rng: Random number genenrator. Note that this rng state should be
              numpy and not python since python randint is inclusive for
              the opper bound whereas the numpy one is exclusive.
    """

104
105
106
    if binary_head:
        # We assume that we have at least two sentences in the sample
        assert len(sample) > 1
107
108
109
    assert target_seq_length <= max_seq_length

    # Divide sample into two segments (A and B).
110
111
112
113
114
115
116
117
118
    if binary_head:
        tokens_a, tokens_b, is_next_random = get_a_and_b_segments(sample,
                                                                  np_rng)
    else:
        tokens_a = []
        for j in range(len(sample)):
            tokens_a.extend(sample[j])
        tokens_b = []
        is_next_random = False
119
120
121
122
123
124
125
126
127
128
129
130

    # Truncate to `target_sequence_length`.
    max_num_tokens = target_seq_length
    truncated = truncate_segments(tokens_a, tokens_b, len(tokens_a),
                                  len(tokens_b), max_num_tokens, np_rng)

    # Build tokens and toketypes.
    tokens, tokentypes = create_tokens_and_tokentypes(tokens_a, tokens_b,
                                                      cls_id, sep_id)

    # Masking.
    max_predictions_per_seq = masked_lm_prob * max_num_tokens
131
    (tokens, masked_positions, masked_labels, _, _) = create_masked_lm_predictions(
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
        tokens, vocab_id_list, vocab_id_to_token_dict, masked_lm_prob,
        cls_id, sep_id, mask_id, max_predictions_per_seq, np_rng)

    # Padding.
    tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np \
        = pad_and_convert_to_numpy(tokens, tokentypes, masked_positions,
                                   masked_labels, pad_id, max_seq_length)

    train_sample = {
        'text': tokens_np,
        'types': tokentypes_np,
        'labels': labels_np,
        'is_random': int(is_next_random),
        'loss_mask': loss_mask_np,
        'padding_mask': padding_mask_np,
        'truncated': int(truncated)}
    return train_sample

150
151
152
153
154
155
156
157

def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions,
                             masked_labels, pad_id, max_seq_length):
    """Pad sequences and convert them to numpy."""

    # Some checks.
    num_tokens = len(tokens)
    padding_length = max_seq_length - num_tokens
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
158
159
160
    assert padding_length >= 0, \
        f"num_tokens ({num_tokens}) is greater than " \
        "max_seq_length ({max_seq_length})."
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
    assert len(tokentypes) == num_tokens
    assert len(masked_positions) == len(masked_labels)

    # Tokens and token types.
    filler = [pad_id] * padding_length
    tokens_np = np.array(tokens + filler, dtype=np.int64)
    tokentypes_np = np.array(tokentypes + filler, dtype=np.int64)

    # Padding mask.
    padding_mask_np = np.array([1] * num_tokens + [0] * padding_length,
                               dtype=np.int64)

    # Lables and loss mask.
    labels = [-1] * max_seq_length
    loss_mask = [0] * max_seq_length
    for i in range(len(masked_positions)):
        assert masked_positions[i] < num_tokens
        labels[masked_positions[i]] = masked_labels[i]
        loss_mask[masked_positions[i]] = 1
    labels_np = np.array(labels, dtype=np.int64)
    loss_mask_np = np.array(loss_mask, dtype=np.int64)

    return tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np