pretrain_bert.py 5.79 KB
Newer Older
Raul Puri's avatar
Raul Puri committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
16
"""Pretrain BERT"""
Raul Puri's avatar
Raul Puri committed
17
18

import torch
19
import torch.nn.functional as F
Raul Puri's avatar
Raul Puri committed
20

Mohammad's avatar
Mohammad committed
21
22
from megatron import get_args
from megatron import get_timers
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
23
from megatron import mpu
24
from megatron import print_rank_0
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
25
from megatron.data.bert_dataset import build_train_valid_test_datasets
Mohammad's avatar
Mohammad committed
26
27
from megatron.model import BertModel
from megatron.training import pretrain
28
from megatron.utils import make_data_loader
Mohammad's avatar
Mohammad committed
29
30
31
from megatron.utils import reduce_losses


Mohammad's avatar
Mohammad committed
32
def model_provider():
Raul Puri's avatar
Raul Puri committed
33
    """Build the model."""
Mohammad's avatar
Mohammad committed
34
    args = get_args()
Raul Puri's avatar
Raul Puri committed
35

36
    print_rank_0('building BERT model ...')
Raul Puri's avatar
Raul Puri committed
37

38
    model = BertModel(
Mohammad's avatar
Mohammad committed
39
        num_tokentypes=2,
Mohammad's avatar
Mohammad committed
40
41
        add_binary_head=True,
        parallel_output=True)
Raul Puri's avatar
Raul Puri committed
42

43
    return model
Raul Puri's avatar
Raul Puri committed
44
45


Mohammad's avatar
Mohammad committed
46
def get_batch(data_iterator):
47

48
    # Items and their type.
49
    keys = ['text', 'types', 'labels', 'is_random', 'loss_mask', 'padding_mask']
50
51
52
53
54
55
56
57
58
59
60
61
    datatype = torch.int64

    # Broadcast data.
    if data_iterator is not None:
        data = next(data_iterator)
    else:
        data = None
    data_b = mpu.broadcast_data(keys, data, datatype)

    # Unpack.
    tokens = data_b['text'].long()
    types = data_b['types'].long()
62
63
64
65
    sentence_order = data_b['is_random'].long()
    loss_mask = data_b['loss_mask'].float()
    lm_labels = data_b['labels'].long()
    padding_mask = data_b['padding_mask'].long()
Raul Puri's avatar
Raul Puri committed
66

67
    return tokens, types, sentence_order, loss_mask, lm_labels, padding_mask
Raul Puri's avatar
Raul Puri committed
68
69


Mohammad's avatar
Mohammad committed
70
def forward_step(data_iterator, model):
Raul Puri's avatar
Raul Puri committed
71
    """Forward step."""
Mohammad's avatar
Mohammad committed
72
    timers = get_timers()
Raul Puri's avatar
Raul Puri committed
73
74

    # Get the batch.
75
    timers('batch generator').start()
76
    tokens, types, sentence_order, loss_mask, lm_labels, padding_mask \
Mohammad's avatar
Mohammad committed
77
        = get_batch(data_iterator)
78
    timers('batch generator').stop()
79

Raul Puri's avatar
Raul Puri committed
80
    # Forward model.
81
    lm_logits, sop_logits = model(tokens, padding_mask, tokentype_ids=types)
82

83
84
    sop_loss = F.cross_entropy(sop_logits.view(-1, 2).contiguous().float(),
                               sentence_order.view(-1).contiguous(),
85
86
                               ignore_index=-1)

87
88
    lm_loss_ = mpu.vocab_parallel_cross_entropy(lm_logits.contiguous().float(),
                                                lm_labels.contiguous())
Raul Puri's avatar
Raul Puri committed
89
    lm_loss = torch.sum(
90
        lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()
Raul Puri's avatar
Raul Puri committed
91

92
    loss = lm_loss + sop_loss
Raul Puri's avatar
Raul Puri committed
93

94
    reduced_losses = reduce_losses([lm_loss, sop_loss])
Raul Puri's avatar
Raul Puri committed
95

96
    return loss, {'lm loss': reduced_losses[0], 'sop loss': reduced_losses[1]}
97
98


Mohammad's avatar
Mohammad committed
99
def get_train_val_test_data():
100
    """Load the data on rank zero and boradcast number of tokens to all GPUS."""
Mohammad's avatar
Mohammad committed
101
    args = get_args()
Mohammad's avatar
Mohammad committed
102

103
    (train_data, valid_data, test_data) = (None, None, None)
104
105
106

    # Data loader only on rank 0 of each model parallel group.
    if mpu.get_model_parallel_rank() == 0:
107
        print_rank_0('> building train, validation, and test datasets '
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
108
                     'for BERT ...')
109
110
111
112
113
114
115
116
117

        data_parallel_size = mpu.get_data_parallel_world_size()
        data_parallel_rank = mpu.get_data_parallel_rank()
        global_batch_size = args.batch_size * data_parallel_size

        # Number of train/valid/test samples.
        train_iters = args.train_iters
        eval_iters = (train_iters // args.eval_interval + 1) * args.eval_iters
        test_iters = args.eval_iters
Mohammad's avatar
Mohammad committed
118
        train_val_test_num_samples = [train_iters * global_batch_size,
119
120
121
122
123
124
125
126
                                      eval_iters * global_batch_size,
                                      test_iters * global_batch_size]
        print_rank_0(' > datasets target sizes (minimum size):')
        print_rank_0('    train:      {}'.format(train_val_test_num_samples[0]))
        print_rank_0('    validation: {}'.format(train_val_test_num_samples[1]))
        print_rank_0('    test:       {}'.format(train_val_test_num_samples[2]))

        train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
Mohammad's avatar
Mohammad committed
127
            data_prefix=args.data_path,
128
129
130
131
132
133
134
            data_impl=args.data_impl,
            splits_string=args.split,
            train_valid_test_num_samples=train_val_test_num_samples,
            max_seq_length=args.seq_length,
            masked_lm_prob=args.mask_prob,
            short_seq_prob=args.short_seq_prob,
            seed=args.seed,
Mohammad's avatar
Mohammad committed
135
            skip_warmup=(not args.mmap_warmup))
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
136
        print_rank_0("> finished creating BERT datasets ...")
137

138
139
140
        train_data = make_data_loader(train_ds)
        valid_data = make_data_loader(valid_ds)
        test_data = make_data_loader(test_ds)
141
142
143
144
145

        do_train = train_data is not None and args.train_iters > 0
        do_valid = valid_data is not None and args.eval_iters > 0
        do_test = test_data is not None and args.eval_iters > 0
        # Need to broadcast num_tokens and num_type_tokens.
Mohammad's avatar
Mohammad committed
146
147
        flags = torch.cuda.LongTensor(
            [int(do_train), int(do_valid), int(do_test)])
148
    else:
Mohammad's avatar
Mohammad committed
149
        flags = torch.cuda.LongTensor([0, 0, 0])
150
151

    # Broadcast num tokens.
Mohammad's avatar
Mohammad committed
152
    torch.distributed.broadcast(flags,
153
154
                                mpu.get_model_parallel_src_rank(),
                                group=mpu.get_model_parallel_group())
Mohammad's avatar
Mohammad committed
155
156
157
    args.do_train = flags[0].item()
    args.do_valid = flags[1].item()
    args.do_test = flags[2].item()
158

159
    return train_data, valid_data, test_data
Raul Puri's avatar
Raul Puri committed
160
161
162


if __name__ == "__main__":
163

Mohammad's avatar
Mohammad committed
164
    pretrain(get_train_val_test_data, model_provider, forward_step,
Mohammad's avatar
Mohammad committed
165
             args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})