pretrain_realm.py 8.07 KB
Newer Older
Neel Kant's avatar
Neel Kant committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Pretrain BERT for Inverse Cloze Task"""
import torch
import torch.nn.functional as F

20
from indexer import load_ict_checkpoint, get_ict_dataset
Neel Kant's avatar
Neel Kant committed
21
from megatron.data.realm_index import BlockData, RandProjectionLSHIndex, FaissMIPSIndex
Neel Kant's avatar
Neel Kant committed
22
23
24
25
from megatron import get_args
from megatron import get_timers
from megatron import mpu
from megatron import print_rank_0
26
from megatron.data.dataset_utils import build_train_valid_test_datasets
27
28
from megatron.model import REALMBertModel, REALMRetriever
from megatron.training import pretrain
29
30
from megatron.utils import reduce_losses, report_memory
from megatron import mpu
31
from indexer import initialize_and_run_async_megatron
Neel Kant's avatar
Neel Kant committed
32
33
34

num_batches = 0

35

Neel Kant's avatar
Neel Kant committed
36
37
38
def model_provider():
    """Build the model."""
    args = get_args()
39
    print_rank_0('building REALM models ...')
Neel Kant's avatar
Neel Kant committed
40

41
42
43
44
    try:
        ict_model = load_ict_checkpoint(from_realm_chkpt=True)
    except:
        ict_model = load_ict_checkpoint(from_realm_chkpt=False)
45
    ict_dataset = get_ict_dataset(use_titles=False)
Neel Kant's avatar
Neel Kant committed
46
    all_block_data = BlockData.load_from_file(args.block_data_path)
Neel Kant's avatar
Neel Kant committed
47
    # hashed_index = RandProjectionLSHIndex.load_from_file(args.block_index_path)
48
    hashed_index = FaissMIPSIndex(index_type='flat_ip', embed_size=128, use_gpu=args.faiss_use_gpu)
Neel Kant's avatar
Neel Kant committed
49
    hashed_index.add_block_embed_data(all_block_data)
50

51
    # top_k + 1 because we may need to exclude trivial candidate
52
    retriever = REALMRetriever(ict_model, ict_dataset, all_block_data, hashed_index, args.block_top_k)
53
    model = REALMBertModel(retriever)
54

55
    return model
Neel Kant's avatar
Neel Kant committed
56
57
58
59


def get_batch(data_iterator):
    # Items and their type.
60
    keys = ['tokens', 'labels', 'loss_mask', 'pad_mask', 'query_block_indices']
Neel Kant's avatar
Neel Kant committed
61
62
63
64
65
66
67
    datatype = torch.int64

    # Broadcast data.
    if data_iterator is None:
        data = None
    else:
        data = next(data_iterator)
68
69
70



Neel Kant's avatar
Neel Kant committed
71
72
73
    data_b = mpu.broadcast_data(keys, data, datatype)

    # Unpack.
74
75
76
77
    tokens = data_b['tokens'].long()
    labels = data_b['labels'].long()
    loss_mask = data_b['loss_mask'].long()
    pad_mask = data_b['pad_mask'].long()
78
    query_block_indices = data_b['query_block_indices'].long()
Neel Kant's avatar
Neel Kant committed
79

80
    return tokens, labels, loss_mask, pad_mask, query_block_indices
Neel Kant's avatar
Neel Kant committed
81
82


83
84
85
86
87
def get_qa_batch(data_iterator):
    question_tokens, question_attention_mask, answer_tokens, answer_token_lengths = next(data_iterator)
    return question_tokens, question_attention_mask, answer_tokens, answer_token_lengths


Neel Kant's avatar
Neel Kant committed
88
89
90
91
92
93
def forward_step(data_iterator, model):
    """Forward step."""
    timers = get_timers()

    # Get the batch.
    timers('batch generator').start()
94
    tokens, labels, loss_mask, pad_mask, query_block_indices = get_batch(data_iterator)
Neel Kant's avatar
Neel Kant committed
95
96
97
    timers('batch generator').stop()

    # Forward model.
98
    lm_logits, block_probs = model(tokens, pad_mask, query_block_indices)
99
    with torch.no_grad():
100
101
        max_retrieval_utility, top_retrieval_utility, avg_retrieval_utility = mpu.checkpoint(
            get_retrieval_utility, lm_logits, block_probs, labels, loss_mask)
102
103

    # P(y|x) = sum_z(P(y|z, x) * P(z|x))
104
    null_block_probs = torch.mean(block_probs[:, block_probs.shape[1] - 1])
Neel Kant's avatar
Neel Kant committed
105
    block_probs = block_probs.unsqueeze(2).unsqueeze(3).expand_as(lm_logits)
Neel Kant's avatar
Neel Kant committed
106
107
    lm_logits = torch.sum(lm_logits * block_probs, dim=1)[:, :labels.shape[1]]

108
109
110
111
112
    lm_loss_ = mpu.vocab_parallel_cross_entropy(lm_logits.contiguous().float(),
                                                labels.contiguous())
    lm_loss = torch.sum(
        lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()

113
    reduced_loss = reduce_losses([lm_loss, max_retrieval_utility, top_retrieval_utility, avg_retrieval_utility, null_block_probs])
114
    # torch.cuda.synchronize()
115
116
117
118
119
    return lm_loss, {'lm_loss': reduced_loss[0],
                     'max_ru': reduced_loss[1],
                     'top_ru': reduced_loss[2],
                     'avg_ru': reduced_loss[3],
                     'null_prob': reduced_loss[4]}
120
121


122
def get_retrieval_utility(lm_logits, block_probs, labels, loss_mask):
123
124
    """log P(y | z, x) - log P(y | null, x)"""
    # [batch x seq_len x vocab_size]
125
126
127
128
    lm_logits = lm_logits[:, :, :labels.shape[1], :]
    #non_null_block_probs = block_probs[:, :-1]
    #non_null_block_probs /= torch.sum(non_null_block_probs, axis=1, keepdim=True)
    # non_null_block_probs = non_null_block_probsexpand_as(lm_logits[:, :-1, :, :])
129
130
131
132
133
134
135
136
137
138
139
    null_block_lm_logits = lm_logits[:, -1, :, :]
    null_block_loss_ = mpu.vocab_parallel_cross_entropy(null_block_lm_logits.contiguous().float(),
                                                       labels.contiguous())
    null_block_loss = torch.sum(
        null_block_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()

    retrieved_block_losses = []
    for block_num in range(lm_logits.shape[1] - 1):
        retrieved_block_lm_logits = lm_logits[:, block_num, :, :]
        retrieved_block_loss_ = mpu.vocab_parallel_cross_entropy(retrieved_block_lm_logits.contiguous().float(),
                                                                 labels.contiguous())
140
        #retrieved_block_loss_ *= non_null_block_probs[:, block_num].reshape(-1, 1)
141
142
143
        retrieved_block_loss = torch.sum(
            retrieved_block_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()
        retrieved_block_losses.append(retrieved_block_loss)
144
    avg_retrieved_block_loss = torch.sum(torch.cuda.FloatTensor(retrieved_block_losses)) / (lm_logits.shape[1] - 1)
145
146
147
148
    max_retrieval_utility = null_block_loss - min(retrieved_block_losses)
    top_retrieval_utility = null_block_loss - retrieved_block_losses[0]
    avg_retrieval_utility = null_block_loss - avg_retrieved_block_loss
    return max_retrieval_utility, top_retrieval_utility, avg_retrieval_utility
Neel Kant's avatar
Neel Kant committed
149
150


151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
def qa_forward_step(data_iterator, model):
    timers = get_timers()

    # this dataset interface needs to be implemented
    timers('batch generator').start()
    question_tokens, question_attention_mask, answer_tokens, answer_token_lengths = get_qa_batch(data_iterator)
    timers('batch generator').stop()

    batch_span_logits, batch_loss_masks, block_probs = model(question_tokens, question_attention_mask,
                                                             answer_tokens, answer_token_lengths)
    # [batch_size x k x num_spans]
    block_probs = block_probs.unsqueeze(2).expand_as(batch_span_logits)
    batch_span_probs = F.softmax(batch_span_logits, dim=2)
    reduced_block_span_probs = torch.sum(batch_span_probs * block_probs, dim=1)
    qa_span_loss_ = -torch.log(reduced_block_span_probs)
    qa_span_loss = torch.sum(
        qa_span_loss_.view(-1) * batch_loss_masks
    )


Neel Kant's avatar
Neel Kant committed
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
def train_valid_test_datasets_provider(train_val_test_num_samples):
    """Build train, valid and test datasets."""
    args = get_args()
    print_rank_0('> building train, validation, and test datasets '
                 'for BERT ...')

    train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
        data_prefix=args.data_path,
        data_impl=args.data_impl,
        splits_string=args.split,
        train_valid_test_num_samples=train_val_test_num_samples,
        max_seq_length=args.seq_length,
        masked_lm_prob=args.mask_prob,
        short_seq_prob=args.short_seq_prob,
        seed=args.seed,
Neel Kant's avatar
Neel Kant committed
186
187
        skip_warmup=(not args.mmap_warmup),
        dataset_type='realm')
Neel Kant's avatar
Neel Kant committed
188
189
190
191
192
193
194
    print_rank_0("> finished creating BERT ICT datasets ...")

    return train_ds, valid_ds, test_ds


if __name__ == "__main__":
    pretrain(train_valid_test_datasets_provider, model_provider, forward_step,
195
196
             args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'},
             initializer_func=initialize_and_run_async_megatron)