pretrain_bert.py 7.94 KB
Newer Older
Raul Puri's avatar
Raul Puri committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
16
"""Pretrain BERT"""
Raul Puri's avatar
Raul Puri committed
17
18

import torch
19
import torch.nn.functional as F
Raul Puri's avatar
Raul Puri committed
20

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
21
from megatron import mpu
22
from megatron.model import BertModel
23
from megatron import print_rank_0
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
24
from megatron.utils import reduce_losses
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
25
from megatron.utils import vocab_size_with_padding
26
from megatron.training import run
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
27
from megatron.data.bert_dataset import build_train_valid_test_datasets
28
from megatron.data_utils.samplers import DistributedBatchSampler
29

Raul Puri's avatar
Raul Puri committed
30

31
def model_provider(args):
Raul Puri's avatar
Raul Puri committed
32
33
    """Build the model."""

34
    print_rank_0('building BERT model ...')
Raul Puri's avatar
Raul Puri committed
35

36
37
38
39
40
41
42
43
44
45
46
47
48
49
    model = BertModel(
        num_layers=args.num_layers,
        vocab_size=args.vocab_size,
        hidden_size=args.hidden_size,
        num_attention_heads=args.num_attention_heads,
        embedding_dropout_prob=args.hidden_dropout,
        attention_dropout_prob=args.attention_dropout,
        output_dropout_prob=args.hidden_dropout,
        max_sequence_length=args.max_position_embeddings,
        checkpoint_activations=args.checkpoint_activations,
        checkpoint_num_layers=args.checkpoint_num_layers,
        add_binary_head=True,
        layernorm_epsilon=args.layernorm_epsilon,
        num_tokentypes=args.tokentype_size,
50
51
52
        parallel_output=True,
        apply_query_key_layer_scaling=args.apply_query_key_layer_scaling,
        attention_softmax_in_fp32=args.attention_softmax_in_fp32)
Raul Puri's avatar
Raul Puri committed
53

54
    return model
Raul Puri's avatar
Raul Puri committed
55
56


57
def get_batch(data_iterator, timers):
58

59
    # Items and their type.
60
    keys = ['text', 'types', 'labels', 'is_random', 'loss_mask', 'padding_mask']
61
62
63
64
65
66
67
68
69
70
71
72
73
74
    datatype = torch.int64

    # Broadcast data.
    timers('data loader').start()
    if data_iterator is not None:
        data = next(data_iterator)
    else:
        data = None
    timers('data loader').stop()
    data_b = mpu.broadcast_data(keys, data, datatype)

    # Unpack.
    tokens = data_b['text'].long()
    types = data_b['types'].long()
75
76
77
78
    sentence_order = data_b['is_random'].long()
    loss_mask = data_b['loss_mask'].float()
    lm_labels = data_b['labels'].long()
    padding_mask = data_b['padding_mask'].long()
Raul Puri's avatar
Raul Puri committed
79

80
    return tokens, types, sentence_order, loss_mask, lm_labels, padding_mask
Raul Puri's avatar
Raul Puri committed
81
82


83
def forward_step(data_iterator, model, args, timers):
Raul Puri's avatar
Raul Puri committed
84
85
86
    """Forward step."""

    # Get the batch.
87
    timers('batch generator').start()
88
    tokens, types, sentence_order, loss_mask, lm_labels, padding_mask \
89
        = get_batch(data_iterator, timers)
90
    timers('batch generator').stop()
91

Raul Puri's avatar
Raul Puri committed
92
    # Forward model.
93
    lm_logits, sop_logits = model(tokens, padding_mask, tokentype_ids=types)
94

95
96
    sop_loss = F.cross_entropy(sop_logits.view(-1, 2).contiguous().float(),
                               sentence_order.view(-1).contiguous(),
97
98
                               ignore_index=-1)

99
100
    lm_loss_ = mpu.vocab_parallel_cross_entropy(lm_logits.contiguous().float(),
                                                lm_labels.contiguous())
Raul Puri's avatar
Raul Puri committed
101
    lm_loss = torch.sum(
102
        lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()
Raul Puri's avatar
Raul Puri committed
103

104
    loss = lm_loss + sop_loss
Raul Puri's avatar
Raul Puri committed
105

106
    reduced_losses = reduce_losses([lm_loss, sop_loss])
Raul Puri's avatar
Raul Puri committed
107

108
    return loss, {'lm loss': reduced_losses[0], 'sop loss': reduced_losses[1]}
109
110
111
112
113


def get_train_val_test_data(args):
    """Load the data on rank zero and boradcast number of tokens to all GPUS."""

114
    (train_data, valid_data, test_data) = (None, None, None)
115
116
117

    # Data loader only on rank 0 of each model parallel group.
    if mpu.get_model_parallel_rank() == 0:
118
        print_rank_0('> building train, validation, and test datasets '
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
119
                     'for BERT ...')
120
121
122
123
124
125
126
127
128

        data_parallel_size = mpu.get_data_parallel_world_size()
        data_parallel_rank = mpu.get_data_parallel_rank()
        global_batch_size = args.batch_size * data_parallel_size

        # Number of train/valid/test samples.
        train_iters = args.train_iters
        eval_iters = (train_iters // args.eval_interval + 1) * args.eval_iters
        test_iters = args.eval_iters
Mohammad's avatar
Mohammad committed
129
        train_val_test_num_samples = [train_iters * global_batch_size,
130
131
132
133
134
135
136
137
                                      eval_iters * global_batch_size,
                                      test_iters * global_batch_size]
        print_rank_0(' > datasets target sizes (minimum size):')
        print_rank_0('    train:      {}'.format(train_val_test_num_samples[0]))
        print_rank_0('    validation: {}'.format(train_val_test_num_samples[1]))
        print_rank_0('    test:       {}'.format(train_val_test_num_samples[2]))

        train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
Mohammad's avatar
Mohammad committed
138
139
            vocab_file=args.vocab_file,
            data_prefix=args.data_path,
140
141
142
143
144
145
146
            data_impl=args.data_impl,
            splits_string=args.split,
            train_valid_test_num_samples=train_val_test_num_samples,
            max_seq_length=args.seq_length,
            masked_lm_prob=args.mask_prob,
            short_seq_prob=args.short_seq_prob,
            seed=args.seed,
Mohammad's avatar
Mohammad committed
147
            skip_warmup=(not args.mmap_warmup))
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
148
        print_rank_0("> finished creating BERT datasets ...")
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180

        def make_data_loader_(dataset):
            if not dataset:
                return None
            # Use a simple sampler with distributed batch sampler.
            sampler = torch.utils.data.SequentialSampler(dataset)
            batch_sampler = DistributedBatchSampler(
                sampler=sampler,
                batch_size=global_batch_size,
                drop_last=True,
                rank=data_parallel_rank,
                world_size=data_parallel_size)
            # Torch dataloader.
            return torch.utils.data.DataLoader(dataset,
                                               batch_sampler=batch_sampler,
                                               num_workers=args.num_workers,
                                               pin_memory=True)

        train_data = make_data_loader_(train_ds)
        valid_data = make_data_loader_(valid_ds)
        test_data = make_data_loader_(test_ds)

        do_train = train_data is not None and args.train_iters > 0
        do_valid = valid_data is not None and args.eval_iters > 0
        do_test = test_data is not None and args.eval_iters > 0
        # Need to broadcast num_tokens and num_type_tokens.
        num_tokens = vocab_size_with_padding(train_ds.num_tokens(), args)
        token_counts = torch.cuda.LongTensor([num_tokens,
                                              2, # hard coded num_type_tokens
                                              int(do_train),
                                              int(do_valid),
                                              int(do_test)])
181
182
183
184
185
186
187
    else:
        token_counts = torch.cuda.LongTensor([0, 0, 0, 0, 0])

    # Broadcast num tokens.
    torch.distributed.broadcast(token_counts,
                                mpu.get_model_parallel_src_rank(),
                                group=mpu.get_model_parallel_group())
188
189
    args.vocab_size = token_counts[0].item()
    args.tokentype_size = token_counts[1].item()
190
191
192
193
    args.do_train = token_counts[2].item()
    args.do_valid = token_counts[3].item()
    args.do_test = token_counts[4].item()

194
    return train_data, valid_data, test_data
Raul Puri's avatar
Raul Puri committed
195
196
197


if __name__ == "__main__":
198

Mohammad's avatar
Mohammad committed
199
200
201
202
203
204
    '''
    from megatron.initialize import initialize_megatron
    initialize_megatron(args_defaults={
        'tokenizer_type': 'BertWordPieceLowerCase'})
    exit()
    '''
205
    run('Pretrain BERT model', get_train_val_test_data,
Mohammad's avatar
Mohammad committed
206
207
        model_provider, forward_step,
        args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})