pretrain_bert.py 7.29 KB
Newer Older
Raul Puri's avatar
Raul Puri committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
16
"""Pretrain BERT"""
Raul Puri's avatar
Raul Puri committed
17
18

import torch
19
import torch.nn.functional as F
Raul Puri's avatar
Raul Puri committed
20

Mohammad's avatar
Mohammad committed
21
22
from megatron import get_args
from megatron import get_timers
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
23
from megatron import mpu
24
from megatron import print_rank_0
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
25
from megatron.data.bert_dataset import build_train_valid_test_datasets
26
from megatron.data_utils.samplers import DistributedBatchSampler
Mohammad's avatar
Mohammad committed
27
28
29
30
31
from megatron.model import BertModel
from megatron.training import pretrain
from megatron.utils import reduce_losses


32

Raul Puri's avatar
Raul Puri committed
33

Mohammad's avatar
Mohammad committed
34
def model_provider():
Raul Puri's avatar
Raul Puri committed
35
    """Build the model."""
Mohammad's avatar
Mohammad committed
36
    args = get_args()
Raul Puri's avatar
Raul Puri committed
37

38
    print_rank_0('building BERT model ...')
Raul Puri's avatar
Raul Puri committed
39

40
41
    model = BertModel(
        num_layers=args.num_layers,
Mohammad's avatar
Mohammad committed
42
        vocab_size=args.padded_vocab_size,
43
44
45
46
47
48
49
50
51
52
        hidden_size=args.hidden_size,
        num_attention_heads=args.num_attention_heads,
        embedding_dropout_prob=args.hidden_dropout,
        attention_dropout_prob=args.attention_dropout,
        output_dropout_prob=args.hidden_dropout,
        max_sequence_length=args.max_position_embeddings,
        checkpoint_activations=args.checkpoint_activations,
        checkpoint_num_layers=args.checkpoint_num_layers,
        add_binary_head=True,
        layernorm_epsilon=args.layernorm_epsilon,
Mohammad's avatar
Mohammad committed
53
        num_tokentypes=2,
54
55
56
        parallel_output=True,
        apply_query_key_layer_scaling=args.apply_query_key_layer_scaling,
        attention_softmax_in_fp32=args.attention_softmax_in_fp32)
Raul Puri's avatar
Raul Puri committed
57

58
    return model
Raul Puri's avatar
Raul Puri committed
59
60


Mohammad's avatar
Mohammad committed
61
def get_batch(data_iterator):
62

63
    # Items and their type.
64
    keys = ['text', 'types', 'labels', 'is_random', 'loss_mask', 'padding_mask']
65
66
67
68
69
70
71
72
73
74
75
76
    datatype = torch.int64

    # Broadcast data.
    if data_iterator is not None:
        data = next(data_iterator)
    else:
        data = None
    data_b = mpu.broadcast_data(keys, data, datatype)

    # Unpack.
    tokens = data_b['text'].long()
    types = data_b['types'].long()
77
78
79
80
    sentence_order = data_b['is_random'].long()
    loss_mask = data_b['loss_mask'].float()
    lm_labels = data_b['labels'].long()
    padding_mask = data_b['padding_mask'].long()
Raul Puri's avatar
Raul Puri committed
81

82
    return tokens, types, sentence_order, loss_mask, lm_labels, padding_mask
Raul Puri's avatar
Raul Puri committed
83
84


Mohammad's avatar
Mohammad committed
85
def forward_step(data_iterator, model):
Raul Puri's avatar
Raul Puri committed
86
    """Forward step."""
Mohammad's avatar
Mohammad committed
87
    timers = get_timers()
Raul Puri's avatar
Raul Puri committed
88
89

    # Get the batch.
90
    timers('batch generator').start()
91
    tokens, types, sentence_order, loss_mask, lm_labels, padding_mask \
Mohammad's avatar
Mohammad committed
92
        = get_batch(data_iterator)
93
    timers('batch generator').stop()
94

Raul Puri's avatar
Raul Puri committed
95
    # Forward model.
96
    lm_logits, sop_logits = model(tokens, padding_mask, tokentype_ids=types)
97

98
99
    sop_loss = F.cross_entropy(sop_logits.view(-1, 2).contiguous().float(),
                               sentence_order.view(-1).contiguous(),
100
101
                               ignore_index=-1)

102
103
    lm_loss_ = mpu.vocab_parallel_cross_entropy(lm_logits.contiguous().float(),
                                                lm_labels.contiguous())
Raul Puri's avatar
Raul Puri committed
104
    lm_loss = torch.sum(
105
        lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()
Raul Puri's avatar
Raul Puri committed
106

107
    loss = lm_loss + sop_loss
Raul Puri's avatar
Raul Puri committed
108

109
    reduced_losses = reduce_losses([lm_loss, sop_loss])
Raul Puri's avatar
Raul Puri committed
110

111
    return loss, {'lm loss': reduced_losses[0], 'sop loss': reduced_losses[1]}
112
113


Mohammad's avatar
Mohammad committed
114
def get_train_val_test_data():
115
    """Load the data on rank zero and boradcast number of tokens to all GPUS."""
Mohammad's avatar
Mohammad committed
116
    args = get_args()
Mohammad's avatar
Mohammad committed
117

118
    (train_data, valid_data, test_data) = (None, None, None)
119
120
121

    # Data loader only on rank 0 of each model parallel group.
    if mpu.get_model_parallel_rank() == 0:
122
        print_rank_0('> building train, validation, and test datasets '
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
123
                     'for BERT ...')
124
125
126
127
128
129
130
131
132

        data_parallel_size = mpu.get_data_parallel_world_size()
        data_parallel_rank = mpu.get_data_parallel_rank()
        global_batch_size = args.batch_size * data_parallel_size

        # Number of train/valid/test samples.
        train_iters = args.train_iters
        eval_iters = (train_iters // args.eval_interval + 1) * args.eval_iters
        test_iters = args.eval_iters
Mohammad's avatar
Mohammad committed
133
        train_val_test_num_samples = [train_iters * global_batch_size,
134
135
136
137
138
139
140
141
                                      eval_iters * global_batch_size,
                                      test_iters * global_batch_size]
        print_rank_0(' > datasets target sizes (minimum size):')
        print_rank_0('    train:      {}'.format(train_val_test_num_samples[0]))
        print_rank_0('    validation: {}'.format(train_val_test_num_samples[1]))
        print_rank_0('    test:       {}'.format(train_val_test_num_samples[2]))

        train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
Mohammad's avatar
Mohammad committed
142
143
            vocab_file=args.vocab_file,
            data_prefix=args.data_path,
144
145
146
147
148
149
150
            data_impl=args.data_impl,
            splits_string=args.split,
            train_valid_test_num_samples=train_val_test_num_samples,
            max_seq_length=args.seq_length,
            masked_lm_prob=args.mask_prob,
            short_seq_prob=args.short_seq_prob,
            seed=args.seed,
Mohammad's avatar
Mohammad committed
151
            skip_warmup=(not args.mmap_warmup))
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
152
        print_rank_0("> finished creating BERT datasets ...")
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178

        def make_data_loader_(dataset):
            if not dataset:
                return None
            # Use a simple sampler with distributed batch sampler.
            sampler = torch.utils.data.SequentialSampler(dataset)
            batch_sampler = DistributedBatchSampler(
                sampler=sampler,
                batch_size=global_batch_size,
                drop_last=True,
                rank=data_parallel_rank,
                world_size=data_parallel_size)
            # Torch dataloader.
            return torch.utils.data.DataLoader(dataset,
                                               batch_sampler=batch_sampler,
                                               num_workers=args.num_workers,
                                               pin_memory=True)

        train_data = make_data_loader_(train_ds)
        valid_data = make_data_loader_(valid_ds)
        test_data = make_data_loader_(test_ds)

        do_train = train_data is not None and args.train_iters > 0
        do_valid = valid_data is not None and args.eval_iters > 0
        do_test = test_data is not None and args.eval_iters > 0
        # Need to broadcast num_tokens and num_type_tokens.
Mohammad's avatar
Mohammad committed
179
180
        flags = torch.cuda.LongTensor(
            [int(do_train), int(do_valid), int(do_test)])
181
    else:
Mohammad's avatar
Mohammad committed
182
        flags = torch.cuda.LongTensor([0, 0, 0])
183
184

    # Broadcast num tokens.
Mohammad's avatar
Mohammad committed
185
    torch.distributed.broadcast(flags,
186
187
                                mpu.get_model_parallel_src_rank(),
                                group=mpu.get_model_parallel_group())
Mohammad's avatar
Mohammad committed
188
189
190
    args.do_train = flags[0].item()
    args.do_valid = flags[1].item()
    args.do_test = flags[2].item()
191

192
    return train_data, valid_data, test_data
Raul Puri's avatar
Raul Puri committed
193
194
195


if __name__ == "__main__":
196

Mohammad's avatar
Mohammad committed
197
    pretrain(get_train_val_test_data, model_provider, forward_step,
Mohammad's avatar
Mohammad committed
198
             args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})