gpt_model.py 4.51 KB
Newer Older
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""GPT-2 model."""

import torch

Mohammad's avatar
Mohammad committed
20
from megatron import get_args
21
from megatron import mpu
22
from .module import MegatronModule
23

24
from .enums import AttnMaskType
25
26
27
28
29
30
from .language_model import parallel_lm_logits
from .language_model import get_language_model
from .utils import init_method_normal
from .utils import scaled_init_method_normal


31
def post_language_model_processing(lm_output, labels, logit_weights,
32
                                   parallel_output,
33
34
                                   fp16_lm_cross_entropy):

Vijay Korthikanti's avatar
Vijay Korthikanti committed
35
    # Output. Format [s b h]
36
37
38
39
40
41
    output = parallel_lm_logits(
        lm_output,
        logit_weights,
        parallel_output)

    if labels is None:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
42
43
        # [s b h] => [b s h]
        return output.transpose(0,1).contiguous()
44
    else:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
45
46
        # [b s] => [s b]
        labels = labels.transpose(0,1).contiguous()
47
48
49
50
51
52
53
54
        if fp16_lm_cross_entropy:
            assert output.dtype == torch.half
            loss = mpu.vocab_parallel_cross_entropy(output, labels)
        else:
            loss = mpu.vocab_parallel_cross_entropy(output.float(), labels)
        return loss


55
class GPTModel(MegatronModule):
56
57
    """GPT-2 Language model."""

58
59
60
61
62
63
    def __init__(self,
                 num_tokentypes=0,
                 parallel_output=True,
                 pre_process=True,
                 post_process=True):
        super(GPTModel, self).__init__()
Mohammad's avatar
Mohammad committed
64
        args = get_args()
65
66

        self.parallel_output = parallel_output
67
68
        self.pre_process = pre_process
        self.post_process = post_process
mohammad's avatar
mohammad committed
69
        self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
70
71
72
73

        self.language_model, self._language_model_key = get_language_model(
            num_tokentypes=num_tokentypes,
            add_pooler=False,
74
            encoder_attn_mask_type=AttnMaskType.causal,
Mohammad's avatar
Mohammad committed
75
76
            init_method=init_method_normal(args.init_method_std),
            scaled_init_method=scaled_init_method_normal(args.init_method_std,
77
78
79
                                                         args.num_layers),
            pre_process=self.pre_process,
            post_process=self.post_process)
80

81
        self.initialize_word_embeddings(init_method_normal)
82

83
    def set_input_tensor(self, input_tensor):
84
        """See megatron.model.transformer.set_input_tensor()"""
85
86
87
        self.language_model.set_input_tensor(input_tensor)

    def forward(self, input_ids, position_ids, attention_mask, labels=None,
mshoeybi's avatar
mshoeybi committed
88
                tokentype_ids=None, inference_params=None):
89

90
91
92
93
        lm_output = self.language_model(
            input_ids,
            position_ids,
            attention_mask,
mshoeybi's avatar
mshoeybi committed
94
            inference_params=inference_params)
95

96
        if self.post_process:
97
98
99
100
101
102
103
            return post_language_model_processing(
                lm_output, labels,
                self.word_embeddings_weight(),
                self.parallel_output,
                self.fp16_lm_cross_entropy)
        else:
            return lm_output
104
105
106
107
108
109
110
111

    def state_dict_for_save_checkpoint(self, destination=None, prefix='',
                                       keep_vars=False):

        state_dict_ = {}
        state_dict_[self._language_model_key] \
            = self.language_model.state_dict_for_save_checkpoint(
                destination, prefix, keep_vars)
112
        # Save word_embeddings.
113
        if self.post_process and not self.pre_process:
114
115
            state_dict_[self._word_embeddings_for_head_key] \
                = self.word_embeddings.state_dict(destination, prefix, keep_vars)
116
117
118
119
120
        return state_dict_

    def load_state_dict(self, state_dict, strict=True):
        """Customized load."""

121
        # Load word_embeddings.
122
        if self.post_process and not self.pre_process:
123
124
            self.word_embeddings.load_state_dict(
                state_dict[self._word_embeddings_for_head_key], strict=strict)
125
126
127
        if self._language_model_key in state_dict:
            state_dict = state_dict[self._language_model_key]
        self.language_model.load_state_dict(state_dict, strict=strict)