biencoder_model.py 12.3 KB
Newer Older
Mostofa Patwary's avatar
Mostofa Patwary committed
1
2
3
4
5
import os
import torch
import sys

from megatron import get_args, print_rank_0
Mostofa Patwary's avatar
Mostofa Patwary committed
6
7
8
from megatron.checkpointing import fix_query_key_value_ordering
from megatron.checkpointing import get_checkpoint_tracker_filename
from megatron.checkpointing import get_checkpoint_name
Mostofa Patwary's avatar
Mostofa Patwary committed
9
10
from megatron import mpu, get_tokenizer
from megatron.model.bert_model import bert_position_ids
Mostofa Patwary's avatar
Mostofa Patwary committed
11
from megatron.model.enums import AttnMaskType
Mostofa Patwary's avatar
Mostofa Patwary committed
12
13
14
15
from megatron.model.language_model import get_language_model
from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal
from megatron.model.utils import scaled_init_method_normal
Mostofa Patwary's avatar
Mostofa Patwary committed
16
from .module import MegatronModule
Mostofa Patwary's avatar
Mostofa Patwary committed
17
18
19

def biencoder_model_provider(only_query_model=False,
                             only_context_model=False,
20
                             biencoder_shared_query_context_model=False):
Mostofa Patwary's avatar
Mostofa Patwary committed
21
22
23
24
25
26
27
28
29
    """Build the model."""
    args = get_args()

    assert mpu.get_tensor_model_parallel_world_size() == 1 and \
        mpu.get_pipeline_model_parallel_world_size() == 1, \
        "Model parallel size > 1 not supported for ICT"

    print_rank_0('building BiEncoderModel...')

Mostofa Patwary's avatar
Mostofa Patwary committed
30
    # simpler to just keep using 2 tokentypes since
Mostofa Patwary's avatar
Mostofa Patwary committed
31
32
33
    # the LM we initialize with has 2 tokentypes
    model = BiEncoderModel(
        num_tokentypes=2,
34
        parallel_output=False,
Mostofa Patwary's avatar
Mostofa Patwary committed
35
36
        only_query_model=only_query_model,
        only_context_model=only_context_model,
37
38
        biencoder_shared_query_context_model=\
            biencoder_shared_query_context_model)
Mostofa Patwary's avatar
Mostofa Patwary committed
39
40
41
42
43
44
45
46
47
48
49
50

    return model


class BiEncoderModel(MegatronModule):
    """Bert-based module for Biencoder model."""

    def __init__(self,
                 num_tokentypes=1,
                 parallel_output=True,
                 only_query_model=False,
                 only_context_model=False,
51
                 biencoder_shared_query_context_model=False):
Mostofa Patwary's avatar
Mostofa Patwary committed
52
53
54
55
56
57
58
        super(BiEncoderModel, self).__init__()
        args = get_args()

        bert_kwargs = dict(
            num_tokentypes=num_tokentypes,
            parallel_output=parallel_output)

59
60
        self.biencoder_shared_query_context_model = \
            biencoder_shared_query_context_model
Mostofa Patwary's avatar
Mostofa Patwary committed
61
62
63
        assert not (only_context_model and only_query_model)
        self.use_context_model = not only_query_model
        self.use_query_model = not only_context_model
64
        self.biencoder_projection_dim = args.biencoder_projection_dim
Mostofa Patwary's avatar
Mostofa Patwary committed
65

66
        if self.biencoder_shared_query_context_model:
Mostofa Patwary's avatar
Mostofa Patwary committed
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
            self.model = PretrainedBertModel(**bert_kwargs)
            self._model_key = 'shared_model'
            self.query_model, self.context_model = self.model, self.model
        else:
            if self.use_query_model:
                # this model embeds (pseudo-)queries - Embed_input in the paper
                self.query_model = PretrainedBertModel(**bert_kwargs)
                self._query_key = 'query_model'

            if self.use_context_model:
                # this model embeds evidence blocks - Embed_doc in the paper
                self.context_model = PretrainedBertModel(**bert_kwargs)
                self._context_key = 'context_model'

    def forward(self, query_tokens, query_attention_mask, query_types,
                context_tokens, context_attention_mask, context_types):
Mostofa Patwary's avatar
Mostofa Patwary committed
83
        """Run a forward pass for each of the models and
Mostofa Patwary's avatar
Mostofa Patwary committed
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
        return the respective embeddings."""

        if self.use_query_model:
            query_logits = self.embed_text(self.query_model,
                                           query_tokens,
                                           query_attention_mask,
                                           query_types)
        else:
            raise ValueError("Cannot embed query without the query model.")
        if self.use_context_model:
            context_logits = self.embed_text(self.context_model,
                                             context_tokens,
                                             context_attention_mask,
                                             context_types)
        else:
            raise ValueError("Cannot embed block without the block model.")
        return query_logits, context_logits

    @staticmethod
    def embed_text(model, tokens, attention_mask, token_types):
        """Embed a batch of tokens using the model"""
        logits = model(tokens,
                              attention_mask,
                              token_types)
        return logits

    def state_dict_for_save_checkpoint(self, destination=None, \
        prefix='', keep_vars=False):
        """Save dict with state dicts of each of the models."""
        state_dict_ = {}
114
        if self.biencoder_shared_query_context_model:
Mostofa Patwary's avatar
Mostofa Patwary committed
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
            state_dict_[self._model_key] = \
                self.model.state_dict_for_save_checkpoint(destination,
                                                          prefix,
                                                          keep_vars)
        else:
            if self.use_query_model:
                state_dict_[self._query_key] = \
                    self.query_model.state_dict_for_save_checkpoint(
                        destination, prefix, keep_vars)

            if self.use_context_model:
                state_dict_[self._context_key] = \
                    self.context_model.state_dict_for_save_checkpoint(
                        destination, prefix, keep_vars)

        return state_dict_

    def load_state_dict(self, state_dict, strict=True):
        """Load the state dicts of each of the models"""
134
        if self.biencoder_shared_query_context_model:
Mostofa Patwary's avatar
Mostofa Patwary committed
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
            print_rank_0("Loading shared query-context model")
            self.model.load_state_dict(state_dict[self._model_key], \
                strict=strict)
        else:
            if self.use_query_model:
                print_rank_0("Loading query model")
                self.query_model.load_state_dict( \
                    state_dict[self._query_key], strict=strict)

            if self.use_context_model:
                print_rank_0("Loading context model")
                self.context_model.load_state_dict( \
                    state_dict[self._context_key], strict=strict)

    def init_state_dict_from_bert(self):
Mostofa Patwary's avatar
Mostofa Patwary committed
150
        """Initialize the state from a pretrained BERT model
Mostofa Patwary's avatar
Mostofa Patwary committed
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
        on iteration zero of ICT pretraining"""
        args = get_args()

        if args.bert_load is None:
            print_rank_0("bert-load argument is None")
            return

        tracker_filename = get_checkpoint_tracker_filename(args.bert_load)
        if not os.path.isfile(tracker_filename):
            raise FileNotFoundError("Could not find BERT checkpoint")
        with open(tracker_filename, 'r') as f:
            iteration = int(f.read().strip())
            assert iteration > 0

        checkpoint_name = get_checkpoint_name(args.bert_load, iteration, False)
        if mpu.get_data_parallel_rank() == 0:
            print('global rank {} is loading BERT checkpoint {}'.format(
                torch.distributed.get_rank(), checkpoint_name))

Mostofa Patwary's avatar
Mostofa Patwary committed
170
        # Load the checkpoint.
Mostofa Patwary's avatar
Mostofa Patwary committed
171
172
        try:
            state_dict = torch.load(checkpoint_name, map_location='cpu')
Mostofa Patwary's avatar
Mostofa Patwary committed
173
174
175
176
177
178
179
180
181
182
183
        except ModuleNotFoundError:
            from megatron.fp16_deprecated import loss_scaler
            # For backward compatibility.
            print_rank_0(' > deserializing using the old code structure ...')
            sys.modules['fp16.loss_scaler'] = sys.modules[
                'megatron.fp16_deprecated.loss_scaler']
            sys.modules['megatron.fp16.loss_scaler'] = sys.modules[
                'megatron.fp16_deprecated.loss_scaler']
            state_dict = torch.load(checkpoint_name, map_location='cpu')
            sys.modules.pop('fp16.loss_scaler', None)
            sys.modules.pop('megatron.fp16.loss_scaler', None)
Mostofa Patwary's avatar
Mostofa Patwary committed
184
        except BaseException:
Mostofa Patwary's avatar
Mostofa Patwary committed
185
186
187
188
            print_rank_0('could not load the BERT checkpoint')
            sys.exit()

        checkpoint_version = state_dict.get('checkpoint_version', 0)
Mostofa Patwary's avatar
Mostofa Patwary committed
189
190
191
192

        # load the LM state dict into each model
        model_dict = state_dict['model']['language_model']

193
        if self.biencoder_shared_query_context_model:
Mostofa Patwary's avatar
Mostofa Patwary committed
194
            self.model.language_model.load_state_dict(model_dict)
Mostofa Patwary's avatar
Mostofa Patwary committed
195
            fix_query_key_value_ordering(self.model, checkpoint_version)
Mostofa Patwary's avatar
Mostofa Patwary committed
196
197
198
199
        else:
            if self.use_query_model:
                self.query_model.language_model.load_state_dict(model_dict)
                # give each model the same ict_head to begin with as well
200
                if self.biencoder_projection_dim > 0:
Mostofa Patwary's avatar
Mostofa Patwary committed
201
202
203
                    query_proj_state_dict = \
                        self.state_dict_for_save_checkpoint()\
                        [self._query_key]['projection_enc']
Mostofa Patwary's avatar
Mostofa Patwary committed
204
205
                fix_query_key_value_ordering(self.query_model, checkpoint_version)

Mostofa Patwary's avatar
Mostofa Patwary committed
206
207
            if self.use_context_model:
                self.context_model.language_model.load_state_dict(model_dict)
208
209
                if self.query_model is not None and \
                    self.biencoder_projection_dim > 0:
Mostofa Patwary's avatar
Mostofa Patwary committed
210
211
                    self.context_model.projection_enc.load_state_dict\
                        (query_proj_state_dict)
Mostofa Patwary's avatar
Mostofa Patwary committed
212
                fix_query_key_value_ordering(self.context_model, checkpoint_version)
Mostofa Patwary's avatar
Mostofa Patwary committed
213
214
215


class PretrainedBertModel(MegatronModule):
Mostofa Patwary's avatar
Mostofa Patwary committed
216
    """BERT-based encoder for queries or contexts used for
Mostofa Patwary's avatar
Mostofa Patwary committed
217
218
    learned information retrieval."""

Mostofa Patwary's avatar
Mostofa Patwary committed
219
    def __init__(self, num_tokentypes=2,
Mostofa Patwary's avatar
Mostofa Patwary committed
220
221
222
223
224
225
            parallel_output=True):
        super(PretrainedBertModel, self).__init__()

        args = get_args()
        tokenizer = get_tokenizer()
        self.pad_id = tokenizer.pad
226
        self.biencoder_projection_dim = args.biencoder_projection_dim
Mostofa Patwary's avatar
Mostofa Patwary committed
227
228
229
230
231
232
233
234
        self.parallel_output = parallel_output
        init_method = init_method_normal(args.init_method_std)
        scaled_init_method = scaled_init_method_normal(
            args.init_method_std, args.num_layers)

        self.language_model, self._language_model_key = get_language_model(
            num_tokentypes=num_tokentypes,
            add_pooler=False,
Mostofa Patwary's avatar
Mostofa Patwary committed
235
            encoder_attn_mask_type=AttnMaskType.padding,
Mostofa Patwary's avatar
Mostofa Patwary committed
236
237
238
            init_method=init_method,
            scaled_init_method=scaled_init_method)

239
        if args.biencoder_projection_dim > 0:
Mostofa Patwary's avatar
Mostofa Patwary committed
240
            self.projection_enc = get_linear_layer(args.hidden_size,
241
                                                   args.biencoder_projection_dim,
Mostofa Patwary's avatar
Mostofa Patwary committed
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
                                                   init_method)
            self._projection_enc_key = 'projection_enc'

    def forward(self, input_ids, attention_mask, tokentype_ids=None):
        extended_attention_mask = attention_mask.unsqueeze(1)
        #extended_attention_mask = bert_extended_attention_mask(attention_mask)
        position_ids = bert_position_ids(input_ids)


        lm_output = self.language_model(input_ids,
                                        position_ids,
                                        extended_attention_mask,
                                        tokentype_ids=tokentype_ids)
        # This mask will be used in average-pooling and max-pooling
        pool_mask = (input_ids == self.pad_id).unsqueeze(2)
Mostofa Patwary's avatar
Mostofa Patwary committed
257

258
259
        # Taking the representation of the [CLS] token of BERT
        pooled_output = lm_output[:, 0, :]
Mostofa Patwary's avatar
Mostofa Patwary committed
260
261
262

        # Converting to float16 dtype
        pooled_output = pooled_output.to(lm_output.dtype)
Mostofa Patwary's avatar
Mostofa Patwary committed
263

Mostofa Patwary's avatar
Mostofa Patwary committed
264
        # Output.
265
        if self.biencoder_projection_dim:
Mostofa Patwary's avatar
Mostofa Patwary committed
266
267
268
269
270
271
272
273
274
275
276
277
278
279
            pooled_output = self.projection_enc(pooled_output)

        return pooled_output

    def state_dict_for_save_checkpoint(self, destination=None, prefix='',
                                       keep_vars=False):
        """For easy load when model is combined with other heads,
        add an extra key."""

        state_dict_ = {}
        state_dict_[self._language_model_key] \
            = self.language_model.state_dict_for_save_checkpoint(
            destination, prefix, keep_vars)

280
        if self.biencoder_projection_dim > 0:
Mostofa Patwary's avatar
Mostofa Patwary committed
281
282
283
284
285
286
287
288
289
290
291
            state_dict_[self._projection_enc_key] = \
                self.projection_enc.state_dict(destination, prefix, keep_vars)

        return state_dict_

    def load_state_dict(self, state_dict, strict=True):
        """Customized load."""
        print_rank_0("loading BERT weights")
        self.language_model.load_state_dict(
            state_dict[self._language_model_key], strict=strict)

292
        if self.biencoder_projection_dim > 0:
Mostofa Patwary's avatar
Mostofa Patwary committed
293
294
295
            print_rank_0("loading projection head weights")
            self.projection_enc.load_state_dict(
                state_dict[self._projection_enc_key], strict=strict)