pretrain_realm.py 7.21 KB
Newer Older
Neel Kant's avatar
Neel Kant committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Pretrain BERT for Inverse Cloze Task"""
import torch
import torch.nn.functional as F

20
from indexer import load_ict_checkpoint, get_ict_dataset
Neel Kant's avatar
Neel Kant committed
21
from megatron.data.realm_index import BlockData, RandProjectionLSHIndex, FaissMIPSIndex
Neel Kant's avatar
Neel Kant committed
22
23
24
25
from megatron import get_args
from megatron import get_timers
from megatron import mpu
from megatron import print_rank_0
26
from megatron.data.dataset_utils import build_train_valid_test_datasets
27
28
from megatron.model import REALMBertModel, REALMRetriever
from megatron.training import pretrain
Neel Kant's avatar
Neel Kant committed
29
30
31
32
from megatron.utils import reduce_losses

num_batches = 0

33

Neel Kant's avatar
Neel Kant committed
34
35
36
def model_provider():
    """Build the model."""
    args = get_args()
37
    print_rank_0('building REALM models ...')
Neel Kant's avatar
Neel Kant committed
38

39
    ict_model = load_ict_checkpoint()
40
    ict_dataset = get_ict_dataset(use_titles=False)
Neel Kant's avatar
Neel Kant committed
41
    all_block_data = BlockData.load_from_file(args.block_data_path)
Neel Kant's avatar
Neel Kant committed
42
    # hashed_index = RandProjectionLSHIndex.load_from_file(args.block_index_path)
43
    hashed_index = FaissMIPSIndex(index_type='flat_ip', embed_size=128)
Neel Kant's avatar
Neel Kant committed
44
    hashed_index.add_block_embed_data(all_block_data)
45

46
47
    # top_k + 1 because we may need to exclude trivial candidate
    retriever = REALMRetriever(ict_model, ict_dataset, all_block_data, hashed_index, args.block_top_k + 1)
48
    model = REALMBertModel(retriever)
49

50
    return model
Neel Kant's avatar
Neel Kant committed
51
52
53
54


def get_batch(data_iterator):
    # Items and their type.
55
    keys = ['tokens', 'labels', 'loss_mask', 'pad_mask', 'query_block_indices']
Neel Kant's avatar
Neel Kant committed
56
57
58
59
60
61
62
63
64
65
    datatype = torch.int64

    # Broadcast data.
    if data_iterator is None:
        data = None
    else:
        data = next(data_iterator)
    data_b = mpu.broadcast_data(keys, data, datatype)

    # Unpack.
66
67
68
69
    tokens = data_b['tokens'].long()
    labels = data_b['labels'].long()
    loss_mask = data_b['loss_mask'].long()
    pad_mask = data_b['pad_mask'].long()
70
    query_block_indices = data_b['query_block_indices'].long()
Neel Kant's avatar
Neel Kant committed
71

72
    return tokens, labels, loss_mask, pad_mask, query_block_indices
Neel Kant's avatar
Neel Kant committed
73
74


75
76
77
78
79
def get_qa_batch(data_iterator):
    question_tokens, question_attention_mask, answer_tokens, answer_token_lengths = next(data_iterator)
    return question_tokens, question_attention_mask, answer_tokens, answer_token_lengths


Neel Kant's avatar
Neel Kant committed
80
81
82
83
84
85
def forward_step(data_iterator, model):
    """Forward step."""
    timers = get_timers()

    # Get the batch.
    timers('batch generator').start()
86
    tokens, labels, loss_mask, pad_mask, query_block_indices = get_batch(data_iterator)
Neel Kant's avatar
Neel Kant committed
87
88
89
    timers('batch generator').stop()

    # Forward model.
90
    lm_logits, block_probs = model(tokens, pad_mask, query_block_indices)
91
    with torch.no_grad():
92
        retrieval_utility = get_retrieval_utility(lm_logits, block_probs, labels, loss_mask)
93
94

    # P(y|x) = sum_z(P(y|z, x) * P(z|x))
Neel Kant's avatar
Neel Kant committed
95
    block_probs = block_probs.unsqueeze(2).unsqueeze(3).expand_as(lm_logits)
Neel Kant's avatar
Neel Kant committed
96
97
    lm_logits = torch.sum(lm_logits * block_probs, dim=1)[:, :labels.shape[1]]

98
99
100
101
102
    lm_loss_ = mpu.vocab_parallel_cross_entropy(lm_logits.contiguous().float(),
                                                labels.contiguous())
    lm_loss = torch.sum(
        lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()

103
    reduced_loss = reduce_losses([lm_loss, retrieval_utility])
Neel Kant's avatar
Neel Kant committed
104
    torch.cuda.synchronize()
105
106
107
    return lm_loss, {'lm_loss': reduced_loss[0], 'retrieval_utility': reduced_loss[1]}


108
def get_retrieval_utility(lm_logits, block_probs, labels, loss_mask):
109
110
    """log P(y | z, x) - log P(y | null, x)"""
    # [batch x seq_len x vocab_size]
111
112
113
114
    lm_logits = lm_logits[:, :, :labels.shape[1], :]
    #non_null_block_probs = block_probs[:, :-1]
    #non_null_block_probs /= torch.sum(non_null_block_probs, axis=1, keepdim=True)
    # non_null_block_probs = non_null_block_probsexpand_as(lm_logits[:, :-1, :, :])
115
116
117
118
119
120
121
122
123
124
125
    null_block_lm_logits = lm_logits[:, -1, :, :]
    null_block_loss_ = mpu.vocab_parallel_cross_entropy(null_block_lm_logits.contiguous().float(),
                                                       labels.contiguous())
    null_block_loss = torch.sum(
        null_block_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()

    retrieved_block_losses = []
    for block_num in range(lm_logits.shape[1] - 1):
        retrieved_block_lm_logits = lm_logits[:, block_num, :, :]
        retrieved_block_loss_ = mpu.vocab_parallel_cross_entropy(retrieved_block_lm_logits.contiguous().float(),
                                                                 labels.contiguous())
126
        #retrieved_block_loss_ *= non_null_block_probs[:, block_num].reshape(-1, 1)
127
128
129
        retrieved_block_loss = torch.sum(
            retrieved_block_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()
        retrieved_block_losses.append(retrieved_block_loss)
130
    avg_retrieved_block_loss = torch.sum(torch.cuda.FloatTensor(retrieved_block_losses)) / (lm_logits.shape[1] - 1)
131
132
133

    retrieval_utility = null_block_loss - avg_retrieved_block_loss
    return retrieval_utility
Neel Kant's avatar
Neel Kant committed
134
135


136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
def qa_forward_step(data_iterator, model):
    timers = get_timers()

    # this dataset interface needs to be implemented
    timers('batch generator').start()
    question_tokens, question_attention_mask, answer_tokens, answer_token_lengths = get_qa_batch(data_iterator)
    timers('batch generator').stop()

    batch_span_logits, batch_loss_masks, block_probs = model(question_tokens, question_attention_mask,
                                                             answer_tokens, answer_token_lengths)
    # [batch_size x k x num_spans]
    block_probs = block_probs.unsqueeze(2).expand_as(batch_span_logits)
    batch_span_probs = F.softmax(batch_span_logits, dim=2)
    reduced_block_span_probs = torch.sum(batch_span_probs * block_probs, dim=1)
    qa_span_loss_ = -torch.log(reduced_block_span_probs)
    qa_span_loss = torch.sum(
        qa_span_loss_.view(-1) * batch_loss_masks
    )


Neel Kant's avatar
Neel Kant committed
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
def train_valid_test_datasets_provider(train_val_test_num_samples):
    """Build train, valid and test datasets."""
    args = get_args()
    print_rank_0('> building train, validation, and test datasets '
                 'for BERT ...')

    train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
        data_prefix=args.data_path,
        data_impl=args.data_impl,
        splits_string=args.split,
        train_valid_test_num_samples=train_val_test_num_samples,
        max_seq_length=args.seq_length,
        masked_lm_prob=args.mask_prob,
        short_seq_prob=args.short_seq_prob,
        seed=args.seed,
Neel Kant's avatar
Neel Kant committed
171
172
        skip_warmup=(not args.mmap_warmup),
        dataset_type='realm')
Neel Kant's avatar
Neel Kant committed
173
174
175
176
177
178
179
180
    print_rank_0("> finished creating BERT ICT datasets ...")

    return train_ds, valid_ds, test_ds


if __name__ == "__main__":
    pretrain(train_valid_test_datasets_provider, model_provider, forward_step,
             args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})