pretrain_realm.py 7.33 KB
Newer Older
Neel Kant's avatar
Neel Kant committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Pretrain BERT for Inverse Cloze Task"""
import torch
import torch.nn.functional as F

20
from indexer import load_ict_checkpoint, get_ict_dataset
Neel Kant's avatar
Neel Kant committed
21
from megatron.data.realm_index import BlockData, RandProjectionLSHIndex, FaissMIPSIndex
Neel Kant's avatar
Neel Kant committed
22
23
24
25
from megatron import get_args
from megatron import get_timers
from megatron import mpu
from megatron import print_rank_0
26
from megatron.data.dataset_utils import build_train_valid_test_datasets
27
28
from megatron.model import REALMBertModel, REALMRetriever
from megatron.training import pretrain
Neel Kant's avatar
Neel Kant committed
29
from megatron.utils import reduce_losses
30
from indexer import initialize_and_run_async_megatron
Neel Kant's avatar
Neel Kant committed
31
32
33

num_batches = 0

34

Neel Kant's avatar
Neel Kant committed
35
36
37
def model_provider():
    """Build the model."""
    args = get_args()
38
    print_rank_0('building REALM models ...')
Neel Kant's avatar
Neel Kant committed
39

40
    ict_model = load_ict_checkpoint()
41
    ict_dataset = get_ict_dataset(use_titles=False)
Neel Kant's avatar
Neel Kant committed
42
    all_block_data = BlockData.load_from_file(args.block_data_path)
Neel Kant's avatar
Neel Kant committed
43
    # hashed_index = RandProjectionLSHIndex.load_from_file(args.block_index_path)
44
    hashed_index = FaissMIPSIndex(index_type='flat_ip', embed_size=128)
Neel Kant's avatar
Neel Kant committed
45
    hashed_index.add_block_embed_data(all_block_data)
46

47
48
    # top_k + 1 because we may need to exclude trivial candidate
    retriever = REALMRetriever(ict_model, ict_dataset, all_block_data, hashed_index, args.block_top_k + 1)
49
    model = REALMBertModel(retriever)
50

51
    return model
Neel Kant's avatar
Neel Kant committed
52
53
54
55


def get_batch(data_iterator):
    # Items and their type.
56
    keys = ['tokens', 'labels', 'loss_mask', 'pad_mask', 'query_block_indices']
Neel Kant's avatar
Neel Kant committed
57
58
59
60
61
62
63
64
65
66
    datatype = torch.int64

    # Broadcast data.
    if data_iterator is None:
        data = None
    else:
        data = next(data_iterator)
    data_b = mpu.broadcast_data(keys, data, datatype)

    # Unpack.
67
68
69
70
    tokens = data_b['tokens'].long()
    labels = data_b['labels'].long()
    loss_mask = data_b['loss_mask'].long()
    pad_mask = data_b['pad_mask'].long()
71
    query_block_indices = data_b['query_block_indices'].long()
Neel Kant's avatar
Neel Kant committed
72

73
    return tokens, labels, loss_mask, pad_mask, query_block_indices
Neel Kant's avatar
Neel Kant committed
74
75


76
77
78
79
80
def get_qa_batch(data_iterator):
    question_tokens, question_attention_mask, answer_tokens, answer_token_lengths = next(data_iterator)
    return question_tokens, question_attention_mask, answer_tokens, answer_token_lengths


Neel Kant's avatar
Neel Kant committed
81
82
83
84
85
86
def forward_step(data_iterator, model):
    """Forward step."""
    timers = get_timers()

    # Get the batch.
    timers('batch generator').start()
87
    tokens, labels, loss_mask, pad_mask, query_block_indices = get_batch(data_iterator)
Neel Kant's avatar
Neel Kant committed
88
89
90
    timers('batch generator').stop()

    # Forward model.
91
    lm_logits, block_probs = model(tokens, pad_mask, query_block_indices)
92
    with torch.no_grad():
93
        retrieval_utility = get_retrieval_utility(lm_logits, block_probs, labels, loss_mask)
94
95

    # P(y|x) = sum_z(P(y|z, x) * P(z|x))
Neel Kant's avatar
Neel Kant committed
96
    block_probs = block_probs.unsqueeze(2).unsqueeze(3).expand_as(lm_logits)
Neel Kant's avatar
Neel Kant committed
97
98
    lm_logits = torch.sum(lm_logits * block_probs, dim=1)[:, :labels.shape[1]]

99
100
101
102
103
    lm_loss_ = mpu.vocab_parallel_cross_entropy(lm_logits.contiguous().float(),
                                                labels.contiguous())
    lm_loss = torch.sum(
        lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()

104
    reduced_loss = reduce_losses([lm_loss, retrieval_utility])
105
    # torch.cuda.synchronize()
106
107
108
    return lm_loss, {'lm_loss': reduced_loss[0], 'retrieval_utility': reduced_loss[1]}


109
def get_retrieval_utility(lm_logits, block_probs, labels, loss_mask):
110
111
    """log P(y | z, x) - log P(y | null, x)"""
    # [batch x seq_len x vocab_size]
112
113
114
115
    lm_logits = lm_logits[:, :, :labels.shape[1], :]
    #non_null_block_probs = block_probs[:, :-1]
    #non_null_block_probs /= torch.sum(non_null_block_probs, axis=1, keepdim=True)
    # non_null_block_probs = non_null_block_probsexpand_as(lm_logits[:, :-1, :, :])
116
117
118
119
120
121
122
123
124
125
126
    null_block_lm_logits = lm_logits[:, -1, :, :]
    null_block_loss_ = mpu.vocab_parallel_cross_entropy(null_block_lm_logits.contiguous().float(),
                                                       labels.contiguous())
    null_block_loss = torch.sum(
        null_block_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()

    retrieved_block_losses = []
    for block_num in range(lm_logits.shape[1] - 1):
        retrieved_block_lm_logits = lm_logits[:, block_num, :, :]
        retrieved_block_loss_ = mpu.vocab_parallel_cross_entropy(retrieved_block_lm_logits.contiguous().float(),
                                                                 labels.contiguous())
127
        #retrieved_block_loss_ *= non_null_block_probs[:, block_num].reshape(-1, 1)
128
129
130
        retrieved_block_loss = torch.sum(
            retrieved_block_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()
        retrieved_block_losses.append(retrieved_block_loss)
131
    avg_retrieved_block_loss = torch.sum(torch.cuda.FloatTensor(retrieved_block_losses)) / (lm_logits.shape[1] - 1)
132
133
134

    retrieval_utility = null_block_loss - avg_retrieved_block_loss
    return retrieval_utility
Neel Kant's avatar
Neel Kant committed
135
136


137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
def qa_forward_step(data_iterator, model):
    timers = get_timers()

    # this dataset interface needs to be implemented
    timers('batch generator').start()
    question_tokens, question_attention_mask, answer_tokens, answer_token_lengths = get_qa_batch(data_iterator)
    timers('batch generator').stop()

    batch_span_logits, batch_loss_masks, block_probs = model(question_tokens, question_attention_mask,
                                                             answer_tokens, answer_token_lengths)
    # [batch_size x k x num_spans]
    block_probs = block_probs.unsqueeze(2).expand_as(batch_span_logits)
    batch_span_probs = F.softmax(batch_span_logits, dim=2)
    reduced_block_span_probs = torch.sum(batch_span_probs * block_probs, dim=1)
    qa_span_loss_ = -torch.log(reduced_block_span_probs)
    qa_span_loss = torch.sum(
        qa_span_loss_.view(-1) * batch_loss_masks
    )


Neel Kant's avatar
Neel Kant committed
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
def train_valid_test_datasets_provider(train_val_test_num_samples):
    """Build train, valid and test datasets."""
    args = get_args()
    print_rank_0('> building train, validation, and test datasets '
                 'for BERT ...')

    train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
        data_prefix=args.data_path,
        data_impl=args.data_impl,
        splits_string=args.split,
        train_valid_test_num_samples=train_val_test_num_samples,
        max_seq_length=args.seq_length,
        masked_lm_prob=args.mask_prob,
        short_seq_prob=args.short_seq_prob,
        seed=args.seed,
Neel Kant's avatar
Neel Kant committed
172
173
        skip_warmup=(not args.mmap_warmup),
        dataset_type='realm')
Neel Kant's avatar
Neel Kant committed
174
175
176
177
178
179
180
    print_rank_0("> finished creating BERT ICT datasets ...")

    return train_ds, valid_ds, test_ds


if __name__ == "__main__":
    pretrain(train_valid_test_datasets_provider, model_provider, forward_step,
181
182
             args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'},
             initializer_func=initialize_and_run_async_megatron)