bert_model.py 9.55 KB
Newer Older
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""BERT model."""

import torch

Mohammad's avatar
Mohammad committed
20
from megatron import get_args
mohammad's avatar
mohammad committed
21
from megatron import mpu
22
from megatron.model.enums import AttnMaskType
23
24
from megatron.model.language_model import parallel_lm_logits
from megatron.model.language_model import get_language_model
25
from megatron.model import LayerNorm
26
from megatron.model.utils import openai_gelu, erf_gelu
27
28
29
from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal
from megatron.model.utils import scaled_init_method_normal
30
from .module import MegatronModule
31

32
def bert_extended_attention_mask(attention_mask):
Neel Kant's avatar
Neel Kant committed
33
34
35
36
37
38
39
40
41
42
    # We create a 3D attention mask from a 2D tensor mask.
    # [b, 1, s]
    attention_mask_b1s = attention_mask.unsqueeze(1)
    # [b, s, 1]
    attention_mask_bs1 = attention_mask.unsqueeze(2)
    # [b, s, s]
    attention_mask_bss = attention_mask_b1s * attention_mask_bs1
    # [b, 1, s, s]
    extended_attention_mask = attention_mask_bss.unsqueeze(1)

43
44
    # Convert attention mask to binary:
    extended_attention_mask = (extended_attention_mask < 0.5)
Neel Kant's avatar
Neel Kant committed
45

46
    return extended_attention_mask
Neel Kant's avatar
Neel Kant committed
47
48
49
50
51
52
53
54
55
56
57

def bert_position_ids(token_ids):
    # Create position ids
    seq_length = token_ids.size(1)
    position_ids = torch.arange(seq_length, dtype=torch.long,
                                device=token_ids.device)
    position_ids = position_ids.unsqueeze(0).expand_as(token_ids)

    return position_ids


58
59
60
61
62
63
64
65
class BertLMHead(MegatronModule):
    """Masked LM head for Bert

    Arguments:
        mpu_vocab_size: model parallel size of vocabulary.
        hidden_size: hidden size
        init_method: init method for weight initialization
        layernorm_epsilon: tolerance for layer norm divisions
66
        parallel_output: whether output logits being distributed or not.
67
    """
Neel Kant's avatar
Neel Kant committed
68

69
70
71
72
73
    def __init__(self, mpu_vocab_size, hidden_size, init_method,
                 layernorm_epsilon, parallel_output):

        super(BertLMHead, self).__init__()

74
        args = get_args()
Neel Kant's avatar
Neel Kant committed
75

76
        self.bias = torch.nn.Parameter(torch.zeros(mpu_vocab_size))
77
        mpu.set_tensor_model_parallel_attributes(self.bias, True, 0, 1)
78
79
80
        self.parallel_output = parallel_output

        self.dense = get_linear_layer(hidden_size, hidden_size, init_method)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
81
82
83
84
85
86
        setattr(self.dense.weight, 'sequence_parallel', args.sequence_parallel)
        setattr(self.dense.bias, 'sequence_parallel', args.sequence_parallel)

        self.layernorm = LayerNorm(hidden_size, 
                                   eps=layernorm_epsilon,
                                   sequence_parallel=args.sequence_parallel)
87
88
89
        self.gelu = torch.nn.functional.gelu
        if args.openai_gelu:
            self.gelu = openai_gelu
90
        elif args.onnx_safe:
Boris Fomitchev's avatar
Boris Fomitchev committed
91
            self.gelu = erf_gelu
92
93
94

    def forward(self, hidden_states, word_embeddings_weight):
        hidden_states = self.dense(hidden_states)
95
        hidden_states = self.gelu(hidden_states)
96
97
98
99
100
101
102
103
        hidden_states = self.layernorm(hidden_states)
        output = parallel_lm_logits(hidden_states,
                                    word_embeddings_weight,
                                    self.parallel_output,
                                    bias=self.bias)
        return output


104
105
106
107
108
109
110
111
112
113
def post_language_model_processing(lm_output, pooled_output,
                                   lm_head, binary_head,
                                   lm_labels,
                                   logit_weights,
                                   fp16_lm_cross_entropy):
    # Output.
    lm_logits = lm_head(
        lm_output, logit_weights)

    binary_logits = None
Vijay Korthikanti's avatar
Vijay Korthikanti committed
114
    if binary_head is not None and pooled_output is not None:
115
116
117
        binary_logits = binary_head(pooled_output)

    if lm_labels is None:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
118
119
        # [s b h] => [b s h]
        return lm_logits.transpose(0,1).contiguous(), binary_logits
120
    else:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
121
        # [b s] => [s b]
Vijay Korthikanti's avatar
Vijay Korthikanti committed
122
123
        lm_labels = lm_labels.transpose(0,1).contiguous()
        # lm_logits : [s, b, h] and lm_labels: [s, b]
124
125
126
127
128
129
        if fp16_lm_cross_entropy:
            assert lm_logits.dtype == torch.half
            lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits, lm_labels)
        else:
            lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits.float(),
                                                       lm_labels)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
130
131
        # [s, b] => [b s]
        lm_loss = lm_loss.transpose(0,1).contiguous()
132
133
134
        return lm_loss, binary_logits


135
class BertModel(MegatronModule):
136
137
    """Bert Language model."""

138
139
    def __init__(self,
                 num_tokentypes=2,
140
141
142
143
144
                 add_binary_head=True,
                 parallel_output=True,
                 pre_process=True,
                 post_process=True):
        super(BertModel, self).__init__()
Mohammad's avatar
Mohammad committed
145
        args = get_args()
146

mohammad's avatar
mohammad committed
147
        self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
148
149
        self.add_binary_head = add_binary_head
        self.parallel_output = parallel_output
150
151
        self.pre_process = pre_process
        self.post_process = post_process
152

Mohammad's avatar
Mohammad committed
153
154
155
        init_method = init_method_normal(args.init_method_std)
        scaled_init_method = scaled_init_method_normal(args.init_method_std,
                                                       args.num_layers)
Neel Kant's avatar
Neel Kant committed
156

157
158
        self.language_model, self._language_model_key = get_language_model(
            num_tokentypes=num_tokentypes,
159
            add_pooler=self.add_binary_head,
160
            encoder_attn_mask_type=AttnMaskType.padding,
161
            init_method=init_method,
162
163
164
            scaled_init_method=scaled_init_method,
            pre_process=self.pre_process,
            post_process=self.post_process)
165

166
        self.initialize_word_embeddings(init_method_normal)
167
        if self.post_process:
168
169
170
171
172
173
174
175
176
177
            self.lm_head = BertLMHead(
                self.word_embeddings_weight().size(0),
                args.hidden_size, init_method, args.layernorm_epsilon, parallel_output)
            self._lm_head_key = 'lm_head'
            self.binary_head = None
            if self.add_binary_head:
                self.binary_head = get_linear_layer(args.hidden_size, 2,
                                                    init_method)
                self._binary_head_key = 'binary_head'

178
    def set_input_tensor(self, input_tensor):
179
        """See megatron.model.transformer.set_input_tensor()"""
180
181
        self.language_model.set_input_tensor(input_tensor)

182
    def forward(self, bert_model_input, attention_mask,
mohammad's avatar
mohammad committed
183
                tokentype_ids=None, lm_labels=None):
184

185
        extended_attention_mask = bert_extended_attention_mask(attention_mask)
186
187
        input_ids = bert_model_input
        position_ids = bert_position_ids(input_ids)
188

189
190
191
192
193
194
195
196
        lm_output = self.language_model(
            input_ids,
            position_ids,
            extended_attention_mask,
            tokentype_ids=tokentype_ids
        )

        if self.post_process and self.add_binary_head:
197
            lm_output, pooled_output = lm_output
mohammad's avatar
mohammad committed
198
        else:
199
200
            pooled_output = None

201
        if self.post_process:
202
203
204
205
206
207
208
            return post_language_model_processing(lm_output, pooled_output,
                                                  self.lm_head, self.binary_head,
                                                  lm_labels,
                                                  self.word_embeddings_weight(),
                                                  self.fp16_lm_cross_entropy)
        else:
            return lm_output
209
210
211
212
213
214
215
216
217
218


    def state_dict_for_save_checkpoint(self, destination=None, prefix='',
                                       keep_vars=False):
        """For easy load when model is combined with other heads,
        add an extra key."""

        state_dict_ = {}
        state_dict_[self._language_model_key] \
            = self.language_model.state_dict_for_save_checkpoint(
219
            destination, prefix, keep_vars)
220
        if self.post_process:
221
222
223
            state_dict_[self._lm_head_key] \
                = self.lm_head.state_dict_for_save_checkpoint(
                destination, prefix, keep_vars)
224
        if self.post_process and self.add_binary_head:
225
226
            state_dict_[self._binary_head_key] \
                = self.binary_head.state_dict(destination, prefix, keep_vars)
227
        # Save word_embeddings.
228
        if self.post_process and not self.pre_process:
229
230
            state_dict_[self._word_embeddings_for_head_key] \
                = self.word_embeddings.state_dict(destination, prefix, keep_vars)
231
232
233
234
235
236
237
        return state_dict_

    def load_state_dict(self, state_dict, strict=True):
        """Customized load."""

        self.language_model.load_state_dict(
            state_dict[self._language_model_key], strict=strict)
238
        if self.post_process:
239
240
            self.lm_head.load_state_dict(
                state_dict[self._lm_head_key], strict=strict)
241
        if self.post_process and self.add_binary_head:
Neel Kant's avatar
Neel Kant committed
242
243
            self.binary_head.load_state_dict(
                state_dict[self._binary_head_key], strict=strict)
244
        # Load word_embeddings.
245
        if self.post_process and not self.pre_process:
246
247
            self.word_embeddings.load_state_dict(
                state_dict[self._word_embeddings_for_head_key], strict=strict)