t5_model.py 7.29 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2
3
4
5
6

"""T5 model."""

import torch

7
8
from megatron import get_args
from megatron.core import tensor_parallel
9
10
from megatron.model.enums import AttnMaskType
from megatron.model.language_model import parallel_lm_logits, get_language_model
Mostofa Patwary's avatar
Mostofa Patwary committed
11
from megatron.model import LayerNorm
12
13
from megatron.model.utils import (
    openai_gelu,
liangjing's avatar
v1  
liangjing committed
14
    get_linear_layer
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
)
from .module import MegatronModule


def t5_extended_attention_mask(attention_mask_list):

    def attn_mask_postprocess(attn_mask):
        # [b, 1, s, s]
        extended_attention_mask = attn_mask.unsqueeze(1)
        return extended_attention_mask

    return [attn_mask_postprocess(attn_mask) for attn_mask in attention_mask_list]


def t5_position_ids(token_ids):
    # Create position ids
    seq_length = token_ids.size(1)
    position_ids = torch.arange(seq_length, dtype=torch.long,
                                device=token_ids.device)
    position_ids = position_ids.unsqueeze(0).expand_as(token_ids)

    return position_ids


class T5LMHead(MegatronModule):
    """Masked LM head for T5

    Arguments:
        mpu_vocab_size: model parallel size of vocabulary.
        parallel_output: wether output logits being distributed or not.
    """

    def __init__(self, mpu_vocab_size, parallel_output):
        super(T5LMHead, self).__init__()

        self.bias = torch.nn.Parameter(torch.zeros(mpu_vocab_size))
        self.bias.model_parallel = True
        self.bias.partition_dim = 0
        self.bias.stride = 1
        self.parallel_output = parallel_output

    def forward(self, hidden_states, word_embeddings_weight):
        output = parallel_lm_logits(hidden_states,
                                    word_embeddings_weight,
                                    self.parallel_output,
                                    bias=self.bias)
        return output


class T5Model(MegatronModule):
    """T5 Language model."""

67
    def __init__(self,
liangjing's avatar
v1  
liangjing committed
68
                 config,
69
70
71
72
73
74
                 num_tokentypes=0,
                 parallel_output=True,
                 pre_process=True,
                 post_process=True,
                 add_encoder=True,
                 add_decoder=True):
liangjing's avatar
v1  
liangjing committed
75
        super().__init__(config=config)
76
77
78
79
        args = get_args()

        self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
        self.parallel_output = parallel_output
80
81
82
83
        self.pre_process = pre_process
        self.post_process = post_process
        self.add_encoder = add_encoder
        self.add_decoder = add_decoder
84
85

        self.language_model, self._language_model_key = get_language_model(
liangjing's avatar
v1  
liangjing committed
86
            config=config,
87
88
            num_tokentypes=num_tokentypes,
            add_pooler=False,
89
90
            add_encoder=add_encoder,
            add_decoder=add_decoder,
91
            encoder_attn_mask_type=AttnMaskType.padding,
92
93
            pre_process=self.pre_process,
            post_process=self.post_process)
94

liangjing's avatar
v1  
liangjing committed
95
        self.initialize_word_embeddings()
96
97
98

        if self.post_process and self.add_decoder:
            self.lm_head = T5LMHead(
liangjing's avatar
v1  
liangjing committed
99
                self.shared_embedding_or_output_weight().size(0),
100
101
                parallel_output)
            self._lm_head_key = 'lm_head'
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127

    def set_input_tensor(self, input_tensor):
        """See megatron.model.transformer.set_input_tensor()"""
        self.language_model.set_input_tensor(input_tensor)

    def forward(self, encoder_input_ids, decoder_input_ids, encoder_attn_mask,
                decoder_attn_mask, encoder_decoder_attn_mask,
                tokentype_ids=None, lm_labels=None, enc_hidden_states=None):

        # Converting the attention masks to proper parameter settings
        encoder_attn_mask, decoder_attn_mask, encoder_decoder_attn_mask = t5_extended_attention_mask(
            [encoder_attn_mask, decoder_attn_mask, encoder_decoder_attn_mask])

        encoder_position_ids = t5_position_ids(encoder_input_ids)
        decoder_position_ids = t5_position_ids(decoder_input_ids)

        lm_output = self.language_model(encoder_input_ids,
                                        encoder_position_ids,
                                        encoder_attn_mask,
                                        decoder_input_ids,
                                        decoder_position_ids,
                                        decoder_attn_mask,
                                        encoder_decoder_attn_mask,
                                        tokentype_ids=tokentype_ids,
                                        enc_hidden_states=enc_hidden_states)

128
129
        if self.post_process and self.add_decoder:
            decoder_output, encoder_output = lm_output
Vijay Korthikanti's avatar
Vijay Korthikanti committed
130
            # Output. [s, b, h]
131
            lm_logits = self.lm_head(decoder_output,
liangjing's avatar
v1  
liangjing committed
132
                                     self.shared_embedding_or_output_weight())
133

134
            if lm_labels is None:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
135
136
                # [s b h] => [b s h]
                return lm_logits.transpose(0,1).contiguous()
137
            else:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
138
                # [b s] => [s b]
Vijay Korthikanti's avatar
Vijay Korthikanti committed
139
                lm_labels = lm_labels.transpose(0,1).contiguous()
140
141
                if self.fp16_lm_cross_entropy:
                    assert lm_logits.dtype == torch.half
142
                    lm_loss = tensor_parallel.vocab_parallel_cross_entropy(lm_logits, lm_labels)
143
                else:
144
145
                    lm_loss = tensor_parallel.vocab_parallel_cross_entropy(lm_logits.float(),
                                                                                lm_labels)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
146
147
                # [s b] => [b s]
                lm_loss = lm_loss.transpose(0,1).contiguous()
148
149
150
151
152
153
154
            return lm_loss
        elif self.add_decoder and not self.add_encoder:
            decoder_output, encoder_output = lm_output
            return decoder_output
        else:
            encoder_output = lm_output
            return encoder_output
155

156
    def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
157
158
159
160
161
        """For easy load when model is combined with other heads,
        add an extra key."""

        state_dict_ = {}
        state_dict_[self._language_model_key] \
162
163
            = self.language_model.state_dict_for_save_checkpoint(prefix=prefix,
                                                                 keep_vars=keep_vars)
164
165
        if self.post_process and self.add_decoder:
            state_dict_[self._lm_head_key] \
166
167
                = self.lm_head.state_dict_for_save_checkpoint(prefix=prefix,
                                                              keep_vars=keep_vars)
168
169
170
         # Save word_embeddings.
        if self.post_process and not self.pre_process and self.add_decoder:
            state_dict_[self._word_embeddings_for_head_key] \
171
172
                = self.word_embeddings.state_dict(prefix=prefix,
                                                  keep_vars=keep_vars)
173
174
175
176
177
178
179
        return state_dict_

    def load_state_dict(self, state_dict, strict=True):
        """Customized load."""

        self.language_model.load_state_dict(
            state_dict[self._language_model_key], strict=strict)
180
181
182
183
184
185
186
        if self.post_process and self.add_decoder:
            self.lm_head.load_state_dict(state_dict[self._lm_head_key],
                                         strict=strict)
        # Load word embeddings.
        if self.post_process and not self.pre_process and self.add_decoder:
            self.word_embeddings.load_state_dict(
                state_dict[self._word_embeddings_for_head_key], strict=strict)