bert_model.py 9.96 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""BERT model."""

import torch

Mohammad's avatar
Mohammad committed
20
from megatron import get_args
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
from megatron.module import MegatronModule

from .language_model import parallel_lm_logits
from .language_model import get_language_model
from .transformer import LayerNorm
from .utils import gelu
from .utils import get_linear_layer
from .utils import init_method_normal
from .utils import scaled_init_method_normal


def bert_attention_mask_func(attention_scores, attention_mask):
    attention_scores = attention_scores + attention_mask
    return attention_scores


def bert_extended_attention_mask(attention_mask, dtype):
    # We create a 3D attention mask from a 2D tensor mask.
    # [b, 1, s]
    attention_mask_b1s = attention_mask.unsqueeze(1)
    # [b, s, 1]
    attention_mask_bs1 = attention_mask.unsqueeze(2)
    # [b, s, s]
    attention_mask_bss = attention_mask_b1s * attention_mask_bs1
    # [b, 1, s, s]
    extended_attention_mask = attention_mask_bss.unsqueeze(1)
    # Since attention_mask is 1.0 for positions we want to attend and 0.0
    # for masked positions, this operation will create a tensor which is
    # 0.0 for positions we want to attend and -10000.0 for masked positions.
    # Since we are adding it to the raw scores before the softmax, this is
    # effectively the same as removing these entirely.
    # fp16 compatibility
    extended_attention_mask = extended_attention_mask.to(dtype=dtype)
    extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0

    return extended_attention_mask


def bert_position_ids(token_ids):
    # Create position ids
    seq_length = token_ids.size(1)
    position_ids = torch.arange(seq_length, dtype=torch.long,
                                device=token_ids.device)
    position_ids = position_ids.unsqueeze(0).expand_as(token_ids)

    return position_ids



class BertLMHead(MegatronModule):
    """Masked LM head for Bert

    Arguments:
        mpu_vocab_size: model parallel size of vocabulary.
        hidden_size: hidden size
        init_method: init method for weight initialization
        layernorm_epsilon: tolerance for layer norm divisions
78
        parallel_output: whether output logits being distributed or not.
79
80
81
82
83
84
85
86
    """
    def __init__(self, mpu_vocab_size, hidden_size, init_method,
                 layernorm_epsilon, parallel_output):

        super(BertLMHead, self).__init__()

        self.bias = torch.nn.Parameter(torch.zeros(mpu_vocab_size))
        self.bias.model_parallel = True
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
87
88
        self.bias.partition_dim = 0
        self.bias.stride = 1
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
        self.parallel_output = parallel_output

        self.dense = get_linear_layer(hidden_size, hidden_size, init_method)
        self.layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)


    def forward(self, hidden_states, word_embeddings_weight):
        hidden_states = self.dense(hidden_states)
        hidden_states = gelu(hidden_states)
        hidden_states = self.layernorm(hidden_states)
        output = parallel_lm_logits(hidden_states,
                                    word_embeddings_weight,
                                    self.parallel_output,
                                    bias=self.bias)
        return output



class BertModel(MegatronModule):
    """Bert Language model."""

Mohammad's avatar
Mohammad committed
110
    def __init__(self, num_tokentypes=2, add_binary_head=True,
Neel Kant's avatar
Neel Kant committed
111
                 ict_head_size=None, parallel_output=True):
112
        super(BertModel, self).__init__()
Mohammad's avatar
Mohammad committed
113
        args = get_args()
114
115

        self.add_binary_head = add_binary_head
116
117
118
119
        self.ict_head_size = ict_head_size
        self.add_ict_head = ict_head_size is not None
        assert not (self.add_binary_head and self.add_ict_head)

120
        self.parallel_output = parallel_output
Mohammad's avatar
Mohammad committed
121
        init_method = init_method_normal(args.init_method_std)
122
        add_pooler = self.add_binary_head or self.add_ict_head
Mohammad's avatar
Mohammad committed
123
124
        scaled_init_method = scaled_init_method_normal(args.init_method_std,
                                                       args.num_layers)
125
        self.language_model, self._language_model_key = get_language_model(
Mohammad's avatar
Mohammad committed
126
            attention_mask_func=bert_attention_mask_func,
127
            num_tokentypes=num_tokentypes,
128
            add_pooler=add_pooler,
129
            init_method=init_method,
Mohammad's avatar
Mohammad committed
130
            scaled_init_method=scaled_init_method)
131

Neel Kant's avatar
Neel Kant committed
132
133
134
        if not self.add_ict_head:
            self.lm_head = BertLMHead(
                self.language_model.embedding.word_embeddings.weight.size(0),
Neel Kant's avatar
Neel Kant committed
135
                args.hidden_size, init_method, args.layernorm_epsilon, parallel_output)
Neel Kant's avatar
Neel Kant committed
136
            self._lm_head_key = 'lm_head'
137
        if self.add_binary_head:
Mohammad's avatar
Mohammad committed
138
139
            self.binary_head = get_linear_layer(args.hidden_size, 2,
                                                init_method)
140
            self._binary_head_key = 'binary_head'
141
        elif self.add_ict_head:
Neel Kant's avatar
Neel Kant committed
142
            self.ict_head = get_linear_layer(args.hidden_size, ict_head_size, init_method)
143
            self._ict_head_key = 'ict_head'
144

145
    def forward(self, input_ids, attention_mask, tokentype_ids=None):
146
147
148
149
150

        extended_attention_mask = bert_extended_attention_mask(
            attention_mask, next(self.language_model.parameters()).dtype)
        position_ids = bert_position_ids(input_ids)

151
        if self.add_binary_head or self.add_ict_head:
152
153
154
155
156
157
158
159
160
161
162
163
164
            lm_output, pooled_output = self.language_model(
                input_ids,
                position_ids,
                extended_attention_mask,
                tokentype_ids=tokentype_ids)
        else:
            lm_output = self.language_model(
                input_ids,
                position_ids,
                extended_attention_mask,
                tokentype_ids=tokentype_ids)

        # Output.
Neel Kant's avatar
Neel Kant committed
165
166
167
168
        if self.add_ict_head:
            ict_logits = self.ict_head(pooled_output)
            return ict_logits, None

169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
        lm_logits = self.lm_head(
            lm_output, self.language_model.embedding.word_embeddings.weight)
        if self.add_binary_head:
            binary_logits = self.binary_head(pooled_output)
            return lm_logits, binary_logits

        return lm_logits, None


    def state_dict_for_save_checkpoint(self, destination=None, prefix='',
                                       keep_vars=False):
        """For easy load when model is combined with other heads,
        add an extra key."""

        state_dict_ = {}
        state_dict_[self._language_model_key] \
            = self.language_model.state_dict_for_save_checkpoint(
                destination, prefix, keep_vars)
187
188
189
190
        if not self.add_ict_head:
            state_dict_[self._lm_head_key] \
                = self.lm_head.state_dict_for_save_checkpoint(
                    destination, prefix, keep_vars)
191
192
193
        if self.add_binary_head:
            state_dict_[self._binary_head_key] \
                = self.binary_head.state_dict(destination, prefix, keep_vars)
194
195
196
        elif self.add_ict_head:
            state_dict_[self._ict_head_key] \
                = self.ict_head.state_dict(destination, prefix, keep_vars)
197
198
199
200
201
202
203
204
        return state_dict_


    def load_state_dict(self, state_dict, strict=True):
        """Customized load."""

        self.language_model.load_state_dict(
            state_dict[self._language_model_key], strict=strict)
205
206
207
        if not self.add_ict_head:
            self.lm_head.load_state_dict(
                state_dict[self._lm_head_key], strict=strict)
208
        if self.add_binary_head:
Neel Kant's avatar
Neel Kant committed
209
210
            self.binary_head.load_state_dict(
                state_dict[self._binary_head_key], strict=strict)
211
        elif self.add_ict_head:
Neel Kant's avatar
Neel Kant committed
212
213
            self.ict_head.load_state_dict(
                state_dict[self._ict_head_key], strict=strict)
214

215
216
217
218
219

class ICTBertModel(MegatronModule):
    def __init__(self,
                 ict_head_size,
                 num_tokentypes=0,
Neel Kant's avatar
Neel Kant committed
220
                 parallel_output=True):
221
222
        super(ICTBertModel, self).__init__()
        bert_args = dict(
Neel Kant's avatar
Neel Kant committed
223
            num_tokentypes=num_tokentypes,
224
225
            add_binary_head=False,
            ict_head_size=ict_head_size,
Neel Kant's avatar
Neel Kant committed
226
227
            parallel_output=parallel_output
        )
228
229

        self.question_model = BertModel(**bert_args)
230
231
232
233
234
235
236
        self._question_key = 'question_model'
        self.context_model = BertModel(**bert_args)
        self._context_key = 'context_model'

    def forward(self, input_tokens, input_attention_mask, input_types,
                context_tokens, context_attention_mask, context_types):

237
238
        question_ict_logits, _ = self.question_model.forward(input_tokens, 1 - input_attention_mask, input_types)
        context_ict_logits, _ = self.context_model.forward(context_tokens, 1 - context_attention_mask, context_types)
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262

        # [batch x h] * [h x batch]
        retrieval_scores = question_ict_logits.matmul(torch.transpose(context_ict_logits, 0, 1))

        return retrieval_scores

    def state_dict_for_save_checkpoint(self, destination=None, prefix='',
                                       keep_vars=False):
        state_dict_ = {}
        state_dict_[self._question_key] \
            = self.question_model.state_dict_for_save_checkpoint(
            destination, prefix, keep_vars)
        state_dict_[self._context_key] \
            = self.context_model.state_dict_for_save_checkpoint(
            destination, prefix, keep_vars)
        return state_dict_

    def load_state_dict(self, state_dict, strict=True):
        """Customized load."""

        self.question_model.load_state_dict(
            state_dict[self._question_key], strict=strict)
        self.context_model.load_state_dict(
            state_dict[self._context_key], strict=strict)