pretrain_bert.py 4.14 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved.
Raul Puri's avatar
Raul Puri committed
2

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
3
"""Pretrain BERT"""
Raul Puri's avatar
Raul Puri committed
4

5
6
from functools import partial

Raul Puri's avatar
Raul Puri committed
7
import torch
8
import torch.nn.functional as F
9

Neel Kant's avatar
Neel Kant committed
10
11
from megatron import get_args
from megatron import print_rank_0
Mohammad's avatar
Mohammad committed
12
from megatron import get_timers
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
13
from megatron import mpu
14
from megatron.data.dataset_utils import build_train_valid_test_datasets
15
from megatron.model import BertModel, ModelType
Mohammad's avatar
Mohammad committed
16
from megatron.training import pretrain
17
from megatron.utils import average_losses_across_data_parallel_group
Mohammad's avatar
Mohammad committed
18
19


20
def model_provider(pre_process=True, post_process=True):
Raul Puri's avatar
Raul Puri committed
21
22
    """Build the model."""

23
    print_rank_0('building BERT model ...')
Raul Puri's avatar
Raul Puri committed
24

25
    args = get_args()
26
    num_tokentypes = 2 if args.bert_binary_head else 0
27
28
29
30
31
32
    model = BertModel(
        num_tokentypes=num_tokentypes,
        add_binary_head=args.bert_binary_head,
        parallel_output=True,
        pre_process=pre_process,
        post_process=post_process)
Raul Puri's avatar
Raul Puri committed
33

34
    return model
Raul Puri's avatar
Raul Puri committed
35
36


Mohammad's avatar
Mohammad committed
37
def get_batch(data_iterator):
38
    """Build the batch."""
39

40
    # Items and their type.
41
    keys = ['text', 'types', 'labels', 'is_random', 'loss_mask', 'padding_mask']
42
43
44
45
46
47
48
49
50
51
52
53
    datatype = torch.int64

    # Broadcast data.
    if data_iterator is not None:
        data = next(data_iterator)
    else:
        data = None
    data_b = mpu.broadcast_data(keys, data, datatype)

    # Unpack.
    tokens = data_b['text'].long()
    types = data_b['types'].long()
54
55
56
57
    sentence_order = data_b['is_random'].long()
    loss_mask = data_b['loss_mask'].float()
    lm_labels = data_b['labels'].long()
    padding_mask = data_b['padding_mask'].long()
Raul Puri's avatar
Raul Puri committed
58

59
    return tokens, types, sentence_order, loss_mask, lm_labels, padding_mask
Raul Puri's avatar
Raul Puri committed
60
61


62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
def loss_func(loss_mask, sentence_order, output_tensor):
    lm_loss_, sop_logits = output_tensor

    lm_loss_ = lm_loss_.float()
    loss_mask = loss_mask.float()
    lm_loss = torch.sum(
        lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()

    if sop_logits is not None:
        sop_loss = F.cross_entropy(sop_logits.view(-1, 2).float(),
                                   sentence_order.view(-1),
                                   ignore_index=-1)
        sop_loss = sop_loss.float()
        loss = lm_loss + sop_loss
        averaged_losses = average_losses_across_data_parallel_group(
            [lm_loss, sop_loss])
        return loss, {'lm loss': averaged_losses[0],
                      'sop loss': averaged_losses[1]}

    else:
        loss = lm_loss
        averaged_losses = average_losses_across_data_parallel_group(
            [lm_loss])
        return loss, {'lm loss': averaged_losses[0]}


def forward_step(data_iterator, model):
Raul Puri's avatar
Raul Puri committed
89
    """Forward step."""
mohammad's avatar
mohammad committed
90
    args = get_args()
Mohammad's avatar
Mohammad committed
91
    timers = get_timers()
Raul Puri's avatar
Raul Puri committed
92
93

    # Get the batch.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
94
    timers('batch-generator', log_level=2).start()
95
96
    tokens, types, sentence_order, loss_mask, lm_labels, padding_mask = get_batch(
        data_iterator)
mohammad's avatar
mohammad committed
97
    timers('batch-generator').stop()
98

99
100
101
    if not args.bert_binary_head:
        types = None

102
    # Forward pass through the model.
103
104
105
106
    output_tensor = model(tokens, padding_mask, tokentype_ids=types,
                          lm_labels=lm_labels)

    return output_tensor, partial(loss_func, loss_mask, sentence_order)
107
108


109
110
def train_valid_test_datasets_provider(train_val_test_num_samples):
    """Build train, valid, and test datasets."""
Mohammad's avatar
Mohammad committed
111
    args = get_args()
Mohammad's avatar
Mohammad committed
112

113
114
115
116
117
118
119
120
121
122
123
    print_rank_0('> building train, validation, and test datasets '
                 'for BERT ...')
    train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
        data_prefix=args.data_path,
        data_impl=args.data_impl,
        splits_string=args.split,
        train_valid_test_num_samples=train_val_test_num_samples,
        max_seq_length=args.seq_length,
        masked_lm_prob=args.mask_prob,
        short_seq_prob=args.short_seq_prob,
        seed=args.seed,
124
125
        skip_warmup=(not args.mmap_warmup),
        binary_head=args.bert_binary_head)
126
    print_rank_0("> finished creating BERT datasets ...")
127

128
    return train_ds, valid_ds, test_ds
Raul Puri's avatar
Raul Puri committed
129
130
131


if __name__ == "__main__":
132

133
134
135
    pretrain(train_valid_test_datasets_provider, model_provider,
             ModelType.encoder_or_decoder,
             forward_step, args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})