pretrain_bert.py 6.18 KB
Newer Older
Raul Puri's avatar
Raul Puri committed
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
Raul Puri's avatar
Raul Puri committed
3
4
5
6
7
8
9
10
11
12
13
14
15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
16
"""Pretrain BERT"""
Raul Puri's avatar
Raul Puri committed
17
18

import torch
19
import torch.nn.functional as F
Raul Puri's avatar
Raul Puri committed
20

Neel Kant's avatar
Neel Kant committed
21
22
from megatron import get_args
from megatron import print_rank_0
Mohammad's avatar
Mohammad committed
23
from megatron import get_timers
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
24
from megatron import mpu
25
from megatron.data.dataset_utils import build_train_valid_test_datasets
26
27
28
29
from megatron.model import (BertModel,
                            BertModelFirstStage,
                            BertModelIntermediateStage,
                            BertModelLastStage)
Mohammad's avatar
Mohammad committed
30
from megatron.training import pretrain
31
from megatron.utils import average_losses_across_data_parallel_group
Mohammad's avatar
Mohammad committed
32
33


Mohammad's avatar
Mohammad committed
34
def model_provider():
Raul Puri's avatar
Raul Puri committed
35
36
    """Build the model."""

37
    print_rank_0('building BERT model ...')
Raul Puri's avatar
Raul Puri committed
38

39
    args = get_args()
40
    num_tokentypes = 2 if args.bert_binary_head else 0
41
    def model_provider_pipelined():
42
        # Determine model based on position of stage in pipeline.
43
        if mpu.is_pipeline_first_stage():
44
            model = BertModelFirstStage(
45
                num_tokentypes=num_tokentypes)
46
        elif mpu.is_pipeline_last_stage():
47
            model = BertModelLastStage(
48
49
                num_tokentypes=num_tokentypes,
                add_binary_head=args.bert_binary_head,
50
51
52
                parallel_output=True)
        else:
            model = BertModelIntermediateStage(
53
                num_tokentypes=num_tokentypes)
54
55
56
57
58
59
60
61
62
63
64
        return model

    args = get_args()
    if mpu.get_pipeline_model_parallel_world_size() > 1:
        if args.virtual_pipeline_model_parallel_size is not None:
            model = []
            for i in range(args.virtual_pipeline_model_parallel_size):
                mpu.set_virtual_pipeline_model_parallel_rank(i)
                model.append(model_provider_pipelined())
        else:
            model = model_provider_pipelined()
65
66
    else:
        model = BertModel(
67
68
            num_tokentypes=num_tokentypes,
            add_binary_head=args.bert_binary_head,
69
            parallel_output=True)
Raul Puri's avatar
Raul Puri committed
70

71
    return model
Raul Puri's avatar
Raul Puri committed
72
73


Mohammad's avatar
Mohammad committed
74
def get_batch(data_iterator):
75
    """Build the batch."""
76

77
    # Items and their type.
78
    keys = ['text', 'types', 'labels', 'is_random', 'loss_mask', 'padding_mask']
79
80
81
82
83
84
85
86
87
88
89
90
    datatype = torch.int64

    # Broadcast data.
    if data_iterator is not None:
        data = next(data_iterator)
    else:
        data = None
    data_b = mpu.broadcast_data(keys, data, datatype)

    # Unpack.
    tokens = data_b['text'].long()
    types = data_b['types'].long()
91
92
93
94
    sentence_order = data_b['is_random'].long()
    loss_mask = data_b['loss_mask'].float()
    lm_labels = data_b['labels'].long()
    padding_mask = data_b['padding_mask'].long()
Raul Puri's avatar
Raul Puri committed
95

96
    return tokens, types, sentence_order, loss_mask, lm_labels, padding_mask
Raul Puri's avatar
Raul Puri committed
97
98


99
def forward_step(data_iterator, model, input_tensor):
Raul Puri's avatar
Raul Puri committed
100
    """Forward step."""
mohammad's avatar
mohammad committed
101
    args = get_args()
Mohammad's avatar
Mohammad committed
102
    timers = get_timers()
Raul Puri's avatar
Raul Puri committed
103
104

    # Get the batch.
mohammad's avatar
mohammad committed
105
    timers('batch-generator').start()
106
107
    tokens, types, sentence_order, loss_mask, lm_labels, padding_mask = get_batch(
        data_iterator)
mohammad's avatar
mohammad committed
108
    timers('batch-generator').stop()
109

110
111
112
    if not args.bert_binary_head:
        types = None

113
    # Forward pass through the model.
114
    if mpu.is_pipeline_first_stage():
115
        assert input_tensor is None
116
        if mpu.is_pipeline_last_stage():
117
118
119
120
            output_tensor = model(tokens, padding_mask, tokentype_ids=types,
                                  lm_labels=lm_labels)
        else:
            output_tensor = model(tokens, padding_mask, tokentype_ids=types)
121
    elif mpu.is_pipeline_last_stage():
122
123
124
125
126
127
        assert input_tensor is not None
        output_tensor = model(input_tensor, padding_mask, lm_labels=lm_labels)
    else:
        assert input_tensor is not None
        output_tensor = model(input_tensor, padding_mask)

128
    if mpu.is_pipeline_last_stage():
129
        lm_loss_, sop_logits = output_tensor
130
        
131
132
133
        lm_loss_ = lm_loss_.float()
        loss_mask = loss_mask.float()
        lm_loss = torch.sum(
134
            lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
        
        if sop_logits is not None:
            sop_loss = F.cross_entropy(sop_logits.view(-1, 2).float(),
                                       sentence_order.view(-1),
                                       ignore_index=-1)
            sop_loss = sop_loss.float()
            loss = lm_loss + sop_loss
            averaged_losses = average_losses_across_data_parallel_group(
                [lm_loss, sop_loss])
            return loss, {'lm loss': averaged_losses[0],
                          'sop loss': averaged_losses[1]}
            
        else:
            loss = lm_loss
            averaged_losses = average_losses_across_data_parallel_group(
                [lm_loss])
            return loss, {'lm loss': averaged_losses[0]}
Raul Puri's avatar
Raul Puri committed
152

153
    return output_tensor
154
155


156
157
def train_valid_test_datasets_provider(train_val_test_num_samples):
    """Build train, valid, and test datasets."""
Mohammad's avatar
Mohammad committed
158
    args = get_args()
Mohammad's avatar
Mohammad committed
159

160
161
162
163
164
165
166
167
168
169
170
    print_rank_0('> building train, validation, and test datasets '
                 'for BERT ...')
    train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
        data_prefix=args.data_path,
        data_impl=args.data_impl,
        splits_string=args.split,
        train_valid_test_num_samples=train_val_test_num_samples,
        max_seq_length=args.seq_length,
        masked_lm_prob=args.mask_prob,
        short_seq_prob=args.short_seq_prob,
        seed=args.seed,
171
172
        skip_warmup=(not args.mmap_warmup),
        binary_head=args.bert_binary_head)
173
    print_rank_0("> finished creating BERT datasets ...")
174

175
    return train_ds, valid_ds, test_ds
Raul Puri's avatar
Raul Puri committed
176
177
178


if __name__ == "__main__":
179

180
    pretrain(train_valid_test_datasets_provider, model_provider, forward_step,
Mohammad's avatar
Mohammad committed
181
             args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})