gpt_model.py 4.46 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2
3
4
5
6

"""GPT-2 model."""

import torch

Mohammad's avatar
Mohammad committed
7
from megatron import get_args
8
from megatron.core import tensor_parallel
9
from .module import MegatronModule
10

11
from .enums import AttnMaskType
12
13
14
15
16
17
from .language_model import parallel_lm_logits
from .language_model import get_language_model
from .utils import init_method_normal
from .utils import scaled_init_method_normal


18
def post_language_model_processing(lm_output, labels, logit_weights,
19
                                   parallel_output,
20
21
                                   fp16_lm_cross_entropy):

Vijay Korthikanti's avatar
Vijay Korthikanti committed
22
    # Output. Format [s b h]
23
24
25
26
27
28
    output = parallel_lm_logits(
        lm_output,
        logit_weights,
        parallel_output)

    if labels is None:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
29
30
        # [s b h] => [b s h]
        return output.transpose(0,1).contiguous()
31
    else:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
32
33
        # [b s] => [s b]
        labels = labels.transpose(0,1).contiguous()
34
35
        if fp16_lm_cross_entropy:
            assert output.dtype == torch.half
36
            loss = tensor_parallel.vocab_parallel_cross_entropy(output, labels)
37
        else:
38
            loss = tensor_parallel.vocab_parallel_cross_entropy(output.float(), labels)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
39
40
41
        
        # [s b] => [b, s]
        loss = loss.transpose(0,1).contiguous()
42
43
44
        return loss


45
class GPTModel(MegatronModule):
46
47
    """GPT-2 Language model."""

48
49
50
51
52
53
    def __init__(self,
                 num_tokentypes=0,
                 parallel_output=True,
                 pre_process=True,
                 post_process=True):
        super(GPTModel, self).__init__()
Mohammad's avatar
Mohammad committed
54
        args = get_args()
55
56

        self.parallel_output = parallel_output
57
58
        self.pre_process = pre_process
        self.post_process = post_process
mohammad's avatar
mohammad committed
59
        self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
60
        self.untie_embeddings_and_output_weights = args.untie_embeddings_and_output_weights
61
62
63
64

        self.language_model, self._language_model_key = get_language_model(
            num_tokentypes=num_tokentypes,
            add_pooler=False,
65
            encoder_attn_mask_type=AttnMaskType.causal,
Mohammad's avatar
Mohammad committed
66
67
            init_method=init_method_normal(args.init_method_std),
            scaled_init_method=scaled_init_method_normal(args.init_method_std,
68
69
70
                                                         args.num_layers),
            pre_process=self.pre_process,
            post_process=self.post_process)
71

72
        self.initialize_word_embeddings(init_method_normal)
73

74
    def set_input_tensor(self, input_tensor):
75
        """See megatron.model.transformer.set_input_tensor()"""
76
77
        self.language_model.set_input_tensor(input_tensor)

Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
78
79
80
    def forward(self, input_ids, position_ids, attention_mask,
                ret_input_ids=None, ret_position_ids=None, ret_attn_mask=None,
                labels=None, tokentype_ids=None, inference_params=None):
81

82
83
84
85
        lm_output = self.language_model(
            input_ids,
            position_ids,
            attention_mask,
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
86
87
88
            ret_input_ids=ret_input_ids,
            ret_position_ids=ret_position_ids,
            ret_attn_mask=ret_attn_mask,
mshoeybi's avatar
mshoeybi committed
89
            inference_params=inference_params)
90

91
        if self.post_process:
92
93
            return post_language_model_processing(
                lm_output, labels,
94
                self.language_model.output_layer.weight if self.untie_embeddings_and_output_weights else self.word_embeddings_weight(),
95
96
97
98
                self.parallel_output,
                self.fp16_lm_cross_entropy)
        else:
            return lm_output
99

100
    def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
101
102
103
104

        state_dict_ = {}
        state_dict_[self._language_model_key] \
            = self.language_model.state_dict_for_save_checkpoint(
105
                prefix=prefix, keep_vars=keep_vars)
106
        # Save word_embeddings.
107
        if self.post_process and not self.pre_process:
108
            state_dict_[self._word_embeddings_for_head_key] \
109
110
                = self.word_embeddings.state_dict(prefix=prefix,
                                                  keep_vars=keep_vars)
111
112
113
114
115
        return state_dict_

    def load_state_dict(self, state_dict, strict=True):
        """Customized load."""

116
        # Load word_embeddings.
117
        if self.post_process and not self.pre_process:
118
119
            self.word_embeddings.load_state_dict(
                state_dict[self._word_embeddings_for_head_key], strict=strict)
120
121
122
        if self._language_model_key in state_dict:
            state_dict = state_dict[self._language_model_key]
        self.language_model.load_state_dict(state_dict, strict=strict)