realm_model.py 5.56 KB
Newer Older
1
import os
2
3
4
5
import numpy as np
import torch
import torch.nn.functional as F

6
from megatron import get_args
7
from megatron.checkpointing import load_checkpoint, get_checkpoint_tracker_filename, get_checkpoint_name
8
from megatron.data.realm_index import detach, BlockData, FaissMIPSIndex
9
from megatron.model import BertModel
10
from megatron.model.utils import get_linear_layer, init_method_normal
11
from megatron.module import MegatronModule
12
13
from megatron.utils import report_memory
from megatron import mpu
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54


class ICTBertModel(MegatronModule):
    """Bert-based module for Inverse Cloze task."""
    def __init__(self,
                 ict_head_size,
                 num_tokentypes=1,
                 parallel_output=True,
                 only_query_model=False,
                 only_block_model=False):
        super(ICTBertModel, self).__init__()
        bert_args = dict(
            num_tokentypes=num_tokentypes,
            add_binary_head=False,
            ict_head_size=ict_head_size,
            parallel_output=parallel_output
        )
        assert not (only_block_model and only_query_model)
        self.use_block_model = not only_query_model
        self.use_query_model = not only_block_model

        if self.use_query_model:
            # this model embeds (pseudo-)queries - Embed_input in the paper
            self.query_model = BertModel(**bert_args)
            self._query_key = 'question_model'

        if self.use_block_model:
            # this model embeds evidence blocks - Embed_doc in the paper
            self.block_model = BertModel(**bert_args)
            self._block_key = 'context_model'

    def forward(self, query_tokens, query_attention_mask, block_tokens, block_attention_mask, only_query=False, only_block=False):
        """Run a forward pass for each of the models and compute the similarity scores."""
        if only_query:
            return self.embed_query(query_tokens, query_attention_mask)

        if only_block:
            return self.embed_block(block_tokens, block_attention_mask)

        query_logits = self.embed_query(query_tokens, query_attention_mask)
        block_logits = self.embed_block(block_tokens, block_attention_mask)
Neel Kant's avatar
Neel Kant committed
55
        return query_logits, block_logits
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100

    def embed_query(self, query_tokens, query_attention_mask):
        """Embed a batch of tokens using the query model"""
        if self.use_query_model:
            query_types = torch.zeros(query_tokens.shape).type(torch.int64).cuda()
            query_ict_logits, _ = self.query_model.forward(query_tokens, query_attention_mask, query_types)
            return query_ict_logits
        else:
            raise ValueError("Cannot embed query without query model.")

    def embed_block(self, block_tokens, block_attention_mask):
        """Embed a batch of tokens using the block model"""
        if self.use_block_model:
            block_types = torch.zeros(block_tokens.shape).type(torch.int64).cuda()
            block_ict_logits, _ = self.block_model.forward(block_tokens, block_attention_mask, block_types)
            return block_ict_logits
        else:
            raise ValueError("Cannot embed block without block model.")

    def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars=False):
        """Save dict with state dicts of each of the models."""
        state_dict_ = {}
        if self.use_query_model:
            state_dict_[self._query_key] \
                = self.query_model.state_dict_for_save_checkpoint(
                destination, prefix, keep_vars)

        if self.use_block_model:
            state_dict_[self._block_key] \
                = self.block_model.state_dict_for_save_checkpoint(
                destination, prefix, keep_vars)

        return state_dict_

    def load_state_dict(self, state_dict, strict=True):
        """Load the state dicts of each of the models"""
        if self.use_query_model:
            print("Loading ICT query model", flush=True)
            self.query_model.load_state_dict(
                state_dict[self._query_key], strict=strict)

        if self.use_block_model:
            print("Loading ICT block model", flush=True)
            self.block_model.load_state_dict(
                state_dict[self._block_key], strict=strict)
Neel Kant's avatar
Neel Kant committed
101
102

    def init_state_dict_from_bert(self):
103
        """Initialize the state from a pretrained BERT model on iteration zero of ICT pretraining"""
Neel Kant's avatar
Neel Kant committed
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
        args = get_args()
        tracker_filename = get_checkpoint_tracker_filename(args.bert_load)
        if not os.path.isfile(tracker_filename):
            raise FileNotFoundError("Could not find BERT load for ICT")
        with open(tracker_filename, 'r') as f:
            iteration = int(f.read().strip())
            assert iteration > 0

        checkpoint_name = get_checkpoint_name(args.bert_load, iteration, False)
        if mpu.get_data_parallel_rank() == 0:
            print('global rank {} is loading checkpoint {}'.format(
                torch.distributed.get_rank(), checkpoint_name))

        try:
            state_dict = torch.load(checkpoint_name, map_location='cpu')
        except BaseException:
            raise ValueError("Could not load checkpoint")

122
        # load the LM state dict into each model
Neel Kant's avatar
Neel Kant committed
123
124
125
        model_dict = state_dict['model']['language_model']
        self.query_model.language_model.load_state_dict(model_dict)
        self.block_model.language_model.load_state_dict(model_dict)
126
127

        # give each model the same ict_head to begin with as well
Neel Kant's avatar
Neel Kant committed
128
129
        query_ict_head_state_dict = self.state_dict_for_save_checkpoint()[self._query_key]['ict_head']
        self.block_model.ict_head.load_state_dict(query_ict_head_state_dict)