realm_model.py 7.75 KB
Newer Older
1
import os
2
3
import torch

4
from megatron import get_args
Neel Kant's avatar
Neel Kant committed
5
from megatron.checkpointing import get_checkpoint_tracker_filename, get_checkpoint_name
6
7
from megatron.model import BertModel
from megatron.module import MegatronModule
8
from megatron import mpu
9
10
11
12
from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal
from megatron.model.language_model import get_language_model
from megatron.model.utils import scaled_init_method_normal
Neel Kant's avatar
Neel Kant committed
13
from megatron.model.bert_model import bert_attention_mask_func, bert_extended_attention_mask, bert_position_ids
14
15
16
17
18
19
20
21
22
23
24


class ICTBertModel(MegatronModule):
    """Bert-based module for Inverse Cloze task."""
    def __init__(self,
                 ict_head_size,
                 num_tokentypes=1,
                 parallel_output=True,
                 only_query_model=False,
                 only_block_model=False):
        super(ICTBertModel, self).__init__()
25
        bert_kwargs = dict(
26
            ict_head_size=ict_head_size,
27
            num_tokentypes=num_tokentypes,
28
29
30
31
32
33
34
35
            parallel_output=parallel_output
        )
        assert not (only_block_model and only_query_model)
        self.use_block_model = not only_query_model
        self.use_query_model = not only_block_model

        if self.use_query_model:
            # this model embeds (pseudo-)queries - Embed_input in the paper
36
            self.query_model = IREncoderBertModel(**bert_kwargs)
37
38
39
40
            self._query_key = 'question_model'

        if self.use_block_model:
            # this model embeds evidence blocks - Embed_doc in the paper
41
            self.block_model = IREncoderBertModel(**bert_kwargs)
42
43
            self._block_key = 'context_model'

Neel Kant's avatar
Neel Kant committed
44
45
    def forward(self, query_tokens, query_attention_mask, block_tokens, block_attention_mask):
        """Run a forward pass for each of the models and return the respective embeddings."""
46
47
        query_logits = self.embed_query(query_tokens, query_attention_mask)
        block_logits = self.embed_block(block_tokens, block_attention_mask)
Neel Kant's avatar
Neel Kant committed
48
        return query_logits, block_logits
49
50
51
52

    def embed_query(self, query_tokens, query_attention_mask):
        """Embed a batch of tokens using the query model"""
        if self.use_query_model:
53
            query_types = torch.cuda.LongTensor(*query_tokens.shape).fill_(0)
54
55
56
57
58
59
60
61
            query_ict_logits, _ = self.query_model.forward(query_tokens, query_attention_mask, query_types)
            return query_ict_logits
        else:
            raise ValueError("Cannot embed query without query model.")

    def embed_block(self, block_tokens, block_attention_mask):
        """Embed a batch of tokens using the block model"""
        if self.use_block_model:
62
            block_types = torch.cuda.LongTensor(*block_tokens.shape).fill_(0)
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
            block_ict_logits, _ = self.block_model.forward(block_tokens, block_attention_mask, block_types)
            return block_ict_logits
        else:
            raise ValueError("Cannot embed block without block model.")

    def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars=False):
        """Save dict with state dicts of each of the models."""
        state_dict_ = {}
        if self.use_query_model:
            state_dict_[self._query_key] \
                = self.query_model.state_dict_for_save_checkpoint(
                destination, prefix, keep_vars)

        if self.use_block_model:
            state_dict_[self._block_key] \
                = self.block_model.state_dict_for_save_checkpoint(
                destination, prefix, keep_vars)

        return state_dict_

    def load_state_dict(self, state_dict, strict=True):
        """Load the state dicts of each of the models"""
        if self.use_query_model:
            print("Loading ICT query model", flush=True)
            self.query_model.load_state_dict(
                state_dict[self._query_key], strict=strict)

        if self.use_block_model:
            print("Loading ICT block model", flush=True)
            self.block_model.load_state_dict(
                state_dict[self._block_key], strict=strict)
Neel Kant's avatar
Neel Kant committed
94
95

    def init_state_dict_from_bert(self):
96
        """Initialize the state from a pretrained BERT model on iteration zero of ICT pretraining"""
Neel Kant's avatar
Neel Kant committed
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
        args = get_args()
        tracker_filename = get_checkpoint_tracker_filename(args.bert_load)
        if not os.path.isfile(tracker_filename):
            raise FileNotFoundError("Could not find BERT load for ICT")
        with open(tracker_filename, 'r') as f:
            iteration = int(f.read().strip())
            assert iteration > 0

        checkpoint_name = get_checkpoint_name(args.bert_load, iteration, False)
        if mpu.get_data_parallel_rank() == 0:
            print('global rank {} is loading checkpoint {}'.format(
                torch.distributed.get_rank(), checkpoint_name))

        try:
            state_dict = torch.load(checkpoint_name, map_location='cpu')
        except BaseException:
            raise ValueError("Could not load checkpoint")

115
        # load the LM state dict into each model
Neel Kant's avatar
Neel Kant committed
116
117
118
        model_dict = state_dict['model']['language_model']
        self.query_model.language_model.load_state_dict(model_dict)
        self.block_model.language_model.load_state_dict(model_dict)
119
120

        # give each model the same ict_head to begin with as well
Neel Kant's avatar
Neel Kant committed
121
122
        query_ict_head_state_dict = self.state_dict_for_save_checkpoint()[self._query_key]['ict_head']
        self.block_model.ict_head.load_state_dict(query_ict_head_state_dict)
123
124
125


class IREncoderBertModel(MegatronModule):
126
    """BERT-based encoder for queries or blocks used for learned information retrieval."""
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
    def __init__(self, ict_head_size, num_tokentypes=2, parallel_output=True):
        super(IREncoderBertModel, self).__init__()
        args = get_args()

        self.ict_head_size = ict_head_size
        self.parallel_output = parallel_output
        init_method = init_method_normal(args.init_method_std)
        scaled_init_method = scaled_init_method_normal(args.init_method_std,
                                                       args.num_layers)

        self.language_model, self._language_model_key = get_language_model(
            attention_mask_func=bert_attention_mask_func,
            num_tokentypes=num_tokentypes,
            add_pooler=True,
            init_method=init_method,
            scaled_init_method=scaled_init_method)

        self.ict_head = get_linear_layer(args.hidden_size, ict_head_size, init_method)
        self._ict_head_key = 'ict_head'

    def forward(self, input_ids, attention_mask, tokentype_ids=None):
        extended_attention_mask = bert_extended_attention_mask(
            attention_mask, next(self.language_model.parameters()).dtype)
        position_ids = bert_position_ids(input_ids)

        lm_output, pooled_output = self.language_model(
            input_ids,
            position_ids,
            extended_attention_mask,
            tokentype_ids=tokentype_ids)

        # Output.
159
160
        ict_logits = self.ict_head(pooled_output)
        return ict_logits, None
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182

    def state_dict_for_save_checkpoint(self, destination=None, prefix='',
                                       keep_vars=False):
        """For easy load when model is combined with other heads,
        add an extra key."""

        state_dict_ = {}
        state_dict_[self._language_model_key] \
            = self.language_model.state_dict_for_save_checkpoint(
            destination, prefix, keep_vars)
        state_dict_[self._ict_head_key] \
            = self.ict_head.state_dict(destination, prefix, keep_vars)
        return state_dict_

    def load_state_dict(self, state_dict, strict=True):
        """Customized load."""
        self.language_model.load_state_dict(
            state_dict[self._language_model_key], strict=strict)
        self.ict_head.load_state_dict(
            state_dict[self._ict_head_key], strict=strict)