bert_model.py 6.13 KB
Newer Older
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""BERT model."""

import torch

Mohammad's avatar
Mohammad committed
20
from megatron import get_args
21
22
23
24
25
26
27
from megatron.model.language_model import parallel_lm_logits
from megatron.model.language_model import get_language_model
from megatron.model.transformer import LayerNorm
from megatron.model.utils import openai_gelu
from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal
from megatron.model.utils import scaled_init_method_normal
28
29
30
from megatron.model.utils import bert_attention_mask_func
from megatron.model.utils import bert_extended_attention_mask
from megatron.model.utils import bert_position_ids
31
32
33
34
35
36
37
38
39
40
41
from megatron.module import MegatronModule


class BertLMHead(MegatronModule):
    """Masked LM head for Bert

    Arguments:
        mpu_vocab_size: model parallel size of vocabulary.
        hidden_size: hidden size
        init_method: init method for weight initialization
        layernorm_epsilon: tolerance for layer norm divisions
42
        parallel_output: whether output logits being distributed or not.
43
    """
Neel Kant's avatar
Neel Kant committed
44

45
46
47
48
49
    def __init__(self, mpu_vocab_size, hidden_size, init_method,
                 layernorm_epsilon, parallel_output):

        super(BertLMHead, self).__init__()

50
        args = get_args()
Neel Kant's avatar
Neel Kant committed
51

52
53
        self.bias = torch.nn.Parameter(torch.zeros(mpu_vocab_size))
        self.bias.model_parallel = True
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
54
55
        self.bias.partition_dim = 0
        self.bias.stride = 1
56
57
58
59
        self.parallel_output = parallel_output

        self.dense = get_linear_layer(hidden_size, hidden_size, init_method)
        self.layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)
60
61
62
        self.gelu = torch.nn.functional.gelu
        if args.openai_gelu:
            self.gelu = openai_gelu
63
64
65

    def forward(self, hidden_states, word_embeddings_weight):
        hidden_states = self.dense(hidden_states)
66
        hidden_states = self.gelu(hidden_states)
67
68
69
70
71
72
73
74
75
76
77
        hidden_states = self.layernorm(hidden_states)
        output = parallel_lm_logits(hidden_states,
                                    word_embeddings_weight,
                                    self.parallel_output,
                                    bias=self.bias)
        return output


class BertModel(MegatronModule):
    """Bert Language model."""

Mohammad's avatar
Mohammad committed
78
    def __init__(self, num_tokentypes=2, add_binary_head=True,
79
                 parallel_output=True):
80
        super(BertModel, self).__init__()
Mohammad's avatar
Mohammad committed
81
        args = get_args()
82
83
84

        self.add_binary_head = add_binary_head
        self.parallel_output = parallel_output
Mohammad's avatar
Mohammad committed
85
86
87
        init_method = init_method_normal(args.init_method_std)
        scaled_init_method = scaled_init_method_normal(args.init_method_std,
                                                       args.num_layers)
Neel Kant's avatar
Neel Kant committed
88

89
        self.language_model, self._language_model_key = get_language_model(
Mohammad's avatar
Mohammad committed
90
            attention_mask_func=bert_attention_mask_func,
91
            num_tokentypes=num_tokentypes,
92
            add_pooler=self.add_binary_head,
93
            init_method=init_method,
94
            scaled_init_method=scaled_init_method)
95

96
97
98
99
        self.lm_head = BertLMHead(
            self.language_model.embedding.word_embeddings.weight.size(0),
            args.hidden_size, init_method, args.layernorm_epsilon, parallel_output)
        self._lm_head_key = 'lm_head'
100
        if self.add_binary_head:
Mohammad's avatar
Mohammad committed
101
102
            self.binary_head = get_linear_layer(args.hidden_size, 2,
                                                init_method)
103
104
            self._binary_head_key = 'binary_head'

105
    def forward(self, input_ids, attention_mask, tokentype_ids=None):
106
107
108
109
110

        extended_attention_mask = bert_extended_attention_mask(
            attention_mask, next(self.language_model.parameters()).dtype)
        position_ids = bert_position_ids(input_ids)

111
        if self.add_binary_head:
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
            lm_output, pooled_output = self.language_model(
                input_ids,
                position_ids,
                extended_attention_mask,
                tokentype_ids=tokentype_ids)
        else:
            lm_output = self.language_model(
                input_ids,
                position_ids,
                extended_attention_mask,
                tokentype_ids=tokentype_ids)

        # Output.
        lm_logits = self.lm_head(
            lm_output, self.language_model.embedding.word_embeddings.weight)
127

128
129
130
131
132
133
134
135
136
137
138
139
140
141
        if self.add_binary_head:
            binary_logits = self.binary_head(pooled_output)
            return lm_logits, binary_logits

        return lm_logits, None

    def state_dict_for_save_checkpoint(self, destination=None, prefix='',
                                       keep_vars=False):
        """For easy load when model is combined with other heads,
        add an extra key."""

        state_dict_ = {}
        state_dict_[self._language_model_key] \
            = self.language_model.state_dict_for_save_checkpoint(
142
143
144
145
            destination, prefix, keep_vars)
        state_dict_[self._lm_head_key] \
            = self.lm_head.state_dict_for_save_checkpoint(
            destination, prefix, keep_vars)
146
147
148
149
150
151
152
153
154
155
        if self.add_binary_head:
            state_dict_[self._binary_head_key] \
                = self.binary_head.state_dict(destination, prefix, keep_vars)
        return state_dict_

    def load_state_dict(self, state_dict, strict=True):
        """Customized load."""

        self.language_model.load_state_dict(
            state_dict[self._language_model_key], strict=strict)
156
157
        self.lm_head.load_state_dict(
            state_dict[self._lm_head_key], strict=strict)
158
        if self.add_binary_head:
Neel Kant's avatar
Neel Kant committed
159
160
            self.binary_head.load_state_dict(
                state_dict[self._binary_head_key], strict=strict)
161

162