pretrain_bert.py 5.14 KB
Newer Older
Raul Puri's avatar
Raul Puri committed
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
Raul Puri's avatar
Raul Puri committed
3
4
5
6
7
8
9
10
11
12
13
14
15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
16
"""Pretrain BERT"""
Raul Puri's avatar
Raul Puri committed
17
18

import torch
19
import torch.nn.functional as F
Raul Puri's avatar
Raul Puri committed
20

Neel Kant's avatar
Neel Kant committed
21
22
from megatron import get_args
from megatron import print_rank_0
Mohammad's avatar
Mohammad committed
23
from megatron import get_timers
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
24
from megatron import mpu
25
from megatron.data.dataset_utils import build_train_valid_test_datasets
26
from megatron.model import BertModel, BertModelFirstStage, BertModelIntermediateStage, BertModelLastStage
Mohammad's avatar
Mohammad committed
27
from megatron.training import pretrain
28
from megatron.utils import average_losses_across_data_parallel_group
Mohammad's avatar
Mohammad committed
29
30


Mohammad's avatar
Mohammad committed
31
def model_provider():
Raul Puri's avatar
Raul Puri committed
32
33
    """Build the model."""

34
    print_rank_0('building BERT model ...')
Raul Puri's avatar
Raul Puri committed
35

36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
    args = get_args()
    if args.inter_layer_model_parallel_size > 1:
        # Determine model based on position of stage in pipeline.
        if mpu.is_inter_layer_first_stage():
            model = BertModelFirstStage(
                num_tokentypes=2)
        elif mpu.is_inter_layer_last_stage():
            model = BertModelLastStage(
                num_tokentypes=2,
                add_binary_head=True,
                parallel_output=True)
        else:
            model = BertModelIntermediateStage(
                num_tokentypes=2)
    else:
        model = BertModel(
            num_tokentypes=2,
            add_binary_head=True,
            parallel_output=True)
Raul Puri's avatar
Raul Puri committed
55

56
    return model
Raul Puri's avatar
Raul Puri committed
57
58


Mohammad's avatar
Mohammad committed
59
def get_batch(data_iterator):
60
    """Build the batch."""
61

62
    # Items and their type.
63
    keys = ['text', 'types', 'labels', 'is_random', 'loss_mask', 'padding_mask']
64
65
66
67
68
69
70
71
72
73
74
75
    datatype = torch.int64

    # Broadcast data.
    if data_iterator is not None:
        data = next(data_iterator)
    else:
        data = None
    data_b = mpu.broadcast_data(keys, data, datatype)

    # Unpack.
    tokens = data_b['text'].long()
    types = data_b['types'].long()
76
77
78
79
    sentence_order = data_b['is_random'].long()
    loss_mask = data_b['loss_mask'].float()
    lm_labels = data_b['labels'].long()
    padding_mask = data_b['padding_mask'].long()
Raul Puri's avatar
Raul Puri committed
80

81
    return tokens, types, sentence_order, loss_mask, lm_labels, padding_mask
Raul Puri's avatar
Raul Puri committed
82
83


84
def forward_step(data_iterator, model, input_tensor):
Raul Puri's avatar
Raul Puri committed
85
    """Forward step."""
mohammad's avatar
mohammad committed
86
    args = get_args()
Mohammad's avatar
Mohammad committed
87
    timers = get_timers()
Raul Puri's avatar
Raul Puri committed
88
89

    # Get the batch.
90
    timers('batch generator').start()
91
    tokens, types, sentence_order, loss_mask, lm_labels, padding_mask \
Mohammad's avatar
Mohammad committed
92
        = get_batch(data_iterator)
93
    timers('batch generator').stop()
94

95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
    # Forward pass through the model.
    if mpu.is_inter_layer_first_stage():
        assert input_tensor is None
        if mpu.is_inter_layer_last_stage():
            output_tensor = model(tokens, padding_mask, tokentype_ids=types,
                                  lm_labels=lm_labels)
        else:
            output_tensor = model(tokens, padding_mask, tokentype_ids=types)
    elif mpu.is_inter_layer_last_stage():
        assert input_tensor is not None
        output_tensor = model(input_tensor, padding_mask, lm_labels=lm_labels)
    else:
        assert input_tensor is not None
        output_tensor = model(input_tensor, padding_mask)

    if mpu.is_inter_layer_last_stage():
        lm_loss_, sop_logits = output_tensor
112

113
114
115
116
        sop_loss = F.cross_entropy(sop_logits.view(-1, 2).float(),
                                   sentence_order.view(-1),
                                   ignore_index=-1)
        sop_loss = sop_loss.float()
117

118
119
120
121
        lm_loss_ = lm_loss_.float()
        loss_mask = loss_mask.float()
        lm_loss = torch.sum(
            lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()
Raul Puri's avatar
Raul Puri committed
122

123
        loss = lm_loss + sop_loss
Raul Puri's avatar
Raul Puri committed
124

125
        averaged_losses = average_losses_across_data_parallel_group([lm_loss, sop_loss])
Raul Puri's avatar
Raul Puri committed
126

127
128
        return loss, {'lm loss': averaged_losses[0], 'sop loss': averaged_losses[1]}
    return output_tensor
129
130


131
132
def train_valid_test_datasets_provider(train_val_test_num_samples):
    """Build train, valid, and test datasets."""
Mohammad's avatar
Mohammad committed
133
    args = get_args()
Mohammad's avatar
Mohammad committed
134

135
136
137
138
139
140
141
142
143
144
145
146
147
    print_rank_0('> building train, validation, and test datasets '
                 'for BERT ...')
    train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
        data_prefix=args.data_path,
        data_impl=args.data_impl,
        splits_string=args.split,
        train_valid_test_num_samples=train_val_test_num_samples,
        max_seq_length=args.seq_length,
        masked_lm_prob=args.mask_prob,
        short_seq_prob=args.short_seq_prob,
        seed=args.seed,
        skip_warmup=(not args.mmap_warmup))
    print_rank_0("> finished creating BERT datasets ...")
148

149
    return train_ds, valid_ds, test_ds
Raul Puri's avatar
Raul Puri committed
150
151
152


if __name__ == "__main__":
153

154
    pretrain(train_valid_test_datasets_provider, model_provider, forward_step,
Mohammad's avatar
Mohammad committed
155
             args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})