pretrain_bert.py 5.48 KB
Newer Older
Raul Puri's avatar
Raul Puri committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Pretrain BERT"""

import torch
19
import torch.nn.functional as F
Raul Puri's avatar
Raul Puri committed
20
21

from configure_data import configure_data
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
22
from megatron import mpu
23
from megatron.model import BertModel
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
24
from megatron.utils import print_rank_0
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
25
from megatron.utils import reduce_losses
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
26
from megatron.utils import vocab_size_with_padding
27
28
from megatron.training import run

Raul Puri's avatar
Raul Puri committed
29

30
def model_provider(args):
Raul Puri's avatar
Raul Puri committed
31
32
    """Build the model."""

33
    print_rank_0('building BERT model ...')
Raul Puri's avatar
Raul Puri committed
34

35
36
37
38
39
40
41
42
43
44
45
46
47
48
    model = BertModel(
        num_layers=args.num_layers,
        vocab_size=args.vocab_size,
        hidden_size=args.hidden_size,
        num_attention_heads=args.num_attention_heads,
        embedding_dropout_prob=args.hidden_dropout,
        attention_dropout_prob=args.attention_dropout,
        output_dropout_prob=args.hidden_dropout,
        max_sequence_length=args.max_position_embeddings,
        checkpoint_activations=args.checkpoint_activations,
        checkpoint_num_layers=args.checkpoint_num_layers,
        add_binary_head=True,
        layernorm_epsilon=args.layernorm_epsilon,
        num_tokentypes=args.tokentype_size,
49
50
51
        parallel_output=True,
        apply_query_key_layer_scaling=args.apply_query_key_layer_scaling,
        attention_softmax_in_fp32=args.attention_softmax_in_fp32)
Raul Puri's avatar
Raul Puri committed
52

53
    return model
Raul Puri's avatar
Raul Puri committed
54
55


56
def get_batch(data_iterator, timers):
57

58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
    # Items and their type.
    keys = ['text', 'types', 'is_random', 'mask', 'mask_labels', 'pad_mask']
    datatype = torch.int64

    # Broadcast data.
    timers('data loader').start()
    if data_iterator is not None:
        data = next(data_iterator)
    else:
        data = None
    timers('data loader').stop()
    data_b = mpu.broadcast_data(keys, data, datatype)

    # Unpack.
    tokens = data_b['text'].long()
    types = data_b['types'].long()
    next_sentence = data_b['is_random'].long()
    loss_mask = data_b['mask'].float()
    lm_labels = data_b['mask_labels'].long()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
77
    padding_mask = data_b['pad_mask'].long()
Raul Puri's avatar
Raul Puri committed
78
79
80
81

    return tokens, types, next_sentence, loss_mask, lm_labels, padding_mask


82
def forward_step(data_iterator, model, args, timers):
Raul Puri's avatar
Raul Puri committed
83
84
85
    """Forward step."""

    # Get the batch.
86
    timers('batch generator').start()
87
88
    tokens, types, next_sentence, loss_mask, lm_labels, padding_mask \
        = get_batch(data_iterator, timers)
89
    timers('batch generator').stop()
90

Raul Puri's avatar
Raul Puri committed
91
    # Forward model.
92
    lm_logits, nsp_logits = model(tokens, 1-padding_mask, tokentype_ids=types)
93

94
    nsp_loss = F.cross_entropy(nsp_logits.view(-1, 2).contiguous().float(),
95
96
97
                               next_sentence.view(-1).contiguous(),
                               ignore_index=-1)

98
99
    lm_loss_ = mpu.vocab_parallel_cross_entropy(lm_logits.contiguous().float(),
                                                lm_labels.contiguous())
Raul Puri's avatar
Raul Puri committed
100
    lm_loss = torch.sum(
101
        lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()
Raul Puri's avatar
Raul Puri committed
102
103
104

    loss = lm_loss + nsp_loss

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
105
    reduced_losses = reduce_losses([lm_loss, nsp_loss])
Raul Puri's avatar
Raul Puri committed
106

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
107
    return loss, {'lm loss': reduced_losses[0], 'nsp loss': reduced_losses[1]}
108
109
110
111
112
113
114
115
116


def get_train_val_test_data(args):
    """Load the data on rank zero and boradcast number of tokens to all GPUS."""

    (train_data, val_data, test_data) = (None, None, None)

    # Data loader only on rank 0 of each model parallel group.
    if mpu.get_model_parallel_rank() == 0:
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
        if (args.data_loader == 'raw'
            or args.data_loader == 'lazy'
            or args.data_loader == 'tfrecords'):
            data_config = configure_data()
            ds_type = 'BERT'
            data_config.set_defaults(data_set_type=ds_type, transpose=False)
            (train_data, val_data, test_data), tokenizer = data_config.apply(args)
            num_tokens = vocab_size_with_padding(tokenizer.num_tokens, args)
            # Need to broadcast num_tokens and num_type_tokens.
            token_counts = torch.cuda.LongTensor([num_tokens,
                                                  tokenizer.num_type_tokens,
                                                  int(args.do_train),
                                                  int(args.do_valid),
                                                  int(args.do_test)])
        else:
            print("Unsupported data loader for BERT.")
            exit(1)
134
135
136
137
138
139
140
141
142
143
144
145
146
    else:
        token_counts = torch.cuda.LongTensor([0, 0, 0, 0, 0])

    # Broadcast num tokens.
    torch.distributed.broadcast(token_counts,
                                mpu.get_model_parallel_src_rank(),
                                group=mpu.get_model_parallel_group())
    num_tokens = token_counts[0].item()
    num_type_tokens = token_counts[1].item()
    args.do_train = token_counts[2].item()
    args.do_valid = token_counts[3].item()
    args.do_test = token_counts[4].item()

147
148
    args.vocab_size = num_tokens
    args.tokentype_size = num_type_tokens
149

150
    return train_data, val_data, test_data
Raul Puri's avatar
Raul Puri committed
151
152
153


if __name__ == "__main__":
154
155
156

    run('Pretrain BERT model', get_train_val_test_data,
        model_provider, forward_step)