pretrain_bert.py 4.62 KB
Newer Older
Raul Puri's avatar
Raul Puri committed
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
Raul Puri's avatar
Raul Puri committed
3
4
5
6
7
8
9
10
11
12
13
14
15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
16
"""Pretrain BERT"""
Raul Puri's avatar
Raul Puri committed
17
18

import torch
19
import torch.nn.functional as F
20
from functools import partial
Neel Kant's avatar
Neel Kant committed
21
22
from megatron import get_args
from megatron import print_rank_0
Mohammad's avatar
Mohammad committed
23
from megatron import get_timers
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
24
from megatron import mpu
25
from megatron.data.dataset_utils import build_train_valid_test_datasets
26
from megatron.model import BertModel
Mohammad's avatar
Mohammad committed
27
from megatron.training import pretrain
28
from megatron.utils import average_losses_across_data_parallel_group
Mohammad's avatar
Mohammad committed
29
30


31
def model_provider(pre_process=True, post_process=True):
Raul Puri's avatar
Raul Puri committed
32
33
    """Build the model."""

34
    print_rank_0('building BERT model ...')
Raul Puri's avatar
Raul Puri committed
35

36
    args = get_args()
37
    num_tokentypes = 2 if args.bert_binary_head else 0
38
39
40
41
42
43
    model = BertModel(
        num_tokentypes=num_tokentypes,
        add_binary_head=args.bert_binary_head,
        parallel_output=True,
        pre_process=pre_process,
        post_process=post_process)
Raul Puri's avatar
Raul Puri committed
44

45
    return model
Raul Puri's avatar
Raul Puri committed
46
47


Mohammad's avatar
Mohammad committed
48
def get_batch(data_iterator):
49
    """Build the batch."""
50

51
    # Items and their type.
52
    keys = ['text', 'types', 'labels', 'is_random', 'loss_mask', 'padding_mask']
53
54
55
56
57
58
59
60
61
62
63
64
    datatype = torch.int64

    # Broadcast data.
    if data_iterator is not None:
        data = next(data_iterator)
    else:
        data = None
    data_b = mpu.broadcast_data(keys, data, datatype)

    # Unpack.
    tokens = data_b['text'].long()
    types = data_b['types'].long()
65
66
67
68
    sentence_order = data_b['is_random'].long()
    loss_mask = data_b['loss_mask'].float()
    lm_labels = data_b['labels'].long()
    padding_mask = data_b['padding_mask'].long()
Raul Puri's avatar
Raul Puri committed
69

70
    return tokens, types, sentence_order, loss_mask, lm_labels, padding_mask
Raul Puri's avatar
Raul Puri committed
71
72


73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
def loss_func(loss_mask, sentence_order, output_tensor):
    lm_loss_, sop_logits = output_tensor

    lm_loss_ = lm_loss_.float()
    loss_mask = loss_mask.float()
    lm_loss = torch.sum(
        lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()

    if sop_logits is not None:
        sop_loss = F.cross_entropy(sop_logits.view(-1, 2).float(),
                                   sentence_order.view(-1),
                                   ignore_index=-1)
        sop_loss = sop_loss.float()
        loss = lm_loss + sop_loss
        averaged_losses = average_losses_across_data_parallel_group(
            [lm_loss, sop_loss])
        return loss, {'lm loss': averaged_losses[0],
                      'sop loss': averaged_losses[1]}

    else:
        loss = lm_loss
        averaged_losses = average_losses_across_data_parallel_group(
            [lm_loss])
        return loss, {'lm loss': averaged_losses[0]}


def forward_step(data_iterator, model):
Raul Puri's avatar
Raul Puri committed
100
    """Forward step."""
mohammad's avatar
mohammad committed
101
    args = get_args()
Mohammad's avatar
Mohammad committed
102
    timers = get_timers()
Raul Puri's avatar
Raul Puri committed
103
104

    # Get the batch.
mohammad's avatar
mohammad committed
105
    timers('batch-generator').start()
106
107
    tokens, types, sentence_order, loss_mask, lm_labels, padding_mask = get_batch(
        data_iterator)
mohammad's avatar
mohammad committed
108
    timers('batch-generator').stop()
109

110
111
112
    if not args.bert_binary_head:
        types = None

113
    # Forward pass through the model.
114
115
116
117
    output_tensor = model(tokens, padding_mask, tokentype_ids=types,
                          lm_labels=lm_labels)

    return output_tensor, partial(loss_func, loss_mask, sentence_order)
118
119


120
121
def train_valid_test_datasets_provider(train_val_test_num_samples):
    """Build train, valid, and test datasets."""
Mohammad's avatar
Mohammad committed
122
    args = get_args()
Mohammad's avatar
Mohammad committed
123

124
125
126
127
128
129
130
131
132
133
134
    print_rank_0('> building train, validation, and test datasets '
                 'for BERT ...')
    train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
        data_prefix=args.data_path,
        data_impl=args.data_impl,
        splits_string=args.split,
        train_valid_test_num_samples=train_val_test_num_samples,
        max_seq_length=args.seq_length,
        masked_lm_prob=args.mask_prob,
        short_seq_prob=args.short_seq_prob,
        seed=args.seed,
135
136
        skip_warmup=(not args.mmap_warmup),
        binary_head=args.bert_binary_head)
137
    print_rank_0("> finished creating BERT datasets ...")
138

139
    return train_ds, valid_ds, test_ds
Raul Puri's avatar
Raul Puri committed
140
141
142


if __name__ == "__main__":
143

144
    pretrain(train_valid_test_datasets_provider, model_provider, forward_step,
Mohammad's avatar
Mohammad committed
145
             args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})