biencoder_model.py 13.7 KB
Newer Older
xingjinliang's avatar
xingjinliang committed
1
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
Mostofa Patwary's avatar
Mostofa Patwary committed
2
3
4
5
import os
import torch
import sys

xingjinliang's avatar
xingjinliang committed
6
from megatron.training import get_args, print_rank_0, get_tokenizer
7
from megatron.core import mpu
xingjinliang's avatar
xingjinliang committed
8
9
10
11
12
13
14
15
16
from megatron.training.checkpointing import fix_query_key_value_ordering
from megatron.training.checkpointing import get_checkpoint_tracker_filename
from megatron.training.checkpointing import get_checkpoint_name
from megatron.legacy.model.bert_model import bert_position_ids
from megatron.legacy.model.enums import AttnMaskType
from megatron.legacy.model.language_model import get_language_model
from megatron.legacy.model.utils import get_linear_layer
from megatron.legacy.model.utils import init_method_normal
from megatron.legacy.model.utils import scaled_init_method_normal
Mostofa Patwary's avatar
Mostofa Patwary committed
17
from .module import MegatronModule
Mostofa Patwary's avatar
Mostofa Patwary committed
18

Mostofa Patwary's avatar
Mostofa Patwary committed
19
def get_model_provider(only_query_model=False, only_context_model=False,
Mostofa Patwary's avatar
Mostofa Patwary committed
20
21
22
23
24
25
        biencoder_shared_query_context_model=False):

    def model_provider(pre_process=True, post_process=True):
        """Build the model."""

        print_rank_0('building Bienoder model ...')
Mostofa Patwary's avatar
Mostofa Patwary committed
26
27
        model = biencoder_model_provider(only_query_model=only_query_model,
                only_context_model = only_context_model,
Mostofa Patwary's avatar
Mostofa Patwary committed
28
                biencoder_shared_query_context_model = \
Mostofa Patwary's avatar
Mostofa Patwary committed
29
                biencoder_shared_query_context_model,
30
                pre_process=pre_process, post_process=post_process)
Mostofa Patwary's avatar
Mostofa Patwary committed
31
32
33
34
35
36

        return model

    return model_provider


37
38
39
40
def biencoder_model_provider(only_query_model=False,
                             only_context_model=False,
                             biencoder_shared_query_context_model=False,
                             pre_process=True,
Mostofa Patwary's avatar
Mostofa Patwary committed
41
                             post_process=True):
Mostofa Patwary's avatar
Mostofa Patwary committed
42
    """Build the model."""
Mostofa Patwary's avatar
Mostofa Patwary committed
43

Mostofa Patwary's avatar
Mostofa Patwary committed
44
45
46
47
48
49
    assert mpu.get_tensor_model_parallel_world_size() == 1 and \
        mpu.get_pipeline_model_parallel_world_size() == 1, \
        "Model parallel size > 1 not supported for ICT"

    print_rank_0('building BiEncoderModel...')

Mostofa Patwary's avatar
Mostofa Patwary committed
50
    # simpler to just keep using 2 tokentypes since
Mostofa Patwary's avatar
Mostofa Patwary committed
51
52
53
    # the LM we initialize with has 2 tokentypes
    model = BiEncoderModel(
        num_tokentypes=2,
54
        parallel_output=False,
Mostofa Patwary's avatar
Mostofa Patwary committed
55
56
        only_query_model=only_query_model,
        only_context_model=only_context_model,
57
        biencoder_shared_query_context_model=\
Mostofa Patwary's avatar
Mostofa Patwary committed
58
        biencoder_shared_query_context_model,
Mostofa Patwary's avatar
Mostofa Patwary committed
59
60
        pre_process=pre_process,
        post_process=post_process)
Mostofa Patwary's avatar
Mostofa Patwary committed
61
62
63
64
65
66
67
68
69
70
71
72

    return model


class BiEncoderModel(MegatronModule):
    """Bert-based module for Biencoder model."""

    def __init__(self,
                 num_tokentypes=1,
                 parallel_output=True,
                 only_query_model=False,
                 only_context_model=False,
Mostofa Patwary's avatar
Mostofa Patwary committed
73
74
75
                 biencoder_shared_query_context_model=False,
                 pre_process=True,
                 post_process=True):
Mostofa Patwary's avatar
Mostofa Patwary committed
76
77
78
79
80
        super(BiEncoderModel, self).__init__()
        args = get_args()

        bert_kwargs = dict(
            num_tokentypes=num_tokentypes,
Mostofa Patwary's avatar
Mostofa Patwary committed
81
82
83
            parallel_output=parallel_output,
            pre_process=pre_process,
            post_process=post_process)
Mostofa Patwary's avatar
Mostofa Patwary committed
84

85
86
        self.biencoder_shared_query_context_model = \
            biencoder_shared_query_context_model
Mostofa Patwary's avatar
Mostofa Patwary committed
87
88
89
        assert not (only_context_model and only_query_model)
        self.use_context_model = not only_query_model
        self.use_query_model = not only_context_model
90
        self.biencoder_projection_dim = args.biencoder_projection_dim
Mostofa Patwary's avatar
Mostofa Patwary committed
91

92
        if self.biencoder_shared_query_context_model:
Mostofa Patwary's avatar
Mostofa Patwary committed
93
94
95
96
97
98
99
100
101
102
103
104
105
106
            self.model = PretrainedBertModel(**bert_kwargs)
            self._model_key = 'shared_model'
            self.query_model, self.context_model = self.model, self.model
        else:
            if self.use_query_model:
                # this model embeds (pseudo-)queries - Embed_input in the paper
                self.query_model = PretrainedBertModel(**bert_kwargs)
                self._query_key = 'query_model'

            if self.use_context_model:
                # this model embeds evidence blocks - Embed_doc in the paper
                self.context_model = PretrainedBertModel(**bert_kwargs)
                self._context_key = 'context_model'

Mostofa Patwary's avatar
Mostofa Patwary committed
107
    def set_input_tensor(self, input_tensor):
xingjinliang's avatar
xingjinliang committed
108
        """See megatron.legacy.model.transformer.set_input_tensor()"""
Mostofa Patwary's avatar
Mostofa Patwary committed
109
110
111
        # this is just a placeholder and will be needed when model
        # parallelism will be used
        # self.language_model.set_input_tensor(input_tensor)
Mostofa Patwary's avatar
Mostofa Patwary committed
112
113
        return

Mostofa Patwary's avatar
Mostofa Patwary committed
114
115
    def forward(self, query_tokens, query_attention_mask, query_types,
                context_tokens, context_attention_mask, context_types):
Mostofa Patwary's avatar
Mostofa Patwary committed
116
        """Run a forward pass for each of the models and
Mostofa Patwary's avatar
Mostofa Patwary committed
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
        return the respective embeddings."""

        if self.use_query_model:
            query_logits = self.embed_text(self.query_model,
                                           query_tokens,
                                           query_attention_mask,
                                           query_types)
        else:
            raise ValueError("Cannot embed query without the query model.")
        if self.use_context_model:
            context_logits = self.embed_text(self.context_model,
                                             context_tokens,
                                             context_attention_mask,
                                             context_types)
        else:
            raise ValueError("Cannot embed block without the block model.")
        return query_logits, context_logits

    @staticmethod
    def embed_text(model, tokens, attention_mask, token_types):
        """Embed a batch of tokens using the model"""
        logits = model(tokens,
                              attention_mask,
                              token_types)
        return logits

143
    def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
Mostofa Patwary's avatar
Mostofa Patwary committed
144
145
        """Save dict with state dicts of each of the models."""
        state_dict_ = {}
146
        if self.biencoder_shared_query_context_model:
Mostofa Patwary's avatar
Mostofa Patwary committed
147
            state_dict_[self._model_key] = \
148
149
                self.model.state_dict_for_save_checkpoint(
                    prefix=prefix, keep_vars=keep_vars)
Mostofa Patwary's avatar
Mostofa Patwary committed
150
151
152
153
        else:
            if self.use_query_model:
                state_dict_[self._query_key] = \
                    self.query_model.state_dict_for_save_checkpoint(
154
                        prefix=prefix, keep_vars=keep_vars)
Mostofa Patwary's avatar
Mostofa Patwary committed
155
156
157
158

            if self.use_context_model:
                state_dict_[self._context_key] = \
                    self.context_model.state_dict_for_save_checkpoint(
159
                        prefix=prefix, keep_vars=keep_vars)
Mostofa Patwary's avatar
Mostofa Patwary committed
160
161
162
163
164

        return state_dict_

    def load_state_dict(self, state_dict, strict=True):
        """Load the state dicts of each of the models"""
165
        if self.biencoder_shared_query_context_model:
Mostofa Patwary's avatar
Mostofa Patwary committed
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
            print_rank_0("Loading shared query-context model")
            self.model.load_state_dict(state_dict[self._model_key], \
                strict=strict)
        else:
            if self.use_query_model:
                print_rank_0("Loading query model")
                self.query_model.load_state_dict( \
                    state_dict[self._query_key], strict=strict)

            if self.use_context_model:
                print_rank_0("Loading context model")
                self.context_model.load_state_dict( \
                    state_dict[self._context_key], strict=strict)

    def init_state_dict_from_bert(self):
Mostofa Patwary's avatar
Mostofa Patwary committed
181
        """Initialize the state from a pretrained BERT model
Mostofa Patwary's avatar
Mostofa Patwary committed
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
        on iteration zero of ICT pretraining"""
        args = get_args()

        if args.bert_load is None:
            print_rank_0("bert-load argument is None")
            return

        tracker_filename = get_checkpoint_tracker_filename(args.bert_load)
        if not os.path.isfile(tracker_filename):
            raise FileNotFoundError("Could not find BERT checkpoint")
        with open(tracker_filename, 'r') as f:
            iteration = int(f.read().strip())
            assert iteration > 0

        checkpoint_name = get_checkpoint_name(args.bert_load, iteration, False)
        if mpu.get_data_parallel_rank() == 0:
            print('global rank {} is loading BERT checkpoint {}'.format(
                torch.distributed.get_rank(), checkpoint_name))

Mostofa Patwary's avatar
Mostofa Patwary committed
201
        # Load the checkpoint.
Mostofa Patwary's avatar
Mostofa Patwary committed
202
203
        try:
            state_dict = torch.load(checkpoint_name, map_location='cpu')
Mostofa Patwary's avatar
Mostofa Patwary committed
204
        except ModuleNotFoundError:
xingjinliang's avatar
xingjinliang committed
205
            from megatron.legacy.fp16_deprecated import loss_scaler
Mostofa Patwary's avatar
Mostofa Patwary committed
206
207
208
209
210
211
212
213
214
            # For backward compatibility.
            print_rank_0(' > deserializing using the old code structure ...')
            sys.modules['fp16.loss_scaler'] = sys.modules[
                'megatron.fp16_deprecated.loss_scaler']
            sys.modules['megatron.fp16.loss_scaler'] = sys.modules[
                'megatron.fp16_deprecated.loss_scaler']
            state_dict = torch.load(checkpoint_name, map_location='cpu')
            sys.modules.pop('fp16.loss_scaler', None)
            sys.modules.pop('megatron.fp16.loss_scaler', None)
xingjinliang's avatar
xingjinliang committed
215
        except Exception:
Mostofa Patwary's avatar
Mostofa Patwary committed
216
217
218
219
            print_rank_0('could not load the BERT checkpoint')
            sys.exit()

        checkpoint_version = state_dict.get('checkpoint_version', 0)
Mostofa Patwary's avatar
Mostofa Patwary committed
220
221
222
223

        # load the LM state dict into each model
        model_dict = state_dict['model']['language_model']

224
        if self.biencoder_shared_query_context_model:
Mostofa Patwary's avatar
Mostofa Patwary committed
225
            self.model.language_model.load_state_dict(model_dict)
Mostofa Patwary's avatar
Mostofa Patwary committed
226
            fix_query_key_value_ordering(self.model, checkpoint_version)
Mostofa Patwary's avatar
Mostofa Patwary committed
227
228
229
230
        else:
            if self.use_query_model:
                self.query_model.language_model.load_state_dict(model_dict)
                # give each model the same ict_head to begin with as well
231
                if self.biencoder_projection_dim > 0:
Mostofa Patwary's avatar
Mostofa Patwary committed
232
233
234
                    query_proj_state_dict = \
                        self.state_dict_for_save_checkpoint()\
                        [self._query_key]['projection_enc']
Mostofa Patwary's avatar
Mostofa Patwary committed
235
236
                fix_query_key_value_ordering(self.query_model, checkpoint_version)

Mostofa Patwary's avatar
Mostofa Patwary committed
237
238
            if self.use_context_model:
                self.context_model.language_model.load_state_dict(model_dict)
239
240
                if self.query_model is not None and \
                    self.biencoder_projection_dim > 0:
Mostofa Patwary's avatar
Mostofa Patwary committed
241
242
                    self.context_model.projection_enc.load_state_dict\
                        (query_proj_state_dict)
Mostofa Patwary's avatar
Mostofa Patwary committed
243
                fix_query_key_value_ordering(self.context_model, checkpoint_version)
Mostofa Patwary's avatar
Mostofa Patwary committed
244
245
246


class PretrainedBertModel(MegatronModule):
Mostofa Patwary's avatar
Mostofa Patwary committed
247
    """BERT-based encoder for queries or contexts used for
Mostofa Patwary's avatar
Mostofa Patwary committed
248
249
    learned information retrieval."""

Mostofa Patwary's avatar
Mostofa Patwary committed
250
    def __init__(self, num_tokentypes=2,
Mostofa Patwary's avatar
Mostofa Patwary committed
251
            parallel_output=True, pre_process=True, post_process=True):
Mostofa Patwary's avatar
Mostofa Patwary committed
252
253
254
255
256
        super(PretrainedBertModel, self).__init__()

        args = get_args()
        tokenizer = get_tokenizer()
        self.pad_id = tokenizer.pad
257
        self.biencoder_projection_dim = args.biencoder_projection_dim
Mostofa Patwary's avatar
Mostofa Patwary committed
258
        self.parallel_output = parallel_output
Mostofa Patwary's avatar
Mostofa Patwary committed
259
260
        self.pre_process = pre_process
        self.post_process = post_process
Mostofa Patwary's avatar
Mostofa Patwary committed
261
262
263
264
265
266
267
        init_method = init_method_normal(args.init_method_std)
        scaled_init_method = scaled_init_method_normal(
            args.init_method_std, args.num_layers)

        self.language_model, self._language_model_key = get_language_model(
            num_tokentypes=num_tokentypes,
            add_pooler=False,
Mostofa Patwary's avatar
Mostofa Patwary committed
268
            encoder_attn_mask_type=AttnMaskType.padding,
Mostofa Patwary's avatar
Mostofa Patwary committed
269
            init_method=init_method,
Mostofa Patwary's avatar
Mostofa Patwary committed
270
271
272
            scaled_init_method=scaled_init_method,
            pre_process=self.pre_process,
            post_process=self.post_process)
Mostofa Patwary's avatar
Mostofa Patwary committed
273

274
        if args.biencoder_projection_dim > 0:
Mostofa Patwary's avatar
Mostofa Patwary committed
275
            self.projection_enc = get_linear_layer(args.hidden_size,
276
                                                   args.biencoder_projection_dim,
Mostofa Patwary's avatar
Mostofa Patwary committed
277
278
279
280
281
282
283
284
285
286
287
288
289
290
                                                   init_method)
            self._projection_enc_key = 'projection_enc'

    def forward(self, input_ids, attention_mask, tokentype_ids=None):
        extended_attention_mask = attention_mask.unsqueeze(1)
        #extended_attention_mask = bert_extended_attention_mask(attention_mask)
        position_ids = bert_position_ids(input_ids)

        lm_output = self.language_model(input_ids,
                                        position_ids,
                                        extended_attention_mask,
                                        tokentype_ids=tokentype_ids)
        # This mask will be used in average-pooling and max-pooling
        pool_mask = (input_ids == self.pad_id).unsqueeze(2)
Mostofa Patwary's avatar
Mostofa Patwary committed
291

292
        # Taking the representation of the [CLS] token of BERT
Vijay Korthikanti's avatar
Vijay Korthikanti committed
293
        pooled_output = lm_output[0, :, :]
Mostofa Patwary's avatar
Mostofa Patwary committed
294
295
296

        # Converting to float16 dtype
        pooled_output = pooled_output.to(lm_output.dtype)
Mostofa Patwary's avatar
Mostofa Patwary committed
297

Mostofa Patwary's avatar
Mostofa Patwary committed
298
        # Output.
299
        if self.biencoder_projection_dim:
Mostofa Patwary's avatar
Mostofa Patwary committed
300
301
302
303
            pooled_output = self.projection_enc(pooled_output)

        return pooled_output

304
    def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
Mostofa Patwary's avatar
Mostofa Patwary committed
305
306
307
308
309
310
        """For easy load when model is combined with other heads,
        add an extra key."""

        state_dict_ = {}
        state_dict_[self._language_model_key] \
            = self.language_model.state_dict_for_save_checkpoint(
311
                prefix=prefix, keep_vars=keep_vars)
Mostofa Patwary's avatar
Mostofa Patwary committed
312

313
        if self.biencoder_projection_dim > 0:
Mostofa Patwary's avatar
Mostofa Patwary committed
314
            state_dict_[self._projection_enc_key] = \
315
316
                self.projection_enc.state_dict(prefix=prefix,
                                               keep_vars=keep_vars)
Mostofa Patwary's avatar
Mostofa Patwary committed
317
318
319
320
321

        return state_dict_

    def load_state_dict(self, state_dict, strict=True):
        """Customized load."""
Mostofa Patwary's avatar
Mostofa Patwary committed
322
        print_rank_0("loading pretrained weights")
Mostofa Patwary's avatar
Mostofa Patwary committed
323
324
325
        self.language_model.load_state_dict(
            state_dict[self._language_model_key], strict=strict)

326
        if self.biencoder_projection_dim > 0:
Mostofa Patwary's avatar
Mostofa Patwary committed
327
328
329
            print_rank_0("loading projection head weights")
            self.projection_enc.load_state_dict(
                state_dict[self._projection_enc_key], strict=strict)