pretrain_realm.py 9.62 KB
Newer Older
Neel Kant's avatar
Neel Kant committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Pretrain BERT for Inverse Cloze Task"""
17
18
19
import sys

import numpy as np
Neel Kant's avatar
Neel Kant committed
20
21
22
import torch
import torch.nn.functional as F

23
from indexer import load_ict_checkpoint, get_ict_dataset
Neel Kant's avatar
Neel Kant committed
24
from megatron.data.realm_index import BlockData, RandProjectionLSHIndex, FaissMIPSIndex
Neel Kant's avatar
Neel Kant committed
25
26
27
28
from megatron import get_args
from megatron import get_timers
from megatron import mpu
from megatron import print_rank_0
29
from megatron.data.dataset_utils import build_train_valid_test_datasets
30
31
from megatron.model import REALMBertModel, REALMRetriever
from megatron.training import pretrain
32
33
from megatron.utils import reduce_losses, report_memory
from megatron import mpu
34
from indexer import initialize_and_run_async_megatron
35
from megatron.mpu.initialize import get_data_parallel_group
Neel Kant's avatar
Neel Kant committed
36
37
38

num_batches = 0

39

Neel Kant's avatar
Neel Kant committed
40
41
42
def model_provider():
    """Build the model."""
    args = get_args()
43
    print_rank_0('building REALM models ...')
Neel Kant's avatar
Neel Kant committed
44

45
46
47
48
    try:
        ict_model = load_ict_checkpoint(from_realm_chkpt=True)
    except:
        ict_model = load_ict_checkpoint(from_realm_chkpt=False)
49
    ict_dataset = get_ict_dataset(use_titles=False)
Neel Kant's avatar
Neel Kant committed
50
    all_block_data = BlockData.load_from_file(args.block_data_path)
51
    hashed_index = FaissMIPSIndex(index_type='flat_ip', embed_size=128, use_gpu=args.faiss_use_gpu)
Neel Kant's avatar
Neel Kant committed
52
    hashed_index.add_block_embed_data(all_block_data)
53

54
    # top_k + 1 because we may need to exclude trivial candidate
55
    retriever = REALMRetriever(ict_model, ict_dataset, all_block_data, hashed_index, args.block_top_k)
56
    model = REALMBertModel(retriever)
57

58
    return model
Neel Kant's avatar
Neel Kant committed
59
60
61
62


def get_batch(data_iterator):
    # Items and their type.
63
    keys = ['tokens', 'labels', 'loss_mask', 'pad_mask', 'query_block_indices']
Neel Kant's avatar
Neel Kant committed
64
65
66
67
68
69
70
    datatype = torch.int64

    # Broadcast data.
    if data_iterator is None:
        data = None
    else:
        data = next(data_iterator)
71

Neel Kant's avatar
Neel Kant committed
72
73
74
    data_b = mpu.broadcast_data(keys, data, datatype)

    # Unpack.
75
76
77
78
    tokens = data_b['tokens'].long()
    labels = data_b['labels'].long()
    loss_mask = data_b['loss_mask'].long()
    pad_mask = data_b['pad_mask'].long()
79
    query_block_indices = data_b['query_block_indices'].long()
Neel Kant's avatar
Neel Kant committed
80

81
    return tokens, labels, loss_mask, pad_mask, query_block_indices
Neel Kant's avatar
Neel Kant committed
82
83


84
85
86
87
88
def get_qa_batch(data_iterator):
    question_tokens, question_attention_mask, answer_tokens, answer_token_lengths = next(data_iterator)
    return question_tokens, question_attention_mask, answer_tokens, answer_token_lengths


Neel Kant's avatar
Neel Kant committed
89
90
91
92
93
94
def forward_step(data_iterator, model):
    """Forward step."""
    timers = get_timers()

    # Get the batch.
    timers('batch generator').start()
95
    tokens, labels, loss_mask, pad_mask, query_block_indices = get_batch(data_iterator)
Neel Kant's avatar
Neel Kant committed
96
97
98
    timers('batch generator').stop()

    # Forward model.
99
    lm_logits, block_probs = model(tokens, pad_mask, query_block_indices)
100
101
102
    # print('logits shape: ', lm_logits.shape, flush=True)
    # print('labels shape: ', labels.shape, flush=True)

103
    with torch.no_grad():
104
105
        max_retrieval_utility, top_retrieval_utility, avg_retrieval_utility = mpu.checkpoint(
            get_retrieval_utility, lm_logits, block_probs, labels, loss_mask)
106
107

    # P(y|x) = sum_z(P(y|z, x) * P(z|x))
108
    null_block_probs = torch.mean(block_probs[:, block_probs.shape[1] - 1])
Neel Kant's avatar
Neel Kant committed
109

110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
    # logits: [batch x top_k x 2 * seq_length x vocab_size]
    # labels: [batch x seq_length]
    relevant_logits = lm_logits[:, :, :labels.shape[1]].float()
    # if get_args().rank == 0:
    #     torch.save({'logits': relevant_logits.cpu(),
    #                 'block_probs': block_probs.cpu(),
    #                 'labels': labels.cpu(),
    #                 'loss_mask': loss_mask.cpu(),
    #                 'tokens': tokens.cpu(),
    #                 'pad_mask': pad_mask.cpu(),
    #                 }, 'tensors.data')
        # torch.load('gagaga')
    block_probs = block_probs.unsqueeze(2).unsqueeze(3).expand_as(relevant_logits)
    # print(torch.sum(block_probs, dim=1), flush=True)

    def get_log_probs(logits, b_probs):
        max_logits = torch.max(logits, dim=-1, keepdim=True)[0].expand_as(logits)
        logits = logits - max_logits

        softmaxed_logits = F.softmax(logits, dim=-1)
        marginalized_probs = torch.sum(softmaxed_logits * b_probs, dim=1)
        l_probs = torch.log(marginalized_probs)
        return l_probs

    log_probs = mpu.checkpoint(get_log_probs, relevant_logits, block_probs)

    def get_loss(l_probs, labs):
        vocab_size = l_probs.shape[2]
        loss = torch.nn.NLLLoss(ignore_index=-1)(l_probs.reshape(-1, vocab_size), labs.reshape(-1))
        # loss = torch.sum(lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()
        return loss.float()

    lm_loss = mpu.checkpoint(get_loss, log_probs, labels)

    # marginalized_logits = torch.sum(relevant_logits * block_probs, dim=1)
    # vocab_size = marginalized_logits.shape[2]
    # lm_loss_ = torch.nn.CrossEntropyLoss()(marginalized_logits.reshape(-1, vocab_size), labels.reshape(-1))
    # lm_loss = torch.sum(lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()
148

149
    reduced_loss = reduce_losses([lm_loss, max_retrieval_utility, top_retrieval_utility, avg_retrieval_utility, null_block_probs])
150
    # reduced_loss = reduce_losses([lm_loss])
151
    # torch.cuda.synchronize()
152
153
154
155
156
    return lm_loss, {'lm_loss': reduced_loss[0],
                     'max_ru': reduced_loss[1],
                     'top_ru': reduced_loss[2],
                     'avg_ru': reduced_loss[3],
                     'null_prob': reduced_loss[4]}
157
158


159
def get_retrieval_utility(lm_logits_, block_probs, labels, loss_mask):
160
161
    """log P(y | z, x) - log P(y | null, x)"""
    # [batch x seq_len x vocab_size]
162
    lm_logits = lm_logits_[:, :, :labels.shape[1], :]
163
164
165
    #non_null_block_probs = block_probs[:, :-1]
    #non_null_block_probs /= torch.sum(non_null_block_probs, axis=1, keepdim=True)
    # non_null_block_probs = non_null_block_probsexpand_as(lm_logits[:, :-1, :, :])
166
167
168
169
170
171
172
173
174
175
176
    null_block_lm_logits = lm_logits[:, -1, :, :]
    null_block_loss_ = mpu.vocab_parallel_cross_entropy(null_block_lm_logits.contiguous().float(),
                                                       labels.contiguous())
    null_block_loss = torch.sum(
        null_block_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()

    retrieved_block_losses = []
    for block_num in range(lm_logits.shape[1] - 1):
        retrieved_block_lm_logits = lm_logits[:, block_num, :, :]
        retrieved_block_loss_ = mpu.vocab_parallel_cross_entropy(retrieved_block_lm_logits.contiguous().float(),
                                                                 labels.contiguous())
177
        #retrieved_block_loss_ *= non_null_block_probs[:, block_num].reshape(-1, 1)
178
179
180
        retrieved_block_loss = torch.sum(
            retrieved_block_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()
        retrieved_block_losses.append(retrieved_block_loss)
181
    avg_retrieved_block_loss = torch.sum(torch.cuda.FloatTensor(retrieved_block_losses)) / (lm_logits.shape[1] - 1)
182
183
184
185
    max_retrieval_utility = null_block_loss - min(retrieved_block_losses)
    top_retrieval_utility = null_block_loss - retrieved_block_losses[0]
    avg_retrieval_utility = null_block_loss - avg_retrieved_block_loss
    return max_retrieval_utility, top_retrieval_utility, avg_retrieval_utility
Neel Kant's avatar
Neel Kant committed
186
187


188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
def qa_forward_step(data_iterator, model):
    timers = get_timers()

    # this dataset interface needs to be implemented
    timers('batch generator').start()
    question_tokens, question_attention_mask, answer_tokens, answer_token_lengths = get_qa_batch(data_iterator)
    timers('batch generator').stop()

    batch_span_logits, batch_loss_masks, block_probs = model(question_tokens, question_attention_mask,
                                                             answer_tokens, answer_token_lengths)
    # [batch_size x k x num_spans]
    block_probs = block_probs.unsqueeze(2).expand_as(batch_span_logits)
    batch_span_probs = F.softmax(batch_span_logits, dim=2)
    reduced_block_span_probs = torch.sum(batch_span_probs * block_probs, dim=1)
    qa_span_loss_ = -torch.log(reduced_block_span_probs)
    qa_span_loss = torch.sum(
        qa_span_loss_.view(-1) * batch_loss_masks
    )


Neel Kant's avatar
Neel Kant committed
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
def train_valid_test_datasets_provider(train_val_test_num_samples):
    """Build train, valid and test datasets."""
    args = get_args()
    print_rank_0('> building train, validation, and test datasets '
                 'for BERT ...')

    train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
        data_prefix=args.data_path,
        data_impl=args.data_impl,
        splits_string=args.split,
        train_valid_test_num_samples=train_val_test_num_samples,
        max_seq_length=args.seq_length,
        masked_lm_prob=args.mask_prob,
        short_seq_prob=args.short_seq_prob,
        seed=args.seed,
Neel Kant's avatar
Neel Kant committed
223
224
        skip_warmup=(not args.mmap_warmup),
        dataset_type='realm')
Neel Kant's avatar
Neel Kant committed
225
226
227
228
229
230
231
    print_rank_0("> finished creating BERT ICT datasets ...")

    return train_ds, valid_ds, test_ds


if __name__ == "__main__":
    pretrain(train_valid_test_datasets_provider, model_provider, forward_step,
232
233
             args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'},
             initializer_func=initialize_and_run_async_megatron)