bert_model.py 10.8 KB
Newer Older
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""BERT model."""

import torch

Mohammad's avatar
Mohammad committed
20
from megatron import get_args
mohammad's avatar
mohammad committed
21
from megatron import mpu
22
from megatron.model.enums import AttnMaskType
23
24
from megatron.model.language_model import parallel_lm_logits
from megatron.model.language_model import get_language_model
25
from megatron.model import import_layernorm
26
from megatron.model.utils import openai_gelu, erf_gelu
27
28
29
from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal
from megatron.model.utils import scaled_init_method_normal
30
from .module import MegatronModule
31

32
def bert_extended_attention_mask(attention_mask):
Neel Kant's avatar
Neel Kant committed
33
34
35
36
37
38
39
40
41
42
    # We create a 3D attention mask from a 2D tensor mask.
    # [b, 1, s]
    attention_mask_b1s = attention_mask.unsqueeze(1)
    # [b, s, 1]
    attention_mask_bs1 = attention_mask.unsqueeze(2)
    # [b, s, s]
    attention_mask_bss = attention_mask_b1s * attention_mask_bs1
    # [b, 1, s, s]
    extended_attention_mask = attention_mask_bss.unsqueeze(1)

43
44
    # Convert attention mask to binary:
    extended_attention_mask = (extended_attention_mask < 0.5)
Neel Kant's avatar
Neel Kant committed
45

46
    return extended_attention_mask
Neel Kant's avatar
Neel Kant committed
47
48
49
50
51
52
53
54
55
56
57

def bert_position_ids(token_ids):
    # Create position ids
    seq_length = token_ids.size(1)
    position_ids = torch.arange(seq_length, dtype=torch.long,
                                device=token_ids.device)
    position_ids = position_ids.unsqueeze(0).expand_as(token_ids)

    return position_ids


58
59
60
61
62
63
64
65
class BertLMHead(MegatronModule):
    """Masked LM head for Bert

    Arguments:
        mpu_vocab_size: model parallel size of vocabulary.
        hidden_size: hidden size
        init_method: init method for weight initialization
        layernorm_epsilon: tolerance for layer norm divisions
66
        parallel_output: whether output logits being distributed or not.
67
    """
Neel Kant's avatar
Neel Kant committed
68

69
70
71
72
73
    def __init__(self, mpu_vocab_size, hidden_size, init_method,
                 layernorm_epsilon, parallel_output):

        super(BertLMHead, self).__init__()

74
        args = get_args()
Neel Kant's avatar
Neel Kant committed
75

76
        self.bias = torch.nn.Parameter(torch.zeros(mpu_vocab_size))
77
        mpu.set_tensor_model_parallel_attributes(self.bias, True, 0, 1)
78
79
80
        self.parallel_output = parallel_output

        self.dense = get_linear_layer(hidden_size, hidden_size, init_method)
81
        LayerNorm = import_layernorm(args.fp32_residual_connection)
82
        self.layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)
83
84
85
        self.gelu = torch.nn.functional.gelu
        if args.openai_gelu:
            self.gelu = openai_gelu
86
        elif args.onnx_safe:
Boris Fomitchev's avatar
Boris Fomitchev committed
87
            self.gelu = erf_gelu
88
89
90

    def forward(self, hidden_states, word_embeddings_weight):
        hidden_states = self.dense(hidden_states)
91
        hidden_states = self.gelu(hidden_states)
92
93
94
95
96
97
98
99
        hidden_states = self.layernorm(hidden_states)
        output = parallel_lm_logits(hidden_states,
                                    word_embeddings_weight,
                                    self.parallel_output,
                                    bias=self.bias)
        return output


100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
def post_language_model_processing(lm_output, pooled_output,
                                   lm_head, binary_head,
                                   lm_labels,
                                   logit_weights,
                                   fp16_lm_cross_entropy):
    # Output.
    lm_logits = lm_head(
        lm_output, logit_weights)

    binary_logits = None
    if binary_head is not None:
        binary_logits = binary_head(pooled_output)

    if lm_labels is None:
        return lm_logits, binary_logits
    else:
        if fp16_lm_cross_entropy:
            assert lm_logits.dtype == torch.half
            lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits, lm_labels)
        else:
            lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits.float(),
                                                       lm_labels)
        return lm_loss, binary_logits


125
class BertModelBase(MegatronModule):
126
127
    """Bert Language model."""

Mohammad's avatar
Mohammad committed
128
    def __init__(self, num_tokentypes=2, add_binary_head=True,
129
                 parallel_output=True):
130
        super(BertModelBase, self).__init__()
Mohammad's avatar
Mohammad committed
131
        args = get_args()
132

mohammad's avatar
mohammad committed
133
        self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
134
135
        self.add_binary_head = add_binary_head
        self.parallel_output = parallel_output
136

Mohammad's avatar
Mohammad committed
137
138
139
        init_method = init_method_normal(args.init_method_std)
        scaled_init_method = scaled_init_method_normal(args.init_method_std,
                                                       args.num_layers)
Neel Kant's avatar
Neel Kant committed
140

141
142
        self.language_model, self._language_model_key = get_language_model(
            num_tokentypes=num_tokentypes,
143
            add_pooler=self.add_binary_head,
144
            encoder_attn_mask_type=AttnMaskType.padding,
145
            init_method=init_method,
146
            scaled_init_method=scaled_init_method)
147

148
        self.initialize_word_embeddings(init_method_normal)
149
        if mpu.is_pipeline_last_stage():
150
151
152
153
154
155
156
157
158
159
160
            self.lm_head = BertLMHead(
                self.word_embeddings_weight().size(0),
                args.hidden_size, init_method, args.layernorm_epsilon, parallel_output)
            self._lm_head_key = 'lm_head'
            self.binary_head = None
            if self.add_binary_head:
                self.binary_head = get_linear_layer(args.hidden_size, 2,
                                                    init_method)
                self._binary_head_key = 'binary_head'

    def forward(self, bert_model_input, attention_mask,
mohammad's avatar
mohammad committed
161
                tokentype_ids=None, lm_labels=None):
162

163
        extended_attention_mask = bert_extended_attention_mask(attention_mask)
164
165

        kwargs = {}
166
        if mpu.is_pipeline_first_stage():
167
168
169
170
            input_ids = bert_model_input
            position_ids = bert_position_ids(input_ids)
            args = [input_ids, position_ids, extended_attention_mask]
            kwargs['tokentype_ids'] = tokentype_ids
171
        else:
172
173
            args = [bert_model_input, extended_attention_mask]
        lm_output = self.language_model(*args, **kwargs)
174
        if mpu.is_pipeline_last_stage() and self.add_binary_head:
175
            lm_output, pooled_output = lm_output
mohammad's avatar
mohammad committed
176
        else:
177
178
            pooled_output = None

179
        if mpu.is_pipeline_last_stage():
180
181
182
183
184
185
186
            return post_language_model_processing(lm_output, pooled_output,
                                                  self.lm_head, self.binary_head,
                                                  lm_labels,
                                                  self.word_embeddings_weight(),
                                                  self.fp16_lm_cross_entropy)
        else:
            return lm_output
187
188
189
190
191
192
193
194
195
196


    def state_dict_for_save_checkpoint(self, destination=None, prefix='',
                                       keep_vars=False):
        """For easy load when model is combined with other heads,
        add an extra key."""

        state_dict_ = {}
        state_dict_[self._language_model_key] \
            = self.language_model.state_dict_for_save_checkpoint(
197
            destination, prefix, keep_vars)
198
        if mpu.is_pipeline_last_stage():
199
200
201
            state_dict_[self._lm_head_key] \
                = self.lm_head.state_dict_for_save_checkpoint(
                destination, prefix, keep_vars)
202
        if mpu.is_pipeline_last_stage() and self.add_binary_head:
203
204
            state_dict_[self._binary_head_key] \
                = self.binary_head.state_dict(destination, prefix, keep_vars)
205
        # Save word_embeddings.
206
        if mpu.is_pipeline_last_stage() and not mpu.is_pipeline_first_stage():
207
208
            state_dict_[self._word_embeddings_for_head_key] \
                = self.word_embeddings.state_dict(destination, prefix, keep_vars)
209
210
211
212
213
214
215
        return state_dict_

    def load_state_dict(self, state_dict, strict=True):
        """Customized load."""

        self.language_model.load_state_dict(
            state_dict[self._language_model_key], strict=strict)
216
        if mpu.is_pipeline_last_stage():
217
218
            self.lm_head.load_state_dict(
                state_dict[self._lm_head_key], strict=strict)
219
        if mpu.is_pipeline_last_stage() and self.add_binary_head:
Neel Kant's avatar
Neel Kant committed
220
221
            self.binary_head.load_state_dict(
                state_dict[self._binary_head_key], strict=strict)
222
        # Load word_embeddings.
223
        if mpu.is_pipeline_last_stage() and not mpu.is_pipeline_first_stage():
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
            self.word_embeddings.load_state_dict(
                state_dict[self._word_embeddings_for_head_key], strict=strict)


class BertModel(BertModelBase):

    def __init__(self, num_tokentypes=2, add_binary_head=True,
                 parallel_output=True):
        super(BertModel, self).__init__(
            num_tokentypes=num_tokentypes,
            add_binary_head=add_binary_head,
            parallel_output=parallel_output)

    def forward(self, input_ids, attention_mask,
                tokentype_ids=None, lm_labels=None):
        return super(BertModel, self).forward(
            input_ids,
            attention_mask,
            tokentype_ids=tokentype_ids,
            lm_labels=lm_labels)


class BertModelFirstStage(BertModelBase):

    def __init__(self, num_tokentypes=2):
        super(BertModelFirstStage, self).__init__(
            num_tokentypes=num_tokentypes)

    def forward(self, input_ids, attention_mask,
                tokentype_ids=None):
        return super(BertModelFirstStage, self).forward(
            input_ids,
            attention_mask,
            tokentype_ids=tokentype_ids)


class BertModelIntermediateStage(BertModelBase):

    def __init__(self, num_tokentypes=2):
        super(BertModelIntermediateStage, self).__init__(
            num_tokentypes=num_tokentypes)

    def forward(self, hidden_state, attention_mask):
        return super(BertModelIntermediateStage, self).forward(
            hidden_state,
            attention_mask)


class BertModelLastStage(BertModelBase):

    def __init__(self, num_tokentypes=2, add_binary_head=True,
                 parallel_output=True):
        super(BertModelLastStage, self).__init__(
            num_tokentypes=num_tokentypes,
            add_binary_head=add_binary_head,
            parallel_output=parallel_output)

    def forward(self, hidden_state, attention_mask,
                lm_labels=None):
        return super(BertModelLastStage, self).forward(
            hidden_state,
            attention_mask,
            lm_labels=lm_labels)