pretrain_bert_ict.py 4.19 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Pretrain BERT for Inverse Cloze Task"""

import torch
import torch.nn.functional as F

21
22
from megatron import get_args
from megatron import get_timers
23
from megatron import mpu
24
from megatron import print_rank_0
25
from megatron.data.dataset_utils import build_train_valid_test_datasets
26
from megatron.model import ICTBertModel
27
from megatron.training import pretrain
28
29
from megatron.utils import reduce_losses

Neel Kant's avatar
Neel Kant committed
30
num_batches = 0
31

Neel Kant's avatar
Neel Kant committed
32

Neel Kant's avatar
Neel Kant committed
33
def model_provider(only_query_model=False, only_block_model=False):
34
    """Build the model."""
35
    args = get_args()
36
    print_rank_0('building BERT models ...')
37

38
39
    model = ICTBertModel(
        ict_head_size=128,
40
        num_tokentypes=2,
Neel Kant's avatar
Neel Kant committed
41
42
43
        parallel_output=True,
        only_query_model=only_query_model,
        only_block_model=only_block_model)
44
45
46
47

    return model


48
def get_batch(data_iterator):
49
    # Items and their type.
50
51
    keys = ['query_tokens', 'query_pad_mask',
            'block_tokens', 'block_pad_mask', 'block_data']
52
53
54
    datatype = torch.int64

    # Broadcast data.
55
    if data_iterator is None:
56
        data = None
57
58
    else:
        data = next(data_iterator)
59
60
61
    data_b = mpu.broadcast_data(keys, data, datatype)

    # Unpack.
Neel Kant's avatar
Neel Kant committed
62
63
64
65
    query_tokens = data_b['query_tokens'].long()
    query_pad_mask = data_b['query_pad_mask'].long()
    block_tokens = data_b['block_tokens'].long()
    block_pad_mask = data_b['block_pad_mask'].long()
66
    block_indices = data_b['block_data'].long()
67

68
69
    return query_tokens, query_pad_mask,\
           block_tokens, block_pad_mask, block_indices
70
71


72
def forward_step(data_iterator, model):
73
    """Forward step."""
74
    timers = get_timers()
75
76
77

    # Get the batch.
    timers('batch generator').start()
78
79
    query_tokens, query_pad_mask, \
    block_tokens, block_pad_mask, block_indices = get_batch(data_iterator)
80
81
82
    timers('batch generator').stop()

    # Forward model.
83
    retrieval_scores = model(query_tokens, query_pad_mask, block_tokens, block_pad_mask).float()
84
    softmaxed = F.softmax(retrieval_scores, dim=1)
Neel Kant's avatar
Neel Kant committed
85

86
87
    top5_vals, top5_indices = torch.topk(softmaxed, k=5, sorted=True)
    batch_size = softmaxed.shape[0]
88

89
90
91
92
93
    top1_acc = torch.cuda.FloatTensor([sum([int(top5_indices[i, 0] == i) for i in range(batch_size)]) / batch_size])
    top5_acc = torch.cuda.FloatTensor([sum([int(i in top5_indices[i]) for i in range(batch_size)]) / batch_size])

    retrieval_loss = F.cross_entropy(softmaxed, torch.arange(batch_size).cuda())
    reduced_losses = reduce_losses([retrieval_loss, top1_acc, top5_acc])
94
95
96
97
98
    stats_dict = {
        'retrieval loss': reduced_losses[0],
        'top1_acc': reduced_losses[1],
        'top5_acc': reduced_losses[2]
    }
99

100
    return retrieval_loss, stats_dict
101
102


Neel Kant's avatar
Neel Kant committed
103
104
def train_valid_test_datasets_provider(train_val_test_num_samples):
    """Build train, valid and test datasets."""
105
    args = get_args()
Neel Kant's avatar
Neel Kant committed
106
    print_rank_0('> building train, validation, and test datasets '
Neel Kant's avatar
Neel Kant committed
107
                 'for BERT ICT...')
108

Neel Kant's avatar
Neel Kant committed
109
110
111
112
113
114
115
116
117
118
    train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
        data_prefix=args.data_path,
        data_impl=args.data_impl,
        splits_string=args.split,
        train_valid_test_num_samples=train_val_test_num_samples,
        max_seq_length=args.seq_length,
        masked_lm_prob=args.mask_prob,
        short_seq_prob=args.short_seq_prob,
        seed=args.seed,
        skip_warmup=(not args.mmap_warmup),
119
        dataset_type='ict')
Neel Kant's avatar
Neel Kant committed
120
    print_rank_0("> finished creating BERT ICT datasets ...")
121

Neel Kant's avatar
Neel Kant committed
122
    return train_ds, valid_ds, test_ds
123
124
125
126


if __name__ == "__main__":

Neel Kant's avatar
Neel Kant committed
127
    pretrain(train_valid_test_datasets_provider, model_provider, forward_step,
128
             args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})