pretrain_bert_ict.py 5.64 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Pretrain BERT for Inverse Cloze Task"""

import torch
Neel Kant's avatar
Neel Kant committed
19
import torch.distributed as dist
20
21
import torch.nn.functional as F

22
23
from megatron import get_args
from megatron import get_timers
24
from megatron import mpu
25
from megatron import print_rank_0
26
from megatron.data.dataset_utils import build_train_valid_test_datasets
27
from megatron.model import ICTBertModel
28
from megatron.training import pretrain
29
30
from megatron.utils import reduce_losses

Neel Kant's avatar
Neel Kant committed
31
num_batches = 0
32

Neel Kant's avatar
Neel Kant committed
33

34
def general_model_provider(only_query_model=False, only_block_model=False):
35
    """Build the model."""
36
    args = get_args()
37
38
39
40
    if args.ict_head_size is None:
        raise ValueError("Need to specify --ict-head-size to provide an ICTBertModel")

    print_rank_0('building ICTBertModel...')
41

42
    # simpler to just keep using 2 tokentypes since the LM we initialize with has 2 tokentypes
43
    model = ICTBertModel(
44
        ict_head_size=args.ict_head_size,
45
        num_tokentypes=2,
Neel Kant's avatar
Neel Kant committed
46
47
48
        parallel_output=True,
        only_query_model=only_query_model,
        only_block_model=only_block_model)
49
50
51
52

    return model


53
54
55
56
def model_provider():
    return general_model_provider(False, False)


57
def get_batch(data_iterator):
58
    # Items and their type.
59
60
    keys = ['query_tokens', 'query_pad_mask',
            'block_tokens', 'block_pad_mask', 'block_data']
61
62
63
    datatype = torch.int64

    # Broadcast data.
64
    if data_iterator is None:
65
        data = None
66
67
    else:
        data = next(data_iterator)
68
69
70
    data_b = mpu.broadcast_data(keys, data, datatype)

    # Unpack.
Neel Kant's avatar
Neel Kant committed
71
72
73
74
    query_tokens = data_b['query_tokens'].long()
    query_pad_mask = data_b['query_pad_mask'].long()
    block_tokens = data_b['block_tokens'].long()
    block_pad_mask = data_b['block_pad_mask'].long()
75
    block_indices = data_b['block_data'].long()
76

77
78
    return query_tokens, query_pad_mask,\
           block_tokens, block_pad_mask, block_indices
79
80


81
def forward_step(data_iterator, model):
82
    """Forward step."""
Neel Kant's avatar
Neel Kant committed
83
    args = get_args()
84
    timers = get_timers()
85
86
87

    # Get the batch.
    timers('batch generator').start()
88
89
    query_tokens, query_pad_mask, \
    block_tokens, block_pad_mask, block_indices = get_batch(data_iterator)
90
91
92
    timers('batch generator').stop()

    # Forward model.
Neel Kant's avatar
Neel Kant committed
93
94
95
96
97
98
99
100
101
102
103
    # retrieval_scores = model(query_tokens, query_pad_mask, block_tokens, block_pad_mask).float()
    query_logits, block_logits = model(query_tokens, query_pad_mask, block_tokens, block_pad_mask)

    data_parallel_size = dist.get_world_size() / args.model_parallel_size
    batch_size = query_logits.shape[0]
    global_batch_size = int(batch_size * data_parallel_size)

    all_logits_shape = (int(global_batch_size), int(query_logits.shape[1]))
    all_query_logits = torch.zeros(all_logits_shape).type(query_logits.dtype).cuda()
    all_block_logits = all_query_logits.clone().cuda()

104
    # record this processes' data and then merge with other processes below
Neel Kant's avatar
Neel Kant committed
105
106
107
    all_query_logits[args.rank * batch_size:(args.rank + 1) * batch_size] = query_logits
    all_block_logits[args.rank * batch_size:(args.rank + 1) * batch_size] = block_logits

108
    # currently this assumes model parallel size == 1.
Neel Kant's avatar
Neel Kant committed
109
110
111
    dist.all_reduce(all_query_logits)
    dist.all_reduce(all_block_logits)

112
    # scores are inner products between query and block embeddings
Neel Kant's avatar
Neel Kant committed
113
    retrieval_scores = all_query_logits.float().matmul(torch.transpose(all_block_logits, 0, 1).float())
114
    softmaxed = F.softmax(retrieval_scores, dim=1)
115
    sorted_vals, sorted_indices = torch.topk(softmaxed, k=softmaxed.shape[1], sorted=True)
116

Neel Kant's avatar
Neel Kant committed
117
118
119
120
121
122
    def topk_acc(k):
        return torch.cuda.FloatTensor([sum([int(i in sorted_indices[i, :k]) for i in range(global_batch_size)]) / global_batch_size])
    top_accs = [topk_acc(k) for k in [1, 8, 20, 100]]

    retrieval_loss = torch.nn.CrossEntropyLoss()(retrieval_scores, torch.arange(global_batch_size).long().cuda())
    reduced_losses = reduce_losses([retrieval_loss, *top_accs])
123
124
125
    stats_dict = {
        'retrieval loss': reduced_losses[0],
        'top1_acc': reduced_losses[1],
Neel Kant's avatar
Neel Kant committed
126
127
128
        'top8_acc': reduced_losses[2],
        'top20_acc': reduced_losses[3],
        'top100_acc': reduced_losses[4],
129
    }
130

131
    return retrieval_loss, stats_dict
132
133


Neel Kant's avatar
Neel Kant committed
134
135
def train_valid_test_datasets_provider(train_val_test_num_samples):
    """Build train, valid and test datasets."""
136
    args = get_args()
Neel Kant's avatar
Neel Kant committed
137
    print_rank_0('> building train, validation, and test datasets '
Neel Kant's avatar
Neel Kant committed
138
                 'for BERT ICT...')
139

Neel Kant's avatar
Neel Kant committed
140
141
142
143
144
145
146
147
148
149
    train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
        data_prefix=args.data_path,
        data_impl=args.data_impl,
        splits_string=args.split,
        train_valid_test_num_samples=train_val_test_num_samples,
        max_seq_length=args.seq_length,
        masked_lm_prob=args.mask_prob,
        short_seq_prob=args.short_seq_prob,
        seed=args.seed,
        skip_warmup=(not args.mmap_warmup),
150
        dataset_type='ict')
Neel Kant's avatar
Neel Kant committed
151
    print_rank_0("> finished creating BERT ICT datasets ...")
152

Neel Kant's avatar
Neel Kant committed
153
    return train_ds, valid_ds, test_ds
154
155
156
157


if __name__ == "__main__":

Neel Kant's avatar
Neel Kant committed
158
    pretrain(train_valid_test_datasets_provider, model_provider, forward_step,
159
             args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})