biencoder_model.py 13.5 KB
Newer Older
Mostofa Patwary's avatar
Mostofa Patwary committed
1
2
3
4
5
import os
import torch
import sys

from megatron import get_args, print_rank_0
Mostofa Patwary's avatar
Mostofa Patwary committed
6
7
8
from megatron.checkpointing import fix_query_key_value_ordering
from megatron.checkpointing import get_checkpoint_tracker_filename
from megatron.checkpointing import get_checkpoint_name
Mostofa Patwary's avatar
Mostofa Patwary committed
9
10
from megatron import mpu, get_tokenizer
from megatron.model.bert_model import bert_position_ids
Mostofa Patwary's avatar
Mostofa Patwary committed
11
from megatron.model.enums import AttnMaskType
Mostofa Patwary's avatar
Mostofa Patwary committed
12
13
14
15
from megatron.model.language_model import get_language_model
from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal
from megatron.model.utils import scaled_init_method_normal
Mostofa Patwary's avatar
Mostofa Patwary committed
16
from .module import MegatronModule
Mostofa Patwary's avatar
Mostofa Patwary committed
17

Mostofa Patwary's avatar
Mostofa Patwary committed
18
def get_model_provider(only_query_model=False, only_context_model=False,
Mostofa Patwary's avatar
Mostofa Patwary committed
19
20
21
22
23
24
        biencoder_shared_query_context_model=False):

    def model_provider(pre_process=True, post_process=True):
        """Build the model."""

        print_rank_0('building Bienoder model ...')
Mostofa Patwary's avatar
Mostofa Patwary committed
25
26
        model = biencoder_model_provider(only_query_model=only_query_model,
                only_context_model = only_context_model,
Mostofa Patwary's avatar
Mostofa Patwary committed
27
                biencoder_shared_query_context_model = \
Mostofa Patwary's avatar
Mostofa Patwary committed
28
                biencoder_shared_query_context_model,
29
                pre_process=pre_process, post_process=post_process)
Mostofa Patwary's avatar
Mostofa Patwary committed
30
31
32
33
34
35

        return model

    return model_provider


36
37
38
39
def biencoder_model_provider(only_query_model=False,
                             only_context_model=False,
                             biencoder_shared_query_context_model=False,
                             pre_process=True,
Mostofa Patwary's avatar
Mostofa Patwary committed
40
                             post_process=True):
Mostofa Patwary's avatar
Mostofa Patwary committed
41
    """Build the model."""
Mostofa Patwary's avatar
Mostofa Patwary committed
42

Mostofa Patwary's avatar
Mostofa Patwary committed
43
44
45
46
47
48
    assert mpu.get_tensor_model_parallel_world_size() == 1 and \
        mpu.get_pipeline_model_parallel_world_size() == 1, \
        "Model parallel size > 1 not supported for ICT"

    print_rank_0('building BiEncoderModel...')

Mostofa Patwary's avatar
Mostofa Patwary committed
49
    # simpler to just keep using 2 tokentypes since
Mostofa Patwary's avatar
Mostofa Patwary committed
50
51
52
    # the LM we initialize with has 2 tokentypes
    model = BiEncoderModel(
        num_tokentypes=2,
53
        parallel_output=False,
Mostofa Patwary's avatar
Mostofa Patwary committed
54
55
        only_query_model=only_query_model,
        only_context_model=only_context_model,
56
        biencoder_shared_query_context_model=\
Mostofa Patwary's avatar
Mostofa Patwary committed
57
        biencoder_shared_query_context_model,
Mostofa Patwary's avatar
Mostofa Patwary committed
58
59
        pre_process=pre_process,
        post_process=post_process)
Mostofa Patwary's avatar
Mostofa Patwary committed
60
61
62
63
64
65
66
67
68
69
70
71

    return model


class BiEncoderModel(MegatronModule):
    """Bert-based module for Biencoder model."""

    def __init__(self,
                 num_tokentypes=1,
                 parallel_output=True,
                 only_query_model=False,
                 only_context_model=False,
Mostofa Patwary's avatar
Mostofa Patwary committed
72
73
74
                 biencoder_shared_query_context_model=False,
                 pre_process=True,
                 post_process=True):
Mostofa Patwary's avatar
Mostofa Patwary committed
75
76
77
78
79
        super(BiEncoderModel, self).__init__()
        args = get_args()

        bert_kwargs = dict(
            num_tokentypes=num_tokentypes,
Mostofa Patwary's avatar
Mostofa Patwary committed
80
81
82
            parallel_output=parallel_output,
            pre_process=pre_process,
            post_process=post_process)
Mostofa Patwary's avatar
Mostofa Patwary committed
83

84
85
        self.biencoder_shared_query_context_model = \
            biencoder_shared_query_context_model
Mostofa Patwary's avatar
Mostofa Patwary committed
86
87
88
        assert not (only_context_model and only_query_model)
        self.use_context_model = not only_query_model
        self.use_query_model = not only_context_model
89
        self.biencoder_projection_dim = args.biencoder_projection_dim
Mostofa Patwary's avatar
Mostofa Patwary committed
90

91
        if self.biencoder_shared_query_context_model:
Mostofa Patwary's avatar
Mostofa Patwary committed
92
93
94
95
96
97
98
99
100
101
102
103
104
105
            self.model = PretrainedBertModel(**bert_kwargs)
            self._model_key = 'shared_model'
            self.query_model, self.context_model = self.model, self.model
        else:
            if self.use_query_model:
                # this model embeds (pseudo-)queries - Embed_input in the paper
                self.query_model = PretrainedBertModel(**bert_kwargs)
                self._query_key = 'query_model'

            if self.use_context_model:
                # this model embeds evidence blocks - Embed_doc in the paper
                self.context_model = PretrainedBertModel(**bert_kwargs)
                self._context_key = 'context_model'

Mostofa Patwary's avatar
Mostofa Patwary committed
106
107
    def set_input_tensor(self, input_tensor):
        """See megatron.model.transformer.set_input_tensor()"""
Mostofa Patwary's avatar
Mostofa Patwary committed
108
109
110
        # this is just a placeholder and will be needed when model
        # parallelism will be used
        # self.language_model.set_input_tensor(input_tensor)
Mostofa Patwary's avatar
Mostofa Patwary committed
111
112
        return

Mostofa Patwary's avatar
Mostofa Patwary committed
113
114
    def forward(self, query_tokens, query_attention_mask, query_types,
                context_tokens, context_attention_mask, context_types):
Mostofa Patwary's avatar
Mostofa Patwary committed
115
        """Run a forward pass for each of the models and
Mostofa Patwary's avatar
Mostofa Patwary committed
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
        return the respective embeddings."""

        if self.use_query_model:
            query_logits = self.embed_text(self.query_model,
                                           query_tokens,
                                           query_attention_mask,
                                           query_types)
        else:
            raise ValueError("Cannot embed query without the query model.")
        if self.use_context_model:
            context_logits = self.embed_text(self.context_model,
                                             context_tokens,
                                             context_attention_mask,
                                             context_types)
        else:
            raise ValueError("Cannot embed block without the block model.")
        return query_logits, context_logits

    @staticmethod
    def embed_text(model, tokens, attention_mask, token_types):
        """Embed a batch of tokens using the model"""
        logits = model(tokens,
                              attention_mask,
                              token_types)
        return logits

142
    def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
Mostofa Patwary's avatar
Mostofa Patwary committed
143
144
        """Save dict with state dicts of each of the models."""
        state_dict_ = {}
145
        if self.biencoder_shared_query_context_model:
Mostofa Patwary's avatar
Mostofa Patwary committed
146
            state_dict_[self._model_key] = \
147
148
                self.model.state_dict_for_save_checkpoint(
                    prefix=prefix, keep_vars=keep_vars)
Mostofa Patwary's avatar
Mostofa Patwary committed
149
150
151
152
        else:
            if self.use_query_model:
                state_dict_[self._query_key] = \
                    self.query_model.state_dict_for_save_checkpoint(
153
                        prefix=prefix, keep_vars=keep_vars)
Mostofa Patwary's avatar
Mostofa Patwary committed
154
155
156
157

            if self.use_context_model:
                state_dict_[self._context_key] = \
                    self.context_model.state_dict_for_save_checkpoint(
158
                        prefix=prefix, keep_vars=keep_vars)
Mostofa Patwary's avatar
Mostofa Patwary committed
159
160
161
162
163

        return state_dict_

    def load_state_dict(self, state_dict, strict=True):
        """Load the state dicts of each of the models"""
164
        if self.biencoder_shared_query_context_model:
Mostofa Patwary's avatar
Mostofa Patwary committed
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
            print_rank_0("Loading shared query-context model")
            self.model.load_state_dict(state_dict[self._model_key], \
                strict=strict)
        else:
            if self.use_query_model:
                print_rank_0("Loading query model")
                self.query_model.load_state_dict( \
                    state_dict[self._query_key], strict=strict)

            if self.use_context_model:
                print_rank_0("Loading context model")
                self.context_model.load_state_dict( \
                    state_dict[self._context_key], strict=strict)

    def init_state_dict_from_bert(self):
Mostofa Patwary's avatar
Mostofa Patwary committed
180
        """Initialize the state from a pretrained BERT model
Mostofa Patwary's avatar
Mostofa Patwary committed
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
        on iteration zero of ICT pretraining"""
        args = get_args()

        if args.bert_load is None:
            print_rank_0("bert-load argument is None")
            return

        tracker_filename = get_checkpoint_tracker_filename(args.bert_load)
        if not os.path.isfile(tracker_filename):
            raise FileNotFoundError("Could not find BERT checkpoint")
        with open(tracker_filename, 'r') as f:
            iteration = int(f.read().strip())
            assert iteration > 0

        checkpoint_name = get_checkpoint_name(args.bert_load, iteration, False)
        if mpu.get_data_parallel_rank() == 0:
            print('global rank {} is loading BERT checkpoint {}'.format(
                torch.distributed.get_rank(), checkpoint_name))

Mostofa Patwary's avatar
Mostofa Patwary committed
200
        # Load the checkpoint.
Mostofa Patwary's avatar
Mostofa Patwary committed
201
202
        try:
            state_dict = torch.load(checkpoint_name, map_location='cpu')
Mostofa Patwary's avatar
Mostofa Patwary committed
203
204
205
206
207
208
209
210
211
212
213
        except ModuleNotFoundError:
            from megatron.fp16_deprecated import loss_scaler
            # For backward compatibility.
            print_rank_0(' > deserializing using the old code structure ...')
            sys.modules['fp16.loss_scaler'] = sys.modules[
                'megatron.fp16_deprecated.loss_scaler']
            sys.modules['megatron.fp16.loss_scaler'] = sys.modules[
                'megatron.fp16_deprecated.loss_scaler']
            state_dict = torch.load(checkpoint_name, map_location='cpu')
            sys.modules.pop('fp16.loss_scaler', None)
            sys.modules.pop('megatron.fp16.loss_scaler', None)
Mostofa Patwary's avatar
Mostofa Patwary committed
214
        except BaseException:
Mostofa Patwary's avatar
Mostofa Patwary committed
215
216
217
218
            print_rank_0('could not load the BERT checkpoint')
            sys.exit()

        checkpoint_version = state_dict.get('checkpoint_version', 0)
Mostofa Patwary's avatar
Mostofa Patwary committed
219
220
221
222

        # load the LM state dict into each model
        model_dict = state_dict['model']['language_model']

223
        if self.biencoder_shared_query_context_model:
Mostofa Patwary's avatar
Mostofa Patwary committed
224
            self.model.language_model.load_state_dict(model_dict)
Mostofa Patwary's avatar
Mostofa Patwary committed
225
            fix_query_key_value_ordering(self.model, checkpoint_version)
Mostofa Patwary's avatar
Mostofa Patwary committed
226
227
228
229
        else:
            if self.use_query_model:
                self.query_model.language_model.load_state_dict(model_dict)
                # give each model the same ict_head to begin with as well
230
                if self.biencoder_projection_dim > 0:
Mostofa Patwary's avatar
Mostofa Patwary committed
231
232
233
                    query_proj_state_dict = \
                        self.state_dict_for_save_checkpoint()\
                        [self._query_key]['projection_enc']
Mostofa Patwary's avatar
Mostofa Patwary committed
234
235
                fix_query_key_value_ordering(self.query_model, checkpoint_version)

Mostofa Patwary's avatar
Mostofa Patwary committed
236
237
            if self.use_context_model:
                self.context_model.language_model.load_state_dict(model_dict)
238
239
                if self.query_model is not None and \
                    self.biencoder_projection_dim > 0:
Mostofa Patwary's avatar
Mostofa Patwary committed
240
241
                    self.context_model.projection_enc.load_state_dict\
                        (query_proj_state_dict)
Mostofa Patwary's avatar
Mostofa Patwary committed
242
                fix_query_key_value_ordering(self.context_model, checkpoint_version)
Mostofa Patwary's avatar
Mostofa Patwary committed
243
244
245


class PretrainedBertModel(MegatronModule):
Mostofa Patwary's avatar
Mostofa Patwary committed
246
    """BERT-based encoder for queries or contexts used for
Mostofa Patwary's avatar
Mostofa Patwary committed
247
248
    learned information retrieval."""

Mostofa Patwary's avatar
Mostofa Patwary committed
249
    def __init__(self, num_tokentypes=2,
Mostofa Patwary's avatar
Mostofa Patwary committed
250
            parallel_output=True, pre_process=True, post_process=True):
Mostofa Patwary's avatar
Mostofa Patwary committed
251
252
253
254
255
        super(PretrainedBertModel, self).__init__()

        args = get_args()
        tokenizer = get_tokenizer()
        self.pad_id = tokenizer.pad
256
        self.biencoder_projection_dim = args.biencoder_projection_dim
Mostofa Patwary's avatar
Mostofa Patwary committed
257
        self.parallel_output = parallel_output
Mostofa Patwary's avatar
Mostofa Patwary committed
258
259
        self.pre_process = pre_process
        self.post_process = post_process
Mostofa Patwary's avatar
Mostofa Patwary committed
260
261
262
263
264
265
266
        init_method = init_method_normal(args.init_method_std)
        scaled_init_method = scaled_init_method_normal(
            args.init_method_std, args.num_layers)

        self.language_model, self._language_model_key = get_language_model(
            num_tokentypes=num_tokentypes,
            add_pooler=False,
Mostofa Patwary's avatar
Mostofa Patwary committed
267
            encoder_attn_mask_type=AttnMaskType.padding,
Mostofa Patwary's avatar
Mostofa Patwary committed
268
            init_method=init_method,
Mostofa Patwary's avatar
Mostofa Patwary committed
269
270
271
            scaled_init_method=scaled_init_method,
            pre_process=self.pre_process,
            post_process=self.post_process)
Mostofa Patwary's avatar
Mostofa Patwary committed
272

273
        if args.biencoder_projection_dim > 0:
Mostofa Patwary's avatar
Mostofa Patwary committed
274
            self.projection_enc = get_linear_layer(args.hidden_size,
275
                                                   args.biencoder_projection_dim,
Mostofa Patwary's avatar
Mostofa Patwary committed
276
277
278
279
280
281
282
283
284
285
286
287
288
289
                                                   init_method)
            self._projection_enc_key = 'projection_enc'

    def forward(self, input_ids, attention_mask, tokentype_ids=None):
        extended_attention_mask = attention_mask.unsqueeze(1)
        #extended_attention_mask = bert_extended_attention_mask(attention_mask)
        position_ids = bert_position_ids(input_ids)

        lm_output = self.language_model(input_ids,
                                        position_ids,
                                        extended_attention_mask,
                                        tokentype_ids=tokentype_ids)
        # This mask will be used in average-pooling and max-pooling
        pool_mask = (input_ids == self.pad_id).unsqueeze(2)
Mostofa Patwary's avatar
Mostofa Patwary committed
290

291
        # Taking the representation of the [CLS] token of BERT
Vijay Korthikanti's avatar
Vijay Korthikanti committed
292
        pooled_output = lm_output[0, :, :]
Mostofa Patwary's avatar
Mostofa Patwary committed
293
294
295

        # Converting to float16 dtype
        pooled_output = pooled_output.to(lm_output.dtype)
Mostofa Patwary's avatar
Mostofa Patwary committed
296

Mostofa Patwary's avatar
Mostofa Patwary committed
297
        # Output.
298
        if self.biencoder_projection_dim:
Mostofa Patwary's avatar
Mostofa Patwary committed
299
300
301
302
            pooled_output = self.projection_enc(pooled_output)

        return pooled_output

303
    def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
Mostofa Patwary's avatar
Mostofa Patwary committed
304
305
306
307
308
309
        """For easy load when model is combined with other heads,
        add an extra key."""

        state_dict_ = {}
        state_dict_[self._language_model_key] \
            = self.language_model.state_dict_for_save_checkpoint(
310
                prefix=prefix, keep_vars=keep_vars)
Mostofa Patwary's avatar
Mostofa Patwary committed
311

312
        if self.biencoder_projection_dim > 0:
Mostofa Patwary's avatar
Mostofa Patwary committed
313
            state_dict_[self._projection_enc_key] = \
314
315
                self.projection_enc.state_dict(prefix=prefix,
                                               keep_vars=keep_vars)
Mostofa Patwary's avatar
Mostofa Patwary committed
316
317
318
319
320

        return state_dict_

    def load_state_dict(self, state_dict, strict=True):
        """Customized load."""
Mostofa Patwary's avatar
Mostofa Patwary committed
321
        print_rank_0("loading pretrained weights")
Mostofa Patwary's avatar
Mostofa Patwary committed
322
323
324
        self.language_model.load_state_dict(
            state_dict[self._language_model_key], strict=strict)

325
        if self.biencoder_projection_dim > 0:
Mostofa Patwary's avatar
Mostofa Patwary committed
326
327
328
            print_rank_0("loading projection head weights")
            self.projection_enc.load_state_dict(
                state_dict[self._projection_enc_key], strict=strict)