"symphony/utils/requirements.txt" did not exist on "d8d5365056e6e43ee8ca9884eb38144bddc60ec5"
t5_model.py 7.76 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2
3
4
5
6

"""T5 model."""

import torch

7
8
from megatron import get_args
from megatron.core import tensor_parallel
9
10
from megatron.model.enums import AttnMaskType
from megatron.model.language_model import parallel_lm_logits, get_language_model
Mostofa Patwary's avatar
Mostofa Patwary committed
11
from megatron.model import LayerNorm
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from megatron.model.utils import (
    openai_gelu,
    get_linear_layer,
    init_method_normal,
    scaled_init_method_normal
)
from .module import MegatronModule


def t5_extended_attention_mask(attention_mask_list):

    def attn_mask_postprocess(attn_mask):
        # [b, 1, s, s]
        extended_attention_mask = attn_mask.unsqueeze(1)
        return extended_attention_mask

    return [attn_mask_postprocess(attn_mask) for attn_mask in attention_mask_list]


def t5_position_ids(token_ids):
    # Create position ids
    seq_length = token_ids.size(1)
    position_ids = torch.arange(seq_length, dtype=torch.long,
                                device=token_ids.device)
    position_ids = position_ids.unsqueeze(0).expand_as(token_ids)

    return position_ids


class T5LMHead(MegatronModule):
    """Masked LM head for T5

    Arguments:
        mpu_vocab_size: model parallel size of vocabulary.
        hidden_size: hidden size
        init_method: init method for weight initialization
        layernorm_epsilon: tolerance for layer norm divisions
        parallel_output: wether output logits being distributed or not.
    """

    def __init__(self, mpu_vocab_size, parallel_output):
        super(T5LMHead, self).__init__()

        args = get_args()

        self.bias = torch.nn.Parameter(torch.zeros(mpu_vocab_size))
        self.bias.model_parallel = True
        self.bias.partition_dim = 0
        self.bias.stride = 1
        self.parallel_output = parallel_output

    def forward(self, hidden_states, word_embeddings_weight):
        output = parallel_lm_logits(hidden_states,
                                    word_embeddings_weight,
                                    self.parallel_output,
                                    bias=self.bias)
        return output


class T5Model(MegatronModule):
    """T5 Language model."""

74
75
76
77
78
79
80
    def __init__(self,
                 num_tokentypes=0,
                 parallel_output=True,
                 pre_process=True,
                 post_process=True,
                 add_encoder=True,
                 add_decoder=True):
81
82
83
84
85
86
87
88
        super(T5Model, self).__init__()
        args = get_args()

        self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
        self.parallel_output = parallel_output
        init_method = init_method_normal(args.init_method_std)
        scaled_init_method = scaled_init_method_normal(args.init_method_std,
                                                       args.num_layers)
89
90
91
92
        self.pre_process = pre_process
        self.post_process = post_process
        self.add_encoder = add_encoder
        self.add_decoder = add_decoder
93
94
95
96

        self.language_model, self._language_model_key = get_language_model(
            num_tokentypes=num_tokentypes,
            add_pooler=False,
97
98
            add_encoder=add_encoder,
            add_decoder=add_decoder,
99
100
            encoder_attn_mask_type=AttnMaskType.padding,
            init_method=init_method,
101
102
103
            scaled_init_method=scaled_init_method,
            pre_process=self.pre_process,
            post_process=self.post_process)
104

105
106
107
108
109
110
111
        self.initialize_word_embeddings(init_method_normal)

        if self.post_process and self.add_decoder:
            self.lm_head = T5LMHead(
                self.word_embeddings_weight().size(0),
                parallel_output)
            self._lm_head_key = 'lm_head'
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137

    def set_input_tensor(self, input_tensor):
        """See megatron.model.transformer.set_input_tensor()"""
        self.language_model.set_input_tensor(input_tensor)

    def forward(self, encoder_input_ids, decoder_input_ids, encoder_attn_mask,
                decoder_attn_mask, encoder_decoder_attn_mask,
                tokentype_ids=None, lm_labels=None, enc_hidden_states=None):

        # Converting the attention masks to proper parameter settings
        encoder_attn_mask, decoder_attn_mask, encoder_decoder_attn_mask = t5_extended_attention_mask(
            [encoder_attn_mask, decoder_attn_mask, encoder_decoder_attn_mask])

        encoder_position_ids = t5_position_ids(encoder_input_ids)
        decoder_position_ids = t5_position_ids(decoder_input_ids)

        lm_output = self.language_model(encoder_input_ids,
                                        encoder_position_ids,
                                        encoder_attn_mask,
                                        decoder_input_ids,
                                        decoder_position_ids,
                                        decoder_attn_mask,
                                        encoder_decoder_attn_mask,
                                        tokentype_ids=tokentype_ids,
                                        enc_hidden_states=enc_hidden_states)

138
139
        if self.post_process and self.add_decoder:
            decoder_output, encoder_output = lm_output
Vijay Korthikanti's avatar
Vijay Korthikanti committed
140
            # Output. [s, b, h]
141
142
            lm_logits = self.lm_head(decoder_output,
                                     self.word_embeddings_weight())
143

144
            if lm_labels is None:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
145
146
                # [s b h] => [b s h]
                return lm_logits.transpose(0,1).contiguous()
147
            else:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
148
                # [b s] => [s b]
Vijay Korthikanti's avatar
Vijay Korthikanti committed
149
                lm_labels = lm_labels.transpose(0,1).contiguous()
150
151
                if self.fp16_lm_cross_entropy:
                    assert lm_logits.dtype == torch.half
152
                    lm_loss = tensor_parallel.vocab_parallel_cross_entropy(lm_logits, lm_labels)
153
                else:
154
155
                    lm_loss = tensor_parallel.vocab_parallel_cross_entropy(lm_logits.float(),
                                                                                lm_labels)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
156
157
                # [s b] => [b s]
                lm_loss = lm_loss.transpose(0,1).contiguous()
158
159
160
161
162
163
164
            return lm_loss
        elif self.add_decoder and not self.add_encoder:
            decoder_output, encoder_output = lm_output
            return decoder_output
        else:
            encoder_output = lm_output
            return encoder_output
165

166
    def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
167
168
169
170
171
        """For easy load when model is combined with other heads,
        add an extra key."""

        state_dict_ = {}
        state_dict_[self._language_model_key] \
172
173
            = self.language_model.state_dict_for_save_checkpoint(prefix=prefix,
                                                                 keep_vars=keep_vars)
174
175
        if self.post_process and self.add_decoder:
            state_dict_[self._lm_head_key] \
176
177
                = self.lm_head.state_dict_for_save_checkpoint(prefix=prefix,
                                                              keep_vars=keep_vars)
178
179
180
         # Save word_embeddings.
        if self.post_process and not self.pre_process and self.add_decoder:
            state_dict_[self._word_embeddings_for_head_key] \
181
182
                = self.word_embeddings.state_dict(prefix=prefix,
                                                  keep_vars=keep_vars)
183
184
185
186
187
188
189
        return state_dict_

    def load_state_dict(self, state_dict, strict=True):
        """Customized load."""

        self.language_model.load_state_dict(
            state_dict[self._language_model_key], strict=strict)
190
191
192
193
194
195
196
        if self.post_process and self.add_decoder:
            self.lm_head.load_state_dict(state_dict[self._lm_head_key],
                                         strict=strict)
        # Load word embeddings.
        if self.post_process and not self.pre_process and self.add_decoder:
            self.word_embeddings.load_state_dict(
                state_dict[self._word_embeddings_for_head_key], strict=strict)