bert_model.py 8.32 KB
Newer Older
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""BERT model."""

import torch

Mohammad's avatar
Mohammad committed
20
from megatron import get_args
21
22
23
24
25
26
27
from megatron.model.language_model import parallel_lm_logits
from megatron.model.language_model import get_language_model
from megatron.model.transformer import LayerNorm
from megatron.model.utils import openai_gelu
from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal
from megatron.model.utils import scaled_init_method_normal
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
from megatron.module import MegatronModule


def bert_attention_mask_func(attention_scores, attention_mask):
    attention_scores = attention_scores + attention_mask
    return attention_scores


def bert_extended_attention_mask(attention_mask, dtype):
    # We create a 3D attention mask from a 2D tensor mask.
    # [b, 1, s]
    attention_mask_b1s = attention_mask.unsqueeze(1)
    # [b, s, 1]
    attention_mask_bs1 = attention_mask.unsqueeze(2)
    # [b, s, s]
    attention_mask_bss = attention_mask_b1s * attention_mask_bs1
    # [b, 1, s, s]
    extended_attention_mask = attention_mask_bss.unsqueeze(1)
    # Since attention_mask is 1.0 for positions we want to attend and 0.0
    # for masked positions, this operation will create a tensor which is
    # 0.0 for positions we want to attend and -10000.0 for masked positions.
    # Since we are adding it to the raw scores before the softmax, this is
    # effectively the same as removing these entirely.
    # fp16 compatibility
    extended_attention_mask = extended_attention_mask.to(dtype=dtype)
    extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0

    return extended_attention_mask


def bert_position_ids(token_ids):
    # Create position ids
    seq_length = token_ids.size(1)
    position_ids = torch.arange(seq_length, dtype=torch.long,
                                device=token_ids.device)
    position_ids = position_ids.unsqueeze(0).expand_as(token_ids)

    return position_ids


class BertLMHead(MegatronModule):
    """Masked LM head for Bert

    Arguments:
        mpu_vocab_size: model parallel size of vocabulary.
        hidden_size: hidden size
        init_method: init method for weight initialization
        layernorm_epsilon: tolerance for layer norm divisions
76
        parallel_output: whether output logits being distributed or not.
77
    """
Neel Kant's avatar
Neel Kant committed
78

79
80
81
82
83
    def __init__(self, mpu_vocab_size, hidden_size, init_method,
                 layernorm_epsilon, parallel_output):

        super(BertLMHead, self).__init__()

84
        args = get_args()
Neel Kant's avatar
Neel Kant committed
85

86
87
        self.bias = torch.nn.Parameter(torch.zeros(mpu_vocab_size))
        self.bias.model_parallel = True
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
88
89
        self.bias.partition_dim = 0
        self.bias.stride = 1
90
91
92
93
        self.parallel_output = parallel_output

        self.dense = get_linear_layer(hidden_size, hidden_size, init_method)
        self.layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)
94
95
96
        self.gelu = torch.nn.functional.gelu
        if args.openai_gelu:
            self.gelu = openai_gelu
97
98
99

    def forward(self, hidden_states, word_embeddings_weight):
        hidden_states = self.dense(hidden_states)
100
        hidden_states = self.gelu(hidden_states)
101
102
103
104
105
106
107
108
109
110
111
        hidden_states = self.layernorm(hidden_states)
        output = parallel_lm_logits(hidden_states,
                                    word_embeddings_weight,
                                    self.parallel_output,
                                    bias=self.bias)
        return output


class BertModel(MegatronModule):
    """Bert Language model."""

Mohammad's avatar
Mohammad committed
112
    def __init__(self, num_tokentypes=2, add_binary_head=True,
Neel Kant's avatar
Neel Kant committed
113
                 ict_head_size=None, parallel_output=True):
114
        super(BertModel, self).__init__()
Mohammad's avatar
Mohammad committed
115
        args = get_args()
116
117

        self.add_binary_head = add_binary_head
118
119
120
121
        self.ict_head_size = ict_head_size
        self.add_ict_head = ict_head_size is not None
        assert not (self.add_binary_head and self.add_ict_head)

122
        self.parallel_output = parallel_output
Mohammad's avatar
Mohammad committed
123
        init_method = init_method_normal(args.init_method_std)
124
        add_pooler = self.add_binary_head or self.add_ict_head
Mohammad's avatar
Mohammad committed
125
126
        scaled_init_method = scaled_init_method_normal(args.init_method_std,
                                                       args.num_layers)
Neel Kant's avatar
Neel Kant committed
127

128
        self.language_model, self._language_model_key = get_language_model(
Mohammad's avatar
Mohammad committed
129
            attention_mask_func=bert_attention_mask_func,
130
            num_tokentypes=num_tokentypes,
131
            add_pooler=add_pooler,
132
            init_method=init_method,
133
            scaled_init_method=scaled_init_method)
134

Neel Kant's avatar
Neel Kant committed
135
136
137
        if not self.add_ict_head:
            self.lm_head = BertLMHead(
                self.language_model.embedding.word_embeddings.weight.size(0),
Neel Kant's avatar
Neel Kant committed
138
                args.hidden_size, init_method, args.layernorm_epsilon, parallel_output)
Neel Kant's avatar
Neel Kant committed
139
            self._lm_head_key = 'lm_head'
140
        if self.add_binary_head:
Mohammad's avatar
Mohammad committed
141
142
            self.binary_head = get_linear_layer(args.hidden_size, 2,
                                                init_method)
143
            self._binary_head_key = 'binary_head'
144
        elif self.add_ict_head:
Neel Kant's avatar
Neel Kant committed
145
            self.ict_head = get_linear_layer(args.hidden_size, ict_head_size, init_method)
146
            self._ict_head_key = 'ict_head'
147

148
    def forward(self, input_ids, attention_mask, tokentype_ids=None):
149
150
151
152
153

        extended_attention_mask = bert_extended_attention_mask(
            attention_mask, next(self.language_model.parameters()).dtype)
        position_ids = bert_position_ids(input_ids)

154
        if self.add_binary_head or self.add_ict_head:
155
156
157
158
159
160
161
162
163
164
165
166
167
            lm_output, pooled_output = self.language_model(
                input_ids,
                position_ids,
                extended_attention_mask,
                tokentype_ids=tokentype_ids)
        else:
            lm_output = self.language_model(
                input_ids,
                position_ids,
                extended_attention_mask,
                tokentype_ids=tokentype_ids)

        # Output.
Neel Kant's avatar
Neel Kant committed
168
169
170
171
        if self.add_ict_head:
            ict_logits = self.ict_head(pooled_output)
            return ict_logits, None

172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
        lm_logits = self.lm_head(
            lm_output, self.language_model.embedding.word_embeddings.weight)
        if self.add_binary_head:
            binary_logits = self.binary_head(pooled_output)
            return lm_logits, binary_logits

        return lm_logits, None

    def state_dict_for_save_checkpoint(self, destination=None, prefix='',
                                       keep_vars=False):
        """For easy load when model is combined with other heads,
        add an extra key."""

        state_dict_ = {}
        state_dict_[self._language_model_key] \
            = self.language_model.state_dict_for_save_checkpoint(
                destination, prefix, keep_vars)
189
190
191
192
        if not self.add_ict_head:
            state_dict_[self._lm_head_key] \
                = self.lm_head.state_dict_for_save_checkpoint(
                    destination, prefix, keep_vars)
193
194
195
        if self.add_binary_head:
            state_dict_[self._binary_head_key] \
                = self.binary_head.state_dict(destination, prefix, keep_vars)
196
197
198
        elif self.add_ict_head:
            state_dict_[self._ict_head_key] \
                = self.ict_head.state_dict(destination, prefix, keep_vars)
199
200
201
202
203
204
205
        return state_dict_

    def load_state_dict(self, state_dict, strict=True):
        """Customized load."""

        self.language_model.load_state_dict(
            state_dict[self._language_model_key], strict=strict)
206
207
208
        if not self.add_ict_head:
            self.lm_head.load_state_dict(
                state_dict[self._lm_head_key], strict=strict)
209
        if self.add_binary_head:
Neel Kant's avatar
Neel Kant committed
210
211
            self.binary_head.load_state_dict(
                state_dict[self._binary_head_key], strict=strict)
212
        elif self.add_ict_head:
Neel Kant's avatar
Neel Kant committed
213
214
            self.ict_head.load_state_dict(
                state_dict[self._ict_head_key], strict=strict)
215

216