"test/srt/test_two_batch_overlap.py" did not exist on "a68cb201dd5f4ae6155b324d22054bbb0de15fba"
pretrain_bert.py 4.63 KB
Newer Older
Raul Puri's avatar
Raul Puri committed
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
Raul Puri's avatar
Raul Puri committed
3
4
5
6
7
8
9
10
11
12
13
14
15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
16
"""Pretrain BERT"""
Raul Puri's avatar
Raul Puri committed
17

18
19
from functools import partial

Raul Puri's avatar
Raul Puri committed
20
import torch
21
import torch.nn.functional as F
22

Neel Kant's avatar
Neel Kant committed
23
24
from megatron import get_args
from megatron import print_rank_0
Mohammad's avatar
Mohammad committed
25
from megatron import get_timers
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
26
from megatron import mpu
27
from megatron.data.dataset_utils import build_train_valid_test_datasets
28
from megatron.model import BertModel
Mohammad's avatar
Mohammad committed
29
from megatron.training import pretrain
30
from megatron.utils import average_losses_across_data_parallel_group
Mohammad's avatar
Mohammad committed
31
32


33
def model_provider(pre_process=True, post_process=True):
Raul Puri's avatar
Raul Puri committed
34
35
    """Build the model."""

36
    print_rank_0('building BERT model ...')
Raul Puri's avatar
Raul Puri committed
37

38
    args = get_args()
39
    num_tokentypes = 2 if args.bert_binary_head else 0
40
41
42
43
44
45
    model = BertModel(
        num_tokentypes=num_tokentypes,
        add_binary_head=args.bert_binary_head,
        parallel_output=True,
        pre_process=pre_process,
        post_process=post_process)
Raul Puri's avatar
Raul Puri committed
46

47
    return model
Raul Puri's avatar
Raul Puri committed
48
49


Mohammad's avatar
Mohammad committed
50
def get_batch(data_iterator):
51
    """Build the batch."""
52

53
    # Items and their type.
54
    keys = ['text', 'types', 'labels', 'is_random', 'loss_mask', 'padding_mask']
55
56
57
58
59
60
61
62
63
64
65
66
    datatype = torch.int64

    # Broadcast data.
    if data_iterator is not None:
        data = next(data_iterator)
    else:
        data = None
    data_b = mpu.broadcast_data(keys, data, datatype)

    # Unpack.
    tokens = data_b['text'].long()
    types = data_b['types'].long()
67
68
69
70
    sentence_order = data_b['is_random'].long()
    loss_mask = data_b['loss_mask'].float()
    lm_labels = data_b['labels'].long()
    padding_mask = data_b['padding_mask'].long()
Raul Puri's avatar
Raul Puri committed
71

72
    return tokens, types, sentence_order, loss_mask, lm_labels, padding_mask
Raul Puri's avatar
Raul Puri committed
73
74


75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
def loss_func(loss_mask, sentence_order, output_tensor):
    lm_loss_, sop_logits = output_tensor

    lm_loss_ = lm_loss_.float()
    loss_mask = loss_mask.float()
    lm_loss = torch.sum(
        lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()

    if sop_logits is not None:
        sop_loss = F.cross_entropy(sop_logits.view(-1, 2).float(),
                                   sentence_order.view(-1),
                                   ignore_index=-1)
        sop_loss = sop_loss.float()
        loss = lm_loss + sop_loss
        averaged_losses = average_losses_across_data_parallel_group(
            [lm_loss, sop_loss])
        return loss, {'lm loss': averaged_losses[0],
                      'sop loss': averaged_losses[1]}

    else:
        loss = lm_loss
        averaged_losses = average_losses_across_data_parallel_group(
            [lm_loss])
        return loss, {'lm loss': averaged_losses[0]}


def forward_step(data_iterator, model):
Raul Puri's avatar
Raul Puri committed
102
    """Forward step."""
mohammad's avatar
mohammad committed
103
    args = get_args()
Mohammad's avatar
Mohammad committed
104
    timers = get_timers()
Raul Puri's avatar
Raul Puri committed
105
106

    # Get the batch.
mohammad's avatar
mohammad committed
107
    timers('batch-generator').start()
108
109
    tokens, types, sentence_order, loss_mask, lm_labels, padding_mask = get_batch(
        data_iterator)
mohammad's avatar
mohammad committed
110
    timers('batch-generator').stop()
111

112
113
114
    if not args.bert_binary_head:
        types = None

115
    # Forward pass through the model.
116
117
118
119
    output_tensor = model(tokens, padding_mask, tokentype_ids=types,
                          lm_labels=lm_labels)

    return output_tensor, partial(loss_func, loss_mask, sentence_order)
120
121


122
123
def train_valid_test_datasets_provider(train_val_test_num_samples):
    """Build train, valid, and test datasets."""
Mohammad's avatar
Mohammad committed
124
    args = get_args()
Mohammad's avatar
Mohammad committed
125

126
127
128
129
130
131
132
133
134
135
136
    print_rank_0('> building train, validation, and test datasets '
                 'for BERT ...')
    train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
        data_prefix=args.data_path,
        data_impl=args.data_impl,
        splits_string=args.split,
        train_valid_test_num_samples=train_val_test_num_samples,
        max_seq_length=args.seq_length,
        masked_lm_prob=args.mask_prob,
        short_seq_prob=args.short_seq_prob,
        seed=args.seed,
137
138
        skip_warmup=(not args.mmap_warmup),
        binary_head=args.bert_binary_head)
139
    print_rank_0("> finished creating BERT datasets ...")
140

141
    return train_ds, valid_ds, test_ds
Raul Puri's avatar
Raul Puri committed
142
143
144


if __name__ == "__main__":
145

146
    pretrain(train_valid_test_datasets_provider, model_provider, forward_step,
Mohammad's avatar
Mohammad committed
147
             args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})