t5_model.py 7.74 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2
3
4
5
6

"""T5 model."""

import torch

xingjinliang's avatar
xingjinliang committed
7
from megatron.training import get_args
8
from megatron.core import tensor_parallel
xingjinliang's avatar
xingjinliang committed
9
10
11
12
from megatron.legacy.model.enums import AttnMaskType
from megatron.legacy.model.language_model import parallel_lm_logits, get_language_model
from megatron.legacy.model import LayerNorm
from megatron.legacy.model.utils import (
13
    openai_gelu,
liangjing's avatar
v1  
liangjing committed
14
    get_linear_layer
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
)
from .module import MegatronModule


def t5_extended_attention_mask(attention_mask_list):

    def attn_mask_postprocess(attn_mask):
        # [b, 1, s, s]
        extended_attention_mask = attn_mask.unsqueeze(1)
        return extended_attention_mask

    return [attn_mask_postprocess(attn_mask) for attn_mask in attention_mask_list]


def t5_position_ids(token_ids):
    # Create position ids
    seq_length = token_ids.size(1)
    position_ids = torch.arange(seq_length, dtype=torch.long,
                                device=token_ids.device)
    position_ids = position_ids.unsqueeze(0).expand_as(token_ids)

    return position_ids


class T5LMHead(MegatronModule):
    """Masked LM head for T5

xingjinliang's avatar
xingjinliang committed
42
    Args:
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
        mpu_vocab_size: model parallel size of vocabulary.
        parallel_output: wether output logits being distributed or not.
    """

    def __init__(self, mpu_vocab_size, parallel_output):
        super(T5LMHead, self).__init__()

        self.bias = torch.nn.Parameter(torch.zeros(mpu_vocab_size))
        self.bias.model_parallel = True
        self.bias.partition_dim = 0
        self.bias.stride = 1
        self.parallel_output = parallel_output

    def forward(self, hidden_states, word_embeddings_weight):
        output = parallel_lm_logits(hidden_states,
                                    word_embeddings_weight,
                                    self.parallel_output,
                                    bias=self.bias)
        return output


class T5Model(MegatronModule):
    """T5 Language model."""

67
    def __init__(self,
liangjing's avatar
v1  
liangjing committed
68
                 config,
69
70
71
72
73
74
                 num_tokentypes=0,
                 parallel_output=True,
                 pre_process=True,
                 post_process=True,
                 add_encoder=True,
                 add_decoder=True):
liangjing's avatar
v1  
liangjing committed
75
        super().__init__(config=config)
76
77
78
79
        args = get_args()

        self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
        self.parallel_output = parallel_output
80
81
82
83
        self.pre_process = pre_process
        self.post_process = post_process
        self.add_encoder = add_encoder
        self.add_decoder = add_decoder
84
85

        self.language_model, self._language_model_key = get_language_model(
liangjing's avatar
v1  
liangjing committed
86
            config=config,
87
88
            num_tokentypes=num_tokentypes,
            add_pooler=False,
89
90
            add_encoder=add_encoder,
            add_decoder=add_decoder,
91
            encoder_attn_mask_type=AttnMaskType.padding,
92
93
            pre_process=self.pre_process,
            post_process=self.post_process)
94

liangjing's avatar
v1  
liangjing committed
95
        self.initialize_word_embeddings()
96

xingjinliang's avatar
xingjinliang committed
97
98
99
100
101
        if self.pre_process:
            self.position_embeddings = self.language_model.embedding.position_embeddings
        else:
            self.position_embeddings = None

102
103
        if self.post_process and self.add_decoder:
            self.lm_head = T5LMHead(
liangjing's avatar
v1  
liangjing committed
104
                self.shared_embedding_or_output_weight().size(0),
105
106
                parallel_output)
            self._lm_head_key = 'lm_head'
107

xingjinliang's avatar
xingjinliang committed
108
109
110
111
        # Tells schedules.py that this model has a skip connection between the encoder's output and the decoder
        # (and hence both the encoder and decoder's tensors are required for correct backprop).
        self.xattn_needed = True

112
    def set_input_tensor(self, input_tensor):
xingjinliang's avatar
xingjinliang committed
113
        """See megatron.legacy.model.transformer.set_input_tensor()"""
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
        self.language_model.set_input_tensor(input_tensor)

    def forward(self, encoder_input_ids, decoder_input_ids, encoder_attn_mask,
                decoder_attn_mask, encoder_decoder_attn_mask,
                tokentype_ids=None, lm_labels=None, enc_hidden_states=None):

        # Converting the attention masks to proper parameter settings
        encoder_attn_mask, decoder_attn_mask, encoder_decoder_attn_mask = t5_extended_attention_mask(
            [encoder_attn_mask, decoder_attn_mask, encoder_decoder_attn_mask])

        encoder_position_ids = t5_position_ids(encoder_input_ids)
        decoder_position_ids = t5_position_ids(decoder_input_ids)

        lm_output = self.language_model(encoder_input_ids,
                                        encoder_position_ids,
                                        encoder_attn_mask,
                                        decoder_input_ids,
                                        decoder_position_ids,
                                        decoder_attn_mask,
                                        encoder_decoder_attn_mask,
                                        tokentype_ids=tokentype_ids,
                                        enc_hidden_states=enc_hidden_states)

137
138
        if self.post_process and self.add_decoder:
            decoder_output, encoder_output = lm_output
Vijay Korthikanti's avatar
Vijay Korthikanti committed
139
            # Output. [s, b, h]
140
            lm_logits = self.lm_head(decoder_output,
liangjing's avatar
v1  
liangjing committed
141
                                     self.shared_embedding_or_output_weight())
142

143
            if lm_labels is None:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
144
145
                # [s b h] => [b s h]
                return lm_logits.transpose(0,1).contiguous()
146
            else:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
147
                # [b s] => [s b]
Vijay Korthikanti's avatar
Vijay Korthikanti committed
148
                lm_labels = lm_labels.transpose(0,1).contiguous()
149
150
                if self.fp16_lm_cross_entropy:
                    assert lm_logits.dtype == torch.half
151
                    lm_loss = tensor_parallel.vocab_parallel_cross_entropy(lm_logits, lm_labels)
152
                else:
153
154
                    lm_loss = tensor_parallel.vocab_parallel_cross_entropy(lm_logits.float(),
                                                                                lm_labels)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
155
156
                # [s b] => [b s]
                lm_loss = lm_loss.transpose(0,1).contiguous()
157
158
159
160
161
162
163
            return lm_loss
        elif self.add_decoder and not self.add_encoder:
            decoder_output, encoder_output = lm_output
            return decoder_output
        else:
            encoder_output = lm_output
            return encoder_output
164

165
    def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
166
167
168
169
170
        """For easy load when model is combined with other heads,
        add an extra key."""

        state_dict_ = {}
        state_dict_[self._language_model_key] \
171
172
            = self.language_model.state_dict_for_save_checkpoint(prefix=prefix,
                                                                 keep_vars=keep_vars)
173
174
        if self.post_process and self.add_decoder:
            state_dict_[self._lm_head_key] \
175
176
                = self.lm_head.state_dict_for_save_checkpoint(prefix=prefix,
                                                              keep_vars=keep_vars)
177
178
179
         # Save word_embeddings.
        if self.post_process and not self.pre_process and self.add_decoder:
            state_dict_[self._word_embeddings_for_head_key] \
180
181
                = self.word_embeddings.state_dict(prefix=prefix,
                                                  keep_vars=keep_vars)
182
183
184
185
186
187
188
        return state_dict_

    def load_state_dict(self, state_dict, strict=True):
        """Customized load."""

        self.language_model.load_state_dict(
            state_dict[self._language_model_key], strict=strict)
189
190
191
192
193
194
195
        if self.post_process and self.add_decoder:
            self.lm_head.load_state_dict(state_dict[self._lm_head_key],
                                         strict=strict)
        # Load word embeddings.
        if self.post_process and not self.pre_process and self.add_decoder:
            self.word_embeddings.load_state_dict(
                state_dict[self._word_embeddings_for_head_key], strict=strict)