gpt_model.py 4.84 KB
Newer Older
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""GPT-2 model."""

import torch

Mohammad's avatar
Mohammad committed
20
from megatron import get_args
21
from megatron import mpu
22
from .module import MegatronModule
23

24
from .enums import AttnMaskType
25
26
27
28
29
30
from .language_model import parallel_lm_logits
from .language_model import get_language_model
from .utils import init_method_normal
from .utils import scaled_init_method_normal


31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
def post_language_model_processing(lm_output, labels, logit_weights,
                                   get_key_value, parallel_output,
                                   forward_method_parallel_output,
                                   fp16_lm_cross_entropy):
    if get_key_value:
        lm_output, presents = lm_output

    # Output.
    if forward_method_parallel_output is not None:
        parallel_output = forward_method_parallel_output
    output = parallel_lm_logits(
        lm_output,
        logit_weights,
        parallel_output)

    if get_key_value:
        output = [output, presents]

    if labels is None:
        return output
    else:
        if fp16_lm_cross_entropy:
            assert output.dtype == torch.half
            loss = mpu.vocab_parallel_cross_entropy(output, labels)
        else:
            loss = mpu.vocab_parallel_cross_entropy(output.float(), labels)
        return loss


60
class GPTModel(MegatronModule):
61
62
    """GPT-2 Language model."""

63
64
65
66
67
68
    def __init__(self,
                 num_tokentypes=0,
                 parallel_output=True,
                 pre_process=True,
                 post_process=True):
        super(GPTModel, self).__init__()
Mohammad's avatar
Mohammad committed
69
        args = get_args()
70
71

        self.parallel_output = parallel_output
72
73
        self.pre_process = pre_process
        self.post_process = post_process
mohammad's avatar
mohammad committed
74
        self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
75
76
77
78

        self.language_model, self._language_model_key = get_language_model(
            num_tokentypes=num_tokentypes,
            add_pooler=False,
79
            encoder_attn_mask_type=AttnMaskType.causal,
Mohammad's avatar
Mohammad committed
80
81
            init_method=init_method_normal(args.init_method_std),
            scaled_init_method=scaled_init_method_normal(args.init_method_std,
82
83
84
                                                         args.num_layers),
            pre_process=self.pre_process,
            post_process=self.post_process)
85

86
        self.initialize_word_embeddings(init_method_normal)
87

88
    def set_input_tensor(self, input_tensor):
89
        """See megatron.model.transformer.set_input_tensor()"""
90
91
92
        self.language_model.set_input_tensor(input_tensor)

    def forward(self, input_ids, position_ids, attention_mask, labels=None,
93
94
                tokentype_ids=None, layer_past=None, get_key_value=False,
                forward_method_parallel_output=None):
95

96
97
98
99
100
101
        lm_output = self.language_model(
            input_ids,
            position_ids,
            attention_mask,
            layer_past=layer_past,
            get_key_value=get_key_value)
102

103
        if self.post_process:
104
105
106
107
108
109
110
111
112
            return post_language_model_processing(
                lm_output, labels,
                self.word_embeddings_weight(),
                get_key_value,
                self.parallel_output,
                forward_method_parallel_output,
                self.fp16_lm_cross_entropy)
        else:
            return lm_output
113
114
115
116
117
118
119
120

    def state_dict_for_save_checkpoint(self, destination=None, prefix='',
                                       keep_vars=False):

        state_dict_ = {}
        state_dict_[self._language_model_key] \
            = self.language_model.state_dict_for_save_checkpoint(
                destination, prefix, keep_vars)
121
        # Save word_embeddings.
122
        if self.post_process and not self.pre_process:
123
124
            state_dict_[self._word_embeddings_for_head_key] \
                = self.word_embeddings.state_dict(destination, prefix, keep_vars)
125
126
127
128
129
        return state_dict_

    def load_state_dict(self, state_dict, strict=True):
        """Customized load."""

130
        # Load word_embeddings.
131
        if self.post_process and not self.pre_process:
132
133
            self.word_embeddings.load_state_dict(
                state_dict[self._word_embeddings_for_head_key], strict=strict)
134
135
136
        if self._language_model_key in state_dict:
            state_dict = state_dict[self._language_model_key]
        self.language_model.load_state_dict(state_dict, strict=strict)