pretrain_bert.py 4.26 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved.
Raul Puri's avatar
Raul Puri committed
2

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
3
"""Pretrain BERT"""
Raul Puri's avatar
Raul Puri committed
4

5
6
from functools import partial

Raul Puri's avatar
Raul Puri committed
7
import torch
8
import torch.nn.functional as F
9

Neel Kant's avatar
Neel Kant committed
10
11
from megatron import get_args
from megatron import print_rank_0
Mohammad's avatar
Mohammad committed
12
from megatron import get_timers
13
from megatron.core import tensor_parallel
14
from megatron.core.enums import ModelType
15
from megatron.data.dataset_utils import build_train_valid_test_datasets
16
from megatron.model import BertModel
Mohammad's avatar
Mohammad committed
17
from megatron.training import pretrain
18
from megatron.utils import average_losses_across_data_parallel_group
liangjing's avatar
v1  
liangjing committed
19
from megatron.arguments import core_transformer_config_from_args
Mohammad's avatar
Mohammad committed
20
21


22
def model_provider(pre_process=True, post_process=True):
Raul Puri's avatar
Raul Puri committed
23
24
    """Build the model."""

25
    print_rank_0('building BERT model ...')
Raul Puri's avatar
Raul Puri committed
26

27
    args = get_args()
liangjing's avatar
v1  
liangjing committed
28
    config = core_transformer_config_from_args(args)
29
    num_tokentypes = 2 if args.bert_binary_head else 0
30
    model = BertModel(
liangjing's avatar
v1  
liangjing committed
31
        config=config,
32
33
34
35
36
        num_tokentypes=num_tokentypes,
        add_binary_head=args.bert_binary_head,
        parallel_output=True,
        pre_process=pre_process,
        post_process=post_process)
Raul Puri's avatar
Raul Puri committed
37

38
    return model
Raul Puri's avatar
Raul Puri committed
39
40


Mohammad's avatar
Mohammad committed
41
def get_batch(data_iterator):
42
    """Build the batch."""
43

44
    # Items and their type.
45
    keys = ['text', 'types', 'labels', 'is_random', 'loss_mask', 'padding_mask']
46
47
48
49
50
51
52
    datatype = torch.int64

    # Broadcast data.
    if data_iterator is not None:
        data = next(data_iterator)
    else:
        data = None
53
    data_b = tensor_parallel.broadcast_data(keys, data, datatype)
54
55
56
57

    # Unpack.
    tokens = data_b['text'].long()
    types = data_b['types'].long()
58
59
60
61
    sentence_order = data_b['is_random'].long()
    loss_mask = data_b['loss_mask'].float()
    lm_labels = data_b['labels'].long()
    padding_mask = data_b['padding_mask'].long()
Raul Puri's avatar
Raul Puri committed
62

63
    return tokens, types, sentence_order, loss_mask, lm_labels, padding_mask
Raul Puri's avatar
Raul Puri committed
64
65


66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
def loss_func(loss_mask, sentence_order, output_tensor):
    lm_loss_, sop_logits = output_tensor

    lm_loss_ = lm_loss_.float()
    loss_mask = loss_mask.float()
    lm_loss = torch.sum(
        lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()

    if sop_logits is not None:
        sop_loss = F.cross_entropy(sop_logits.view(-1, 2).float(),
                                   sentence_order.view(-1),
                                   ignore_index=-1)
        sop_loss = sop_loss.float()
        loss = lm_loss + sop_loss
        averaged_losses = average_losses_across_data_parallel_group(
            [lm_loss, sop_loss])
        return loss, {'lm loss': averaged_losses[0],
                      'sop loss': averaged_losses[1]}

    else:
        loss = lm_loss
        averaged_losses = average_losses_across_data_parallel_group(
            [lm_loss])
        return loss, {'lm loss': averaged_losses[0]}


def forward_step(data_iterator, model):
Raul Puri's avatar
Raul Puri committed
93
    """Forward step."""
mohammad's avatar
mohammad committed
94
    args = get_args()
Mohammad's avatar
Mohammad committed
95
    timers = get_timers()
Raul Puri's avatar
Raul Puri committed
96
97

    # Get the batch.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
98
    timers('batch-generator', log_level=2).start()
99
100
    tokens, types, sentence_order, loss_mask, lm_labels, padding_mask = get_batch(
        data_iterator)
mohammad's avatar
mohammad committed
101
    timers('batch-generator').stop()
102

103
104
105
    if not args.bert_binary_head:
        types = None

106
    # Forward pass through the model.
107
108
109
110
    output_tensor = model(tokens, padding_mask, tokentype_ids=types,
                          lm_labels=lm_labels)

    return output_tensor, partial(loss_func, loss_mask, sentence_order)
111
112


113
114
def train_valid_test_datasets_provider(train_val_test_num_samples):
    """Build train, valid, and test datasets."""
Mohammad's avatar
Mohammad committed
115
    args = get_args()
Mohammad's avatar
Mohammad committed
116

117
118
119
120
121
122
123
124
125
    print_rank_0('> building train, validation, and test datasets '
                 'for BERT ...')
    train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
        data_prefix=args.data_path,
        data_impl=args.data_impl,
        splits_string=args.split,
        train_valid_test_num_samples=train_val_test_num_samples,
        max_seq_length=args.seq_length,
        seed=args.seed,
126
127
        skip_warmup=(not args.mmap_warmup),
        binary_head=args.bert_binary_head)
128
    print_rank_0("> finished creating BERT datasets ...")
129

130
    return train_ds, valid_ds, test_ds
Raul Puri's avatar
Raul Puri committed
131
132
133


if __name__ == "__main__":
134

135
136
137
    pretrain(train_valid_test_datasets_provider, model_provider,
             ModelType.encoder_or_decoder,
             forward_step, args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})