pretrain_bert.py 4.2 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved.
Raul Puri's avatar
Raul Puri committed
2

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
3
"""Pretrain BERT"""
Raul Puri's avatar
Raul Puri committed
4

5
6
from functools import partial

Raul Puri's avatar
Raul Puri committed
7
import torch
8
import torch.nn.functional as F
9

Neel Kant's avatar
Neel Kant committed
10
11
from megatron import get_args
from megatron import print_rank_0
Mohammad's avatar
Mohammad committed
12
from megatron import get_timers
13
from megatron.core import tensor_parallel
14
from megatron.core.enums import ModelType
15
from megatron.data.dataset_utils import build_train_valid_test_datasets
16
from megatron.model import BertModel
Mohammad's avatar
Mohammad committed
17
from megatron.training import pretrain
18
from megatron.utils import average_losses_across_data_parallel_group
Mohammad's avatar
Mohammad committed
19
20


21
def model_provider(pre_process=True, post_process=True):
Raul Puri's avatar
Raul Puri committed
22
23
    """Build the model."""

24
    print_rank_0('building BERT model ...')
Raul Puri's avatar
Raul Puri committed
25

26
    args = get_args()
27
    num_tokentypes = 2 if args.bert_binary_head else 0
28
29
30
31
32
33
    model = BertModel(
        num_tokentypes=num_tokentypes,
        add_binary_head=args.bert_binary_head,
        parallel_output=True,
        pre_process=pre_process,
        post_process=post_process)
Raul Puri's avatar
Raul Puri committed
34

35
    return model
Raul Puri's avatar
Raul Puri committed
36
37


Mohammad's avatar
Mohammad committed
38
def get_batch(data_iterator):
39
    """Build the batch."""
40

41
    # Items and their type.
42
    keys = ['text', 'types', 'labels', 'is_random', 'loss_mask', 'padding_mask']
43
44
45
46
47
48
49
    datatype = torch.int64

    # Broadcast data.
    if data_iterator is not None:
        data = next(data_iterator)
    else:
        data = None
50
    data_b = tensor_parallel.broadcast_data(keys, data, datatype)
51
52
53
54

    # Unpack.
    tokens = data_b['text'].long()
    types = data_b['types'].long()
55
56
57
58
    sentence_order = data_b['is_random'].long()
    loss_mask = data_b['loss_mask'].float()
    lm_labels = data_b['labels'].long()
    padding_mask = data_b['padding_mask'].long()
Raul Puri's avatar
Raul Puri committed
59

60
    return tokens, types, sentence_order, loss_mask, lm_labels, padding_mask
Raul Puri's avatar
Raul Puri committed
61
62


63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
def loss_func(loss_mask, sentence_order, output_tensor):
    lm_loss_, sop_logits = output_tensor

    lm_loss_ = lm_loss_.float()
    loss_mask = loss_mask.float()
    lm_loss = torch.sum(
        lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()

    if sop_logits is not None:
        sop_loss = F.cross_entropy(sop_logits.view(-1, 2).float(),
                                   sentence_order.view(-1),
                                   ignore_index=-1)
        sop_loss = sop_loss.float()
        loss = lm_loss + sop_loss
        averaged_losses = average_losses_across_data_parallel_group(
            [lm_loss, sop_loss])
        return loss, {'lm loss': averaged_losses[0],
                      'sop loss': averaged_losses[1]}

    else:
        loss = lm_loss
        averaged_losses = average_losses_across_data_parallel_group(
            [lm_loss])
        return loss, {'lm loss': averaged_losses[0]}


def forward_step(data_iterator, model):
Raul Puri's avatar
Raul Puri committed
90
    """Forward step."""
mohammad's avatar
mohammad committed
91
    args = get_args()
Mohammad's avatar
Mohammad committed
92
    timers = get_timers()
Raul Puri's avatar
Raul Puri committed
93
94

    # Get the batch.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
95
    timers('batch-generator', log_level=2).start()
96
97
    tokens, types, sentence_order, loss_mask, lm_labels, padding_mask = get_batch(
        data_iterator)
mohammad's avatar
mohammad committed
98
    timers('batch-generator').stop()
99

100
101
102
    if not args.bert_binary_head:
        types = None

103
    # Forward pass through the model.
104
105
106
107
    output_tensor = model(tokens, padding_mask, tokentype_ids=types,
                          lm_labels=lm_labels)

    return output_tensor, partial(loss_func, loss_mask, sentence_order)
108
109


110
111
def train_valid_test_datasets_provider(train_val_test_num_samples):
    """Build train, valid, and test datasets."""
Mohammad's avatar
Mohammad committed
112
    args = get_args()
Mohammad's avatar
Mohammad committed
113

114
115
116
117
118
119
120
121
122
123
124
    print_rank_0('> building train, validation, and test datasets '
                 'for BERT ...')
    train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
        data_prefix=args.data_path,
        data_impl=args.data_impl,
        splits_string=args.split,
        train_valid_test_num_samples=train_val_test_num_samples,
        max_seq_length=args.seq_length,
        masked_lm_prob=args.mask_prob,
        short_seq_prob=args.short_seq_prob,
        seed=args.seed,
125
126
        skip_warmup=(not args.mmap_warmup),
        binary_head=args.bert_binary_head)
127
    print_rank_0("> finished creating BERT datasets ...")
128

129
    return train_ds, valid_ds, test_ds
Raul Puri's avatar
Raul Puri committed
130
131
132


if __name__ == "__main__":
133

134
135
136
    pretrain(train_valid_test_datasets_provider, model_provider,
             ModelType.encoder_or_decoder,
             forward_step, args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})