pretrain_bert.py 6.62 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved.
Raul Puri's avatar
Raul Puri committed
2

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
3
"""Pretrain BERT"""
Raul Puri's avatar
Raul Puri committed
4

5
6
from functools import partial

Raul Puri's avatar
Raul Puri committed
7
import torch
8
import torch.nn.functional as F
9

xingjinliang's avatar
xingjinliang committed
10
11
12
13
from megatron.training import get_args
from megatron.training import get_tokenizer
from megatron.training import print_rank_0
from megatron.training import get_timers
14
from megatron.core import tensor_parallel
15
from megatron.core.enums import ModelType
xingjinliang's avatar
xingjinliang committed
16
17
import megatron.legacy.model
from megatron.core.models.bert.bert_model import BertModel
Mohammad's avatar
Mohammad committed
18
from megatron.training import pretrain
xingjinliang's avatar
xingjinliang committed
19
20
21
22
23
24
25
26
from megatron.training.utils import average_losses_across_data_parallel_group
from megatron.training.arguments import core_transformer_config_from_args
from megatron.core.transformer.spec_utils import import_module
from megatron.core.models.bert.bert_layer_specs import bert_layer_with_transformer_engine_spec, bert_layer_local_spec
from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder
from megatron.core.datasets.bert_dataset import BERTMaskedWordPieceDataset, BERTMaskedWordPieceDatasetConfig
from megatron.core.datasets.utils import get_blend_from_list
from megatron.core import mpu, tensor_parallel
Mohammad's avatar
Mohammad committed
27
28


29
def model_provider(pre_process=True, post_process=True):
Raul Puri's avatar
Raul Puri committed
30
31
    """Build the model."""

32
    print_rank_0('building BERT model ...')
Raul Puri's avatar
Raul Puri committed
33

34
    args = get_args()
liangjing's avatar
v1  
liangjing committed
35
    config = core_transformer_config_from_args(args)
36
    num_tokentypes = 2 if args.bert_binary_head else 0
xingjinliang's avatar
xingjinliang committed
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65

    if args.use_legacy_models:
        model = megatron.legacy.model.BertModel(
            config=config,
            num_tokentypes=num_tokentypes,
            add_binary_head=args.bert_binary_head,
            parallel_output=True,
            pre_process=pre_process,
            post_process=post_process)
    else:
        if args.spec is None:
            transformer_layer_spec = bert_layer_with_transformer_engine_spec #default spec
        elif args.spec[0] == 'local':
            print_rank_0('Using Local spec for transformer layers')
            transformer_layer_spec = bert_layer_local_spec
        else :
            transformer_layer_spec = import_module(args.spec)

        model = BertModel(
            config=config,
            transformer_layer_spec=transformer_layer_spec,
            vocab_size=args.padded_vocab_size,
            max_sequence_length=args.max_position_embeddings,
            num_tokentypes=num_tokentypes,
            add_binary_head=args.bert_binary_head,
            share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights,
            parallel_output=True,
            pre_process=pre_process,
            post_process=post_process)
Raul Puri's avatar
Raul Puri committed
66

67
    return model
Raul Puri's avatar
Raul Puri committed
68
69


Mohammad's avatar
Mohammad committed
70
def get_batch(data_iterator):
71
    """Build the batch."""
72

73
    # Items and their type.
xingjinliang's avatar
xingjinliang committed
74
75
    keys = ['text', 'types', 'labels',
            'is_random', 'loss_mask', 'padding_mask']
76
77
78
79
80
81
82
    datatype = torch.int64

    # Broadcast data.
    if data_iterator is not None:
        data = next(data_iterator)
    else:
        data = None
83
    data_b = tensor_parallel.broadcast_data(keys, data, datatype)
84
85
86
87

    # Unpack.
    tokens = data_b['text'].long()
    types = data_b['types'].long()
88
89
90
91
    sentence_order = data_b['is_random'].long()
    loss_mask = data_b['loss_mask'].float()
    lm_labels = data_b['labels'].long()
    padding_mask = data_b['padding_mask'].long()
Raul Puri's avatar
Raul Puri committed
92

93
    return tokens, types, sentence_order, loss_mask, lm_labels, padding_mask
Raul Puri's avatar
Raul Puri committed
94
95


96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
def loss_func(loss_mask, sentence_order, output_tensor):
    lm_loss_, sop_logits = output_tensor

    lm_loss_ = lm_loss_.float()
    loss_mask = loss_mask.float()
    lm_loss = torch.sum(
        lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()

    if sop_logits is not None:
        sop_loss = F.cross_entropy(sop_logits.view(-1, 2).float(),
                                   sentence_order.view(-1),
                                   ignore_index=-1)
        sop_loss = sop_loss.float()
        loss = lm_loss + sop_loss
        averaged_losses = average_losses_across_data_parallel_group(
            [lm_loss, sop_loss])
        return loss, {'lm loss': averaged_losses[0],
                      'sop loss': averaged_losses[1]}
    else:
        loss = lm_loss
        averaged_losses = average_losses_across_data_parallel_group(
            [lm_loss])
        return loss, {'lm loss': averaged_losses[0]}


def forward_step(data_iterator, model):
Raul Puri's avatar
Raul Puri committed
122
    """Forward step."""
mohammad's avatar
mohammad committed
123
    args = get_args()
Mohammad's avatar
Mohammad committed
124
    timers = get_timers()
Raul Puri's avatar
Raul Puri committed
125
126

    # Get the batch.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
127
    timers('batch-generator', log_level=2).start()
128
129
    tokens, types, sentence_order, loss_mask, lm_labels, padding_mask = get_batch(
        data_iterator)
mohammad's avatar
mohammad committed
130
    timers('batch-generator').stop()
131

132
133
134
    if not args.bert_binary_head:
        types = None

135
    # Forward pass through the model.
xingjinliang's avatar
xingjinliang committed
136
137
    output_tensor = model(tokens, padding_mask,
                          tokentype_ids=types, lm_labels=lm_labels)
138
139

    return output_tensor, partial(loss_func, loss_mask, sentence_order)
140
141


142
143
def train_valid_test_datasets_provider(train_val_test_num_samples):
    """Build train, valid, and test datasets."""
Mohammad's avatar
Mohammad committed
144
    args = get_args()
Mohammad's avatar
Mohammad committed
145

xingjinliang's avatar
xingjinliang committed
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
    tokenizer = get_tokenizer()

    config = BERTMaskedWordPieceDatasetConfig(
        random_seed=args.seed,
        sequence_length=args.seq_length,
        blend=get_blend_from_list(args.data_path),
        blend_per_split=[
            get_blend_from_list(args.train_data_path),
            get_blend_from_list(args.valid_data_path),
            get_blend_from_list(args.test_data_path)
        ],
        renormalize_blend_weights=args.renormalize_blend_weights,
        split=args.split,
        path_to_cache=args.data_cache_path,
        tokenizer=tokenizer,
        masking_probability=args.mask_prob,
        short_sequence_probability=args.short_seq_prob,
        masking_max_ngram=3,
        masking_do_full_word=True,
        masking_do_permutation=False,
        masking_use_longer_ngrams=False,
        masking_use_geometric_distribution=False,
        classification_head=args.bert_binary_head,
    )

171
172
    print_rank_0('> building train, validation, and test datasets '
                 'for BERT ...')
xingjinliang's avatar
xingjinliang committed
173
174
175
176
177
178
179
180

    train_ds, valid_ds, test_ds = BlendedMegatronDatasetBuilder(
        BERTMaskedWordPieceDataset,
        train_val_test_num_samples,
        lambda: mpu.get_tensor_model_parallel_rank() == 0,
        config,
    ).build()

181
    print_rank_0("> finished creating BERT datasets ...")
182

183
    return train_ds, valid_ds, test_ds
Raul Puri's avatar
Raul Puri committed
184
185
186


if __name__ == "__main__":
187

xingjinliang's avatar
xingjinliang committed
188
189
190
    # Temporary for transition to core datasets
    train_valid_test_datasets_provider.is_distributed = True

191
192
193
    pretrain(train_valid_test_datasets_provider, model_provider,
             ModelType.encoder_or_decoder,
             forward_step, args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})