bert_model.py 17.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""BERT model."""

18
19
import pickle

Neel Kant's avatar
Neel Kant committed
20
import numpy as np
21
import torch
22
import torch.nn.functional as F
23

Mohammad's avatar
Mohammad committed
24
from megatron import get_args
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
from megatron.module import MegatronModule

from .language_model import parallel_lm_logits
from .language_model import get_language_model
from .transformer import LayerNorm
from .utils import gelu
from .utils import get_linear_layer
from .utils import init_method_normal
from .utils import scaled_init_method_normal


def bert_attention_mask_func(attention_scores, attention_mask):
    attention_scores = attention_scores + attention_mask
    return attention_scores


def bert_extended_attention_mask(attention_mask, dtype):
    # We create a 3D attention mask from a 2D tensor mask.
    # [b, 1, s]
    attention_mask_b1s = attention_mask.unsqueeze(1)
    # [b, s, 1]
    attention_mask_bs1 = attention_mask.unsqueeze(2)
    # [b, s, s]
    attention_mask_bss = attention_mask_b1s * attention_mask_bs1
    # [b, 1, s, s]
    extended_attention_mask = attention_mask_bss.unsqueeze(1)
    # Since attention_mask is 1.0 for positions we want to attend and 0.0
    # for masked positions, this operation will create a tensor which is
    # 0.0 for positions we want to attend and -10000.0 for masked positions.
    # Since we are adding it to the raw scores before the softmax, this is
    # effectively the same as removing these entirely.
    # fp16 compatibility
    extended_attention_mask = extended_attention_mask.to(dtype=dtype)
    extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0

    return extended_attention_mask


def bert_position_ids(token_ids):
    # Create position ids
    seq_length = token_ids.size(1)
    position_ids = torch.arange(seq_length, dtype=torch.long,
                                device=token_ids.device)
    position_ids = position_ids.unsqueeze(0).expand_as(token_ids)

    return position_ids



class BertLMHead(MegatronModule):
    """Masked LM head for Bert

    Arguments:
        mpu_vocab_size: model parallel size of vocabulary.
        hidden_size: hidden size
        init_method: init method for weight initialization
        layernorm_epsilon: tolerance for layer norm divisions
82
        parallel_output: whether output logits being distributed or not.
83
84
85
86
87
88
89
90
    """
    def __init__(self, mpu_vocab_size, hidden_size, init_method,
                 layernorm_epsilon, parallel_output):

        super(BertLMHead, self).__init__()

        self.bias = torch.nn.Parameter(torch.zeros(mpu_vocab_size))
        self.bias.model_parallel = True
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
91
92
        self.bias.partition_dim = 0
        self.bias.stride = 1
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
        self.parallel_output = parallel_output

        self.dense = get_linear_layer(hidden_size, hidden_size, init_method)
        self.layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)


    def forward(self, hidden_states, word_embeddings_weight):
        hidden_states = self.dense(hidden_states)
        hidden_states = gelu(hidden_states)
        hidden_states = self.layernorm(hidden_states)
        output = parallel_lm_logits(hidden_states,
                                    word_embeddings_weight,
                                    self.parallel_output,
                                    bias=self.bias)
        return output



class BertModel(MegatronModule):
    """Bert Language model."""

Mohammad's avatar
Mohammad committed
114
    def __init__(self, num_tokentypes=2, add_binary_head=True,
Neel Kant's avatar
Neel Kant committed
115
                 ict_head_size=None, parallel_output=True):
116
        super(BertModel, self).__init__()
Mohammad's avatar
Mohammad committed
117
        args = get_args()
118
119

        self.add_binary_head = add_binary_head
120
121
122
123
        self.ict_head_size = ict_head_size
        self.add_ict_head = ict_head_size is not None
        assert not (self.add_binary_head and self.add_ict_head)

124
        self.parallel_output = parallel_output
Mohammad's avatar
Mohammad committed
125
        init_method = init_method_normal(args.init_method_std)
126
        add_pooler = self.add_binary_head or self.add_ict_head
Mohammad's avatar
Mohammad committed
127
128
        scaled_init_method = scaled_init_method_normal(args.init_method_std,
                                                       args.num_layers)
Neel Kant's avatar
Neel Kant committed
129
130
131
132
133

        max_pos_embeds = None
        if not add_binary_head and ict_head_size is None:
            max_pos_embeds = 2 * args.seq_length

134
        self.language_model, self._language_model_key = get_language_model(
Mohammad's avatar
Mohammad committed
135
            attention_mask_func=bert_attention_mask_func,
136
            num_tokentypes=num_tokentypes,
137
            add_pooler=add_pooler,
138
            init_method=init_method,
Neel Kant's avatar
Neel Kant committed
139
140
            scaled_init_method=scaled_init_method,
            max_pos_embeds=max_pos_embeds)
141

Neel Kant's avatar
Neel Kant committed
142
143
144
        if not self.add_ict_head:
            self.lm_head = BertLMHead(
                self.language_model.embedding.word_embeddings.weight.size(0),
Neel Kant's avatar
Neel Kant committed
145
                args.hidden_size, init_method, args.layernorm_epsilon, parallel_output)
Neel Kant's avatar
Neel Kant committed
146
            self._lm_head_key = 'lm_head'
147
        if self.add_binary_head:
Mohammad's avatar
Mohammad committed
148
149
            self.binary_head = get_linear_layer(args.hidden_size, 2,
                                                init_method)
150
            self._binary_head_key = 'binary_head'
151
        elif self.add_ict_head:
Neel Kant's avatar
Neel Kant committed
152
            self.ict_head = get_linear_layer(args.hidden_size, ict_head_size, init_method)
153
            self._ict_head_key = 'ict_head'
154

155
    def forward(self, input_ids, attention_mask, tokentype_ids=None):
156
157
158
159
160

        extended_attention_mask = bert_extended_attention_mask(
            attention_mask, next(self.language_model.parameters()).dtype)
        position_ids = bert_position_ids(input_ids)

161
        if self.add_binary_head or self.add_ict_head:
162
163
164
165
166
167
168
169
170
171
172
173
174
            lm_output, pooled_output = self.language_model(
                input_ids,
                position_ids,
                extended_attention_mask,
                tokentype_ids=tokentype_ids)
        else:
            lm_output = self.language_model(
                input_ids,
                position_ids,
                extended_attention_mask,
                tokentype_ids=tokentype_ids)

        # Output.
Neel Kant's avatar
Neel Kant committed
175
176
177
178
        if self.add_ict_head:
            ict_logits = self.ict_head(pooled_output)
            return ict_logits, None

179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
        lm_logits = self.lm_head(
            lm_output, self.language_model.embedding.word_embeddings.weight)
        if self.add_binary_head:
            binary_logits = self.binary_head(pooled_output)
            return lm_logits, binary_logits

        return lm_logits, None


    def state_dict_for_save_checkpoint(self, destination=None, prefix='',
                                       keep_vars=False):
        """For easy load when model is combined with other heads,
        add an extra key."""

        state_dict_ = {}
        state_dict_[self._language_model_key] \
            = self.language_model.state_dict_for_save_checkpoint(
                destination, prefix, keep_vars)
197
198
199
200
        if not self.add_ict_head:
            state_dict_[self._lm_head_key] \
                = self.lm_head.state_dict_for_save_checkpoint(
                    destination, prefix, keep_vars)
201
202
203
        if self.add_binary_head:
            state_dict_[self._binary_head_key] \
                = self.binary_head.state_dict(destination, prefix, keep_vars)
204
205
206
        elif self.add_ict_head:
            state_dict_[self._ict_head_key] \
                = self.ict_head.state_dict(destination, prefix, keep_vars)
207
208
209
210
211
212
213
214
        return state_dict_


    def load_state_dict(self, state_dict, strict=True):
        """Customized load."""

        self.language_model.load_state_dict(
            state_dict[self._language_model_key], strict=strict)
215
216
217
        if not self.add_ict_head:
            self.lm_head.load_state_dict(
                state_dict[self._lm_head_key], strict=strict)
218
        if self.add_binary_head:
Neel Kant's avatar
Neel Kant committed
219
220
            self.binary_head.load_state_dict(
                state_dict[self._binary_head_key], strict=strict)
221
        elif self.add_ict_head:
Neel Kant's avatar
Neel Kant committed
222
223
            self.ict_head.load_state_dict(
                state_dict[self._ict_head_key], strict=strict)
224

225

Neel Kant's avatar
Neel Kant committed
226
class REALMBertModel(MegatronModule):
Neel Kant's avatar
Neel Kant committed
227
228

    # TODO: load BertModel checkpoint
229
    def __init__(self, retriever):
Neel Kant's avatar
Neel Kant committed
230
231
        super(REALMBertModel, self).__init__()
        bert_args = dict(
Neel Kant's avatar
Neel Kant committed
232
            num_tokentypes=1,
Neel Kant's avatar
Neel Kant committed
233
234
235
236
237
238
            add_binary_head=False,
            parallel_output=True
        )
        self.lm_model = BertModel(**bert_args)
        self._lm_key = 'realm_lm'

239
240
241
242
243
244
        self.retriever = retriever
        self._retriever_key = 'retriever'

    def forward(self, tokens, attention_mask):
        # [batch_size x 5 x seq_length]
        top5_block_tokens, top5_block_attention_mask = self.retriever.retrieve_evidence_blocks(tokens, attention_mask)
Neel Kant's avatar
Neel Kant committed
245
246
247
248
249
250
251
        batch_size = tokens.shape[0]

        seq_length = top5_block_tokens.shape[2]
        top5_block_tokens = torch.cuda.LongTensor(top5_block_tokens).reshape(-1, seq_length)
        top5_block_attention_mask = torch.cuda.LongTensor(top5_block_attention_mask).reshape(-1, seq_length)

        # [batch_size x 5 x embed_size]
Neel Kant's avatar
Neel Kant committed
252
253
        fresh_block_logits = self.retriever.ict_model(None, None, top5_block_tokens, top5_block_attention_mask, only_block=True).reshape(batch_size, 5, -1)
        # fresh_block_logits.register_hook(lambda x: print("fresh block: ", x.shape, flush=True))
Neel Kant's avatar
Neel Kant committed
254
255

        # [batch_size x embed_size x 1]
Neel Kant's avatar
Neel Kant committed
256
        query_logits = self.retriever.ict_model(tokens, attention_mask, None, None, only_query=True).unsqueeze(2)
Neel Kant's avatar
Neel Kant committed
257

258
259

        # [batch_size x 5]
Neel Kant's avatar
Neel Kant committed
260
261
        fresh_block_scores = torch.matmul(fresh_block_logits, query_logits).squeeze()
        block_probs = F.softmax(fresh_block_scores, dim=1)
262

Neel Kant's avatar
Neel Kant committed
263
264
265
        # [batch_size * 5 x seq_length]
        tokens = torch.stack([tokens.unsqueeze(1)] * 5, dim=1).reshape(-1, seq_length)
        attention_mask = torch.stack([attention_mask.unsqueeze(1)] * 5, dim=1).reshape(-1, seq_length)
266

Neel Kant's avatar
Neel Kant committed
267
268
269
        # [batch_size * 5 x 2 * seq_length]
        all_tokens = torch.cat((tokens, top5_block_tokens), axis=1)
        all_attention_mask = torch.cat((attention_mask, top5_block_attention_mask), axis=1)
270
271
272
273
        all_token_types = torch.zeros(all_tokens.shape).type(torch.int64).cuda()

        # [batch_size x 5 x 2 * seq_length x vocab_size]
        lm_logits, _ = self.lm_model.forward(all_tokens, all_attention_mask, all_token_types)
Neel Kant's avatar
Neel Kant committed
274
        lm_logits = lm_logits.reshape(batch_size, 5, 2 * seq_length, -1)
275
276
277
278
279
280
281
282
283
284
        return lm_logits, block_probs

    def state_dict_for_save_checkpoint(self, destination=None, prefix='',
                                       keep_vars=False):
        """For easy load when model is combined with other heads,
        add an extra key."""

        state_dict_ = {}
        state_dict_[self._lm_key] = self.lm_model.state_dict_for_save_checkpoint(destination, prefix, keep_vars)
        return state_dict_
Neel Kant's avatar
Neel Kant committed
285
286
287


class REALMRetriever(MegatronModule):
Neel Kant's avatar
Neel Kant committed
288
    """Retriever which uses a pretrained ICTBertModel and a HashedIndex"""
Neel Kant's avatar
Neel Kant committed
289
290
291
292
293
    def __init__(self, ict_model, ict_dataset, hashed_index, top_k=5):
        super(REALMRetriever, self).__init__()
        self.ict_model = ict_model
        self.ict_dataset = ict_dataset
        self.hashed_index = hashed_index
Neel Kant's avatar
Neel Kant committed
294
        self.top_k = top_k
Neel Kant's avatar
Neel Kant committed
295
296
297
298
299
300
301
302
303

    def retrieve_evidence_blocks_text(self, query_text):
        """Get the top k evidence blocks for query_text in text form"""
        print("-" * 100)
        print("Query: ", query_text)
        padless_max_len = self.ict_dataset.max_seq_length - 2
        query_tokens = self.ict_dataset.encode_text(query_text)[:padless_max_len]

        query_tokens, query_pad_mask = self.ict_dataset.concat_and_pad_tokens(query_tokens)
Neel Kant's avatar
Neel Kant committed
304
305
        query_tokens = torch.cuda.LongTensor(np.array(query_tokens).reshape(1, -1))
        query_pad_mask = torch.cuda.LongTensor(np.array(query_pad_mask).reshape(1, -1))
Neel Kant's avatar
Neel Kant committed
306

307
        top5_block_tokens, _ = self.retrieve_evidence_blocks(query_tokens, query_pad_mask)
Neel Kant's avatar
Neel Kant committed
308
        for i, block in enumerate(top5_block_tokens[0]):
309
            block_text = self.ict_dataset.decode_tokens(block)
Neel Kant's avatar
Neel Kant committed
310
            print('\n    > Block {}: {}'.format(i, block_text))
Neel Kant's avatar
Neel Kant committed
311

312
    def retrieve_evidence_blocks(self, query_tokens, query_pad_mask):
Neel Kant's avatar
Neel Kant committed
313
314
        """Embed blocks to be used in a forward pass"""
        query_embeds = self.ict_model(query_tokens, query_pad_mask, None, None, only_query=True)
315
        query_hashes = self.hashed_index.hash_embeds(query_embeds)
Neel Kant's avatar
Neel Kant committed
316

317
        block_buckets = [self.hashed_index.get_block_bucket(hash) for hash in query_hashes]
Neel Kant's avatar
Neel Kant committed
318
319
320
321
322
323
324
325
        for j, bucket in enumerate(block_buckets):
            if len(bucket) < 5:
                for i in range(len(block_buckets)):
                    if len(block_buckets[i]) > 5:
                        block_buckets[j] = block_buckets[i].copy()

        # [batch_size x max_bucket_population x embed_size]
        block_embeds = [torch.cuda.FloatTensor(np.array([self.hashed_index.get_block_embed(arr[3])
326
                                                        for arr in bucket])) for bucket in block_buckets]
Neel Kant's avatar
Neel Kant committed
327

328
329
        all_top5_tokens, all_top5_pad_masks = [], []
        for query_embed, embed_tensor, bucket in zip(query_embeds, block_embeds, block_buckets):
Neel Kant's avatar
Neel Kant committed
330
331
            retrieval_scores = query_embed.matmul(torch.transpose(embed_tensor.reshape(-1, query_embed.size()[0]), 0, 1))
            print(retrieval_scores.shape, flush=True)
332
333
334
335
            top5_vals, top5_indices = torch.topk(retrieval_scores, k=5, sorted=True)

            top5_start_end_doc = [bucket[idx][:3] for idx in top5_indices.squeeze()]
            # top_k tuples of (block_tokens, block_pad_mask)
Neel Kant's avatar
Neel Kant committed
336
337
338
            top5_block_data = [self.ict_dataset.get_block(*indices) for indices in top5_start_end_doc]

            top5_tokens, top5_pad_masks = zip(*top5_block_data)
339
340
341
342

            all_top5_tokens.append(np.array(top5_tokens))
            all_top5_pad_masks.append(np.array(top5_pad_masks))

Neel Kant's avatar
Neel Kant committed
343
        # [batch_size x 5 x seq_length]
Neel Kant's avatar
Neel Kant committed
344
        return np.array(all_top5_tokens), np.array(all_top5_pad_masks)
Neel Kant's avatar
Neel Kant committed
345
346


347
class ICTBertModel(MegatronModule):
Neel Kant's avatar
Neel Kant committed
348
    """Bert-based module for Inverse Cloze task."""
349
350
    def __init__(self,
                 ict_head_size,
351
352
353
354
                 num_tokentypes=1,
                 parallel_output=True,
                 only_query_model=False,
                 only_block_model=False):
355
356
        super(ICTBertModel, self).__init__()
        bert_args = dict(
Neel Kant's avatar
Neel Kant committed
357
            num_tokentypes=num_tokentypes,
358
359
            add_binary_head=False,
            ict_head_size=ict_head_size,
Neel Kant's avatar
Neel Kant committed
360
361
            parallel_output=parallel_output
        )
Neel Kant's avatar
Neel Kant committed
362
        assert not (only_block_model and only_query_model)
363
364
        self.use_block_model = not only_query_model
        self.use_query_model = not only_block_model
365

366
367
368
369
        if self.use_query_model:
            # this model embeds (pseudo-)queries - Embed_input in the paper
            self.query_model = BertModel(**bert_args)
            self._query_key = 'question_model'
370

371
372
373
374
        if self.use_block_model:
            # this model embeds evidence blocks - Embed_doc in the paper
            self.block_model = BertModel(**bert_args)
            self._block_key = 'context_model'
375

Neel Kant's avatar
Neel Kant committed
376
    def forward(self, query_tokens, query_attention_mask, block_tokens, block_attention_mask, only_query=False, only_block=False):
Neel Kant's avatar
Neel Kant committed
377
        """Run a forward pass for each of the models and compute the similarity scores."""
Neel Kant's avatar
Neel Kant committed
378
379
380
381
382
383
384
385

        if only_query:
            return self.embed_query(query_tokens, query_attention_mask)

        if only_block:
            return self.embed_block(block_tokens, block_attention_mask)


386
387
388
389
390
391
392
393
394
395
        query_logits = self.embed_query(query_tokens, query_attention_mask)
        block_logits = self.embed_block(block_tokens, block_attention_mask)

        # [batch x embed] * [embed x batch]
        retrieval_scores = query_logits.matmul(torch.transpose(block_logits, 0, 1))
        return retrieval_scores

    def embed_query(self, query_tokens, query_attention_mask):
        """Embed a batch of tokens using the query model"""
        if self.use_query_model:
Neel Kant's avatar
Neel Kant committed
396
            query_types = torch.zeros(query_tokens.shape).type(torch.int64).cuda()
397
398
399
400
401
402
403
404
            query_ict_logits, _ = self.query_model.forward(query_tokens, query_attention_mask, query_types)
            return query_ict_logits
        else:
            raise ValueError("Cannot embed query without query model.")

    def embed_block(self, block_tokens, block_attention_mask):
        """Embed a batch of tokens using the block model"""
        if self.use_block_model:
Neel Kant's avatar
Neel Kant committed
405
            block_types = torch.zeros(block_tokens.shape).type(torch.int64).cuda()
406
407
408
409
            block_ict_logits, _ = self.block_model.forward(block_tokens, block_attention_mask, block_types)
            return block_ict_logits
        else:
            raise ValueError("Cannot embed block without block model.")
410

411
    def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars=False):
Neel Kant's avatar
Neel Kant committed
412
        """Save dict with state dicts of each of the models."""
413
        state_dict_ = {}
414
415
416
417
418
419
420
421
422
423
        if self.use_query_model:
            state_dict_[self._query_key] \
                = self.query_model.state_dict_for_save_checkpoint(
                destination, prefix, keep_vars)

        if self.use_block_model:
            state_dict_[self._block_key] \
                = self.block_model.state_dict_for_save_checkpoint(
                destination, prefix, keep_vars)

424
425
426
        return state_dict_

    def load_state_dict(self, state_dict, strict=True):
Neel Kant's avatar
Neel Kant committed
427
        """Load the state dicts of each of the models"""
428
        if self.use_query_model:
Neel Kant's avatar
Neel Kant committed
429
            print("Loading ICT query model", flush=True)
430
431
432
433
            self.query_model.load_state_dict(
                state_dict[self._query_key], strict=strict)

        if self.use_block_model:
Neel Kant's avatar
Neel Kant committed
434
            print("Loading ICT block model", flush=True)
435
436
            self.block_model.load_state_dict(
                state_dict[self._block_key], strict=strict)