gpt_model.py 4.1 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2
3
4
5
6

"""GPT-2 model."""

import torch

Mohammad's avatar
Mohammad committed
7
from megatron import get_args
8
from megatron import mpu
9
from megatron import core
10
from .module import MegatronModule
11

12
from .enums import AttnMaskType
13
14
15
16
17
18
from .language_model import parallel_lm_logits
from .language_model import get_language_model
from .utils import init_method_normal
from .utils import scaled_init_method_normal


19
def post_language_model_processing(lm_output, labels, logit_weights,
20
                                   parallel_output,
21
22
                                   fp16_lm_cross_entropy):

Vijay Korthikanti's avatar
Vijay Korthikanti committed
23
    # Output. Format [s b h]
24
25
26
27
28
29
    output = parallel_lm_logits(
        lm_output,
        logit_weights,
        parallel_output)

    if labels is None:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
30
31
        # [s b h] => [b s h]
        return output.transpose(0,1).contiguous()
32
    else:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
33
34
        # [b s] => [s b]
        labels = labels.transpose(0,1).contiguous()
35
36
        if fp16_lm_cross_entropy:
            assert output.dtype == torch.half
37
            loss = core.tensor_parallel.vocab_parallel_cross_entropy(output, labels)
38
        else:
39
            loss = core.tensor_parallel.vocab_parallel_cross_entropy(output.float(), labels)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
40
41
42
        
        # [s b] => [b, s]
        loss = loss.transpose(0,1).contiguous()
43
44
45
        return loss


46
class GPTModel(MegatronModule):
47
48
    """GPT-2 Language model."""

49
50
51
52
53
54
    def __init__(self,
                 num_tokentypes=0,
                 parallel_output=True,
                 pre_process=True,
                 post_process=True):
        super(GPTModel, self).__init__()
Mohammad's avatar
Mohammad committed
55
        args = get_args()
56
57

        self.parallel_output = parallel_output
58
59
        self.pre_process = pre_process
        self.post_process = post_process
mohammad's avatar
mohammad committed
60
        self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
61
62
63
64

        self.language_model, self._language_model_key = get_language_model(
            num_tokentypes=num_tokentypes,
            add_pooler=False,
65
            encoder_attn_mask_type=AttnMaskType.causal,
Mohammad's avatar
Mohammad committed
66
67
            init_method=init_method_normal(args.init_method_std),
            scaled_init_method=scaled_init_method_normal(args.init_method_std,
68
69
70
                                                         args.num_layers),
            pre_process=self.pre_process,
            post_process=self.post_process)
71

72
        self.initialize_word_embeddings(init_method_normal)
73

74
    def set_input_tensor(self, input_tensor):
75
        """See megatron.model.transformer.set_input_tensor()"""
76
77
78
        self.language_model.set_input_tensor(input_tensor)

    def forward(self, input_ids, position_ids, attention_mask, labels=None,
mshoeybi's avatar
mshoeybi committed
79
                tokentype_ids=None, inference_params=None):
80

81
82
83
84
        lm_output = self.language_model(
            input_ids,
            position_ids,
            attention_mask,
mshoeybi's avatar
mshoeybi committed
85
            inference_params=inference_params)
86

87
        if self.post_process:
88
89
90
91
92
93
94
            return post_language_model_processing(
                lm_output, labels,
                self.word_embeddings_weight(),
                self.parallel_output,
                self.fp16_lm_cross_entropy)
        else:
            return lm_output
95

96
    def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
97
98
99
100

        state_dict_ = {}
        state_dict_[self._language_model_key] \
            = self.language_model.state_dict_for_save_checkpoint(
101
                prefix=prefix, keep_vars=keep_vars)
102
        # Save word_embeddings.
103
        if self.post_process and not self.pre_process:
104
            state_dict_[self._word_embeddings_for_head_key] \
105
106
                = self.word_embeddings.state_dict(prefix=prefix,
                                                  keep_vars=keep_vars)
107
108
109
110
111
        return state_dict_

    def load_state_dict(self, state_dict, strict=True):
        """Customized load."""

112
        # Load word_embeddings.
113
        if self.post_process and not self.pre_process:
114
115
            self.word_embeddings.load_state_dict(
                state_dict[self._word_embeddings_for_head_key], strict=strict)
116
117
118
        if self._language_model_key in state_dict:
            state_dict = state_dict[self._language_model_key]
        self.language_model.load_state_dict(state_dict, strict=strict)