bert_model.py 7.89 KB
Newer Older
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""BERT model."""

import torch

Mohammad's avatar
Mohammad committed
20
from megatron import get_args
mohammad's avatar
mohammad committed
21
from megatron import mpu
22
23
24
25
26
27
28
from megatron.model.language_model import parallel_lm_logits
from megatron.model.language_model import get_language_model
from megatron.model.transformer import LayerNorm
from megatron.model.utils import openai_gelu
from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal
from megatron.model.utils import scaled_init_method_normal
29
30
31
from megatron.module import MegatronModule


Neel Kant's avatar
Neel Kant committed
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
def bert_attention_mask_func(attention_scores, attention_mask):
    attention_scores = attention_scores + attention_mask
    return attention_scores


def bert_extended_attention_mask(attention_mask, dtype):
    # We create a 3D attention mask from a 2D tensor mask.
    # [b, 1, s]
    attention_mask_b1s = attention_mask.unsqueeze(1)
    # [b, s, 1]
    attention_mask_bs1 = attention_mask.unsqueeze(2)
    # [b, s, s]
    attention_mask_bss = attention_mask_b1s * attention_mask_bs1
    # [b, 1, s, s]
    extended_attention_mask = attention_mask_bss.unsqueeze(1)
    # Since attention_mask is 1.0 for positions we want to attend and 0.0
    # for masked positions, this operation will create a tensor which is
    # 0.0 for positions we want to attend and -10000.0 for masked positions.
    # Since we are adding it to the raw scores before the softmax, this is
    # effectively the same as removing these entirely.
    # fp16 compatibility
    extended_attention_mask = extended_attention_mask.to(dtype=dtype)
    extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0

    return extended_attention_mask


def bert_position_ids(token_ids):
    # Create position ids
    seq_length = token_ids.size(1)
    position_ids = torch.arange(seq_length, dtype=torch.long,
                                device=token_ids.device)
    position_ids = position_ids.unsqueeze(0).expand_as(token_ids)

    return position_ids


69
70
71
72
73
74
75
76
class BertLMHead(MegatronModule):
    """Masked LM head for Bert

    Arguments:
        mpu_vocab_size: model parallel size of vocabulary.
        hidden_size: hidden size
        init_method: init method for weight initialization
        layernorm_epsilon: tolerance for layer norm divisions
77
        parallel_output: whether output logits being distributed or not.
78
    """
Neel Kant's avatar
Neel Kant committed
79

80
81
82
83
84
    def __init__(self, mpu_vocab_size, hidden_size, init_method,
                 layernorm_epsilon, parallel_output):

        super(BertLMHead, self).__init__()

85
        args = get_args()
Neel Kant's avatar
Neel Kant committed
86

87
88
        self.bias = torch.nn.Parameter(torch.zeros(mpu_vocab_size))
        self.bias.model_parallel = True
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
89
90
        self.bias.partition_dim = 0
        self.bias.stride = 1
91
92
93
94
        self.parallel_output = parallel_output

        self.dense = get_linear_layer(hidden_size, hidden_size, init_method)
        self.layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)
95
96
97
        self.gelu = torch.nn.functional.gelu
        if args.openai_gelu:
            self.gelu = openai_gelu
98
99
100

    def forward(self, hidden_states, word_embeddings_weight):
        hidden_states = self.dense(hidden_states)
101
        hidden_states = self.gelu(hidden_states)
102
103
104
105
106
107
108
109
110
111
112
        hidden_states = self.layernorm(hidden_states)
        output = parallel_lm_logits(hidden_states,
                                    word_embeddings_weight,
                                    self.parallel_output,
                                    bias=self.bias)
        return output


class BertModel(MegatronModule):
    """Bert Language model."""

Mohammad's avatar
Mohammad committed
113
    def __init__(self, num_tokentypes=2, add_binary_head=True,
114
                 parallel_output=True):
115
        super(BertModel, self).__init__()
Mohammad's avatar
Mohammad committed
116
        args = get_args()
117

mohammad's avatar
mohammad committed
118
        self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
119
120
        self.add_binary_head = add_binary_head
        self.parallel_output = parallel_output
Mohammad's avatar
Mohammad committed
121
122
123
        init_method = init_method_normal(args.init_method_std)
        scaled_init_method = scaled_init_method_normal(args.init_method_std,
                                                       args.num_layers)
Neel Kant's avatar
Neel Kant committed
124

125
        self.language_model, self._language_model_key = get_language_model(
Mohammad's avatar
Mohammad committed
126
            attention_mask_func=bert_attention_mask_func,
127
            num_tokentypes=num_tokentypes,
128
            add_pooler=self.add_binary_head,
129
            init_method=init_method,
130
            scaled_init_method=scaled_init_method)
131

132
133
134
135
        self.lm_head = BertLMHead(
            self.language_model.embedding.word_embeddings.weight.size(0),
            args.hidden_size, init_method, args.layernorm_epsilon, parallel_output)
        self._lm_head_key = 'lm_head'
136
        if self.add_binary_head:
Mohammad's avatar
Mohammad committed
137
138
            self.binary_head = get_linear_layer(args.hidden_size, 2,
                                                init_method)
139
140
            self._binary_head_key = 'binary_head'

mohammad's avatar
mohammad committed
141
142
    def forward(self, input_ids, attention_mask,
                tokentype_ids=None, lm_labels=None):
143
144
145
146
147

        extended_attention_mask = bert_extended_attention_mask(
            attention_mask, next(self.language_model.parameters()).dtype)
        position_ids = bert_position_ids(input_ids)

148
        if self.add_binary_head:
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
            lm_output, pooled_output = self.language_model(
                input_ids,
                position_ids,
                extended_attention_mask,
                tokentype_ids=tokentype_ids)
        else:
            lm_output = self.language_model(
                input_ids,
                position_ids,
                extended_attention_mask,
                tokentype_ids=tokentype_ids)

        # Output.
        lm_logits = self.lm_head(
            lm_output, self.language_model.embedding.word_embeddings.weight)
164

mohammad's avatar
mohammad committed
165
        binary_logits = None
166
167
        if self.add_binary_head:
            binary_logits = self.binary_head(pooled_output)
mohammad's avatar
mohammad committed
168
169

        if lm_labels is None:
170
            return lm_logits, binary_logits
mohammad's avatar
mohammad committed
171
        else:
mohammad's avatar
mohammad committed
172
173
174
175
176
177
            if self.fp16_lm_cross_entropy:
                assert lm_logits.dtype == torch.half
                lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits, lm_labels)
            else:
                lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits.float(),
                                                           lm_labels)
mohammad's avatar
mohammad committed
178
            return lm_loss, binary_logits
179
180
181
182
183
184
185
186
187
188


    def state_dict_for_save_checkpoint(self, destination=None, prefix='',
                                       keep_vars=False):
        """For easy load when model is combined with other heads,
        add an extra key."""

        state_dict_ = {}
        state_dict_[self._language_model_key] \
            = self.language_model.state_dict_for_save_checkpoint(
189
190
191
192
            destination, prefix, keep_vars)
        state_dict_[self._lm_head_key] \
            = self.lm_head.state_dict_for_save_checkpoint(
            destination, prefix, keep_vars)
193
194
195
196
197
198
199
200
201
202
        if self.add_binary_head:
            state_dict_[self._binary_head_key] \
                = self.binary_head.state_dict(destination, prefix, keep_vars)
        return state_dict_

    def load_state_dict(self, state_dict, strict=True):
        """Customized load."""

        self.language_model.load_state_dict(
            state_dict[self._language_model_key], strict=strict)
203
204
        self.lm_head.load_state_dict(
            state_dict[self._lm_head_key], strict=strict)
205
        if self.add_binary_head:
Neel Kant's avatar
Neel Kant committed
206
207
            self.binary_head.load_state_dict(
                state_dict[self._binary_head_key], strict=strict)