bert_model.py 9.69 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2
3
4
5
6

"""BERT model."""

import torch

Mohammad's avatar
Mohammad committed
7
from megatron import get_args
8
from megatron.core import tensor_parallel
9
from megatron.model.enums import AttnMaskType
10
11
from megatron.model.language_model import parallel_lm_logits
from megatron.model.language_model import get_language_model
12
from megatron.model import LayerNorm
13
from megatron.model.utils import openai_gelu, erf_gelu
14
15
16
from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal
from megatron.model.utils import scaled_init_method_normal
17
from .module import MegatronModule
18

Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
19

20
def bert_extended_attention_mask(attention_mask):
Neel Kant's avatar
Neel Kant committed
21
22
23
24
25
26
27
28
29
30
    # We create a 3D attention mask from a 2D tensor mask.
    # [b, 1, s]
    attention_mask_b1s = attention_mask.unsqueeze(1)
    # [b, s, 1]
    attention_mask_bs1 = attention_mask.unsqueeze(2)
    # [b, s, s]
    attention_mask_bss = attention_mask_b1s * attention_mask_bs1
    # [b, 1, s, s]
    extended_attention_mask = attention_mask_bss.unsqueeze(1)

31
32
    # Convert attention mask to binary:
    extended_attention_mask = (extended_attention_mask < 0.5)
Neel Kant's avatar
Neel Kant committed
33

34
    return extended_attention_mask
Neel Kant's avatar
Neel Kant committed
35
36
37
38
39
40
41
42
43
44
45

def bert_position_ids(token_ids):
    # Create position ids
    seq_length = token_ids.size(1)
    position_ids = torch.arange(seq_length, dtype=torch.long,
                                device=token_ids.device)
    position_ids = position_ids.unsqueeze(0).expand_as(token_ids)

    return position_ids


46
47
48
49
class BertLMHead(MegatronModule):
    """Masked LM head for Bert

    Arguments:
liangjing's avatar
v1  
liangjing committed
50
        config: TransformerConfig object
51
52
        mpu_vocab_size: model parallel size of vocabulary.
        hidden_size: hidden size
53
        parallel_output: whether output logits being distributed or not.
54
    """
Neel Kant's avatar
Neel Kant committed
55

liangjing's avatar
v1  
liangjing committed
56
57
    def __init__(self, mpu_vocab_size, hidden_size, config, parallel_output):
        super().__init__(config=config)
58

59
        args = get_args()
60
        self.bias = torch.nn.Parameter(torch.zeros(mpu_vocab_size))
61
        tensor_parallel.set_tensor_model_parallel_attributes(self.bias, True, 0, 1)
62
63
        self.parallel_output = parallel_output

liangjing's avatar
v1  
liangjing committed
64
65
66
        self.dense = get_linear_layer(hidden_size, hidden_size, config.init_method)
        setattr(self.dense.weight, 'sequence_parallel', config.sequence_parallel)
        setattr(self.dense.bias, 'sequence_parallel', config.sequence_parallel)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
67

68
        self.layernorm = LayerNorm(hidden_size,
liangjing's avatar
v1  
liangjing committed
69
70
                                   eps=config.layernorm_epsilon,
                                   sequence_parallel=config.sequence_parallel)
71
72
73
        self.gelu = torch.nn.functional.gelu
        if args.openai_gelu:
            self.gelu = openai_gelu
74
        elif args.onnx_safe:
Boris Fomitchev's avatar
Boris Fomitchev committed
75
            self.gelu = erf_gelu
76
77
78

    def forward(self, hidden_states, word_embeddings_weight):
        hidden_states = self.dense(hidden_states)
79
        hidden_states = self.gelu(hidden_states)
80
81
82
83
84
85
86
87
        hidden_states = self.layernorm(hidden_states)
        output = parallel_lm_logits(hidden_states,
                                    word_embeddings_weight,
                                    self.parallel_output,
                                    bias=self.bias)
        return output


88
89
90
91
92
93
94
95
96
97
def post_language_model_processing(lm_output, pooled_output,
                                   lm_head, binary_head,
                                   lm_labels,
                                   logit_weights,
                                   fp16_lm_cross_entropy):
    # Output.
    lm_logits = lm_head(
        lm_output, logit_weights)

    binary_logits = None
98
    if binary_head is not None:
99
100
101
        binary_logits = binary_head(pooled_output)

    if lm_labels is None:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
102
103
        # [s b h] => [b s h]
        return lm_logits.transpose(0,1).contiguous(), binary_logits
104
    else:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
105
        # [b s] => [s b]
Vijay Korthikanti's avatar
Vijay Korthikanti committed
106
107
        lm_labels = lm_labels.transpose(0,1).contiguous()
        # lm_logits : [s, b, h] and lm_labels: [s, b]
108
109
        if fp16_lm_cross_entropy:
            assert lm_logits.dtype == torch.half
110
            lm_loss = tensor_parallel.vocab_parallel_cross_entropy(lm_logits, lm_labels)
111
        else:
112
            lm_loss = tensor_parallel.vocab_parallel_cross_entropy(lm_logits.float(),
113
                                                        lm_labels)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
114
115
        # [s, b] => [b s]
        lm_loss = lm_loss.transpose(0,1).contiguous()
116
117
118
        return lm_loss, binary_logits


119
class BertModel(MegatronModule):
120
121
    """Bert Language model."""

122
    def __init__(self,
liangjing's avatar
v1  
liangjing committed
123
                 config,
124
                 num_tokentypes=2,
125
126
127
128
                 add_binary_head=True,
                 parallel_output=True,
                 pre_process=True,
                 post_process=True):
liangjing's avatar
v1  
liangjing committed
129
        super().__init__(config=config)
Mohammad's avatar
Mohammad committed
130
        args = get_args()
131

132
133
134
        # TODO this option is not yet implemented in BERT
        assert args.untie_embeddings_and_output_weights is False

mohammad's avatar
mohammad committed
135
        self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
136
137
        self.add_binary_head = add_binary_head
        self.parallel_output = parallel_output
138
139
        self.pre_process = pre_process
        self.post_process = post_process
140

Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
141
142
143
144
        self.return_embeddings = args.output_bert_embeddings
        if self.return_embeddings:
            assert self.post_process and self.add_binary_head

145
        self.language_model, self._language_model_key = get_language_model(
liangjing's avatar
v1  
liangjing committed
146
            config=config,
147
            num_tokentypes=num_tokentypes,
148
            add_pooler=self.add_binary_head,
149
            encoder_attn_mask_type=AttnMaskType.padding,
150
151
            pre_process=self.pre_process,
            post_process=self.post_process)
152

liangjing's avatar
v1  
liangjing committed
153
        self.initialize_word_embeddings()
154
        if self.post_process:
liangjing's avatar
v1  
liangjing committed
155
156
            self.lm_head = BertLMHead(self.shared_embedding_or_output_weight().size(0), config.hidden_size,
                                      config, parallel_output)
157
158
159
            self._lm_head_key = 'lm_head'
            self.binary_head = None
            if self.add_binary_head:
liangjing's avatar
v1  
liangjing committed
160
161
                self.binary_head = get_linear_layer(config.hidden_size, 2,
                                                    config.init_method)
162
163
                self._binary_head_key = 'binary_head'

164
    def set_input_tensor(self, input_tensor):
165
        """See megatron.model.transformer.set_input_tensor()"""
166
167
        self.language_model.set_input_tensor(input_tensor)

168
    def forward(self, bert_model_input, attention_mask,
mohammad's avatar
mohammad committed
169
                tokentype_ids=None, lm_labels=None):
170

171
        extended_attention_mask = bert_extended_attention_mask(attention_mask)
172
173
        input_ids = bert_model_input
        position_ids = bert_position_ids(input_ids)
174

175
176
177
178
179
180
181
182
        lm_output = self.language_model(
            input_ids,
            position_ids,
            extended_attention_mask,
            tokentype_ids=tokentype_ids
        )

        if self.post_process and self.add_binary_head:
183
            lm_output, pooled_output = lm_output
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201

            # Return pooled output (e.g., when computing Bert embeddings).
            if self.return_embeddings:

                # Sum attention mask.
                embeddings = torch.transpose(lm_output, 0, 1)
                masks = torch.sum(attention_mask, dim=1)

                # Collect masked embeddings.
                output = torch.zeros(
                    size=(embeddings.shape[0], embeddings.shape[2]),
                    dtype=torch.float32,
                    device=torch.cuda.current_device())
                for i, (embedding, mask) in enumerate(zip(embeddings, masks)):
                    output[i, :] = torch.mean(embedding[1: mask - 1], dim=0)

                return output

mohammad's avatar
mohammad committed
202
        else:
203
204
            pooled_output = None

205
        if self.post_process:
206
207
208
            return post_language_model_processing(lm_output, pooled_output,
                                                  self.lm_head, self.binary_head,
                                                  lm_labels,
liangjing's avatar
v1  
liangjing committed
209
                                                  self.shared_embedding_or_output_weight(),
210
211
212
                                                  self.fp16_lm_cross_entropy)
        else:
            return lm_output
213
214


215
    def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
216
217
218
219
220
        """For easy load when model is combined with other heads,
        add an extra key."""

        state_dict_ = {}
        state_dict_[self._language_model_key] \
221
222
            = self.language_model.state_dict_for_save_checkpoint(prefix=prefix,
                                                                 keep_vars=keep_vars)
223
        if self.post_process:
224
            state_dict_[self._lm_head_key] \
225
226
                = self.lm_head.state_dict_for_save_checkpoint(prefix=prefix,
                                                              keep_vars=keep_vars)
227
        if self.post_process and self.add_binary_head:
228
            state_dict_[self._binary_head_key] \
229
                = self.binary_head.state_dict(prefix=prefix, keep_vars=keep_vars)
230
        # Save word_embeddings.
231
        if self.post_process and not self.pre_process:
232
            state_dict_[self._word_embeddings_for_head_key] \
233
                = self.word_embeddings.state_dict(prefix=prefix, keep_vars=keep_vars)
234
235
236
237
238
239
240
        return state_dict_

    def load_state_dict(self, state_dict, strict=True):
        """Customized load."""

        self.language_model.load_state_dict(
            state_dict[self._language_model_key], strict=strict)
241
        if self.post_process:
242
243
            self.lm_head.load_state_dict(
                state_dict[self._lm_head_key], strict=strict)
244
        if self.post_process and self.add_binary_head:
Neel Kant's avatar
Neel Kant committed
245
246
            self.binary_head.load_state_dict(
                state_dict[self._binary_head_key], strict=strict)
247
        # Load word_embeddings.
248
        if self.post_process and not self.pre_process:
249
250
            self.word_embeddings.load_state_dict(
                state_dict[self._word_embeddings_for_head_key], strict=strict)