pretrain_t5.py 10.7 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved.
2
3
4

"""Pretrain T5"""

liangjing's avatar
liangjing committed
5
from copy import deepcopy
6
from functools import partial
liangjing's avatar
liangjing committed
7
from typing import Union
8
9
10

import torch

liangjing's avatar
liangjing committed
11
from megatron.training import (
12
13
    get_args,
    get_timers,
liangjing's avatar
liangjing committed
14
    get_tokenizer,
15
16
    print_rank_0
)
liangjing's avatar
liangjing committed
17
18
19
20
21
22
from megatron.core import mpu, tensor_parallel
from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder
from megatron.core.datasets.t5_dataset import (
    T5MaskedWordPieceDataset,
    T5MaskedWordPieceDatasetConfig,
)
23
from megatron.core.enums import ModelType
liangjing's avatar
liangjing committed
24
from megatron.core.models.T5 import T5Model
25
from megatron.training import pretrain
liangjing's avatar
liangjing committed
26
27
28
29
30
31
32
33
34
35
from megatron.training.arguments import core_transformer_config_from_args
from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder
from megatron.core.datasets.t5_dataset import T5MaskedWordPieceDataset, T5MaskedWordPieceDatasetConfig
from megatron.core.datasets.utils import get_blend_from_list
from megatron.core.models.T5.t5_spec import (get_t5_encoder_with_transformer_engine_block_spec,
                                            get_t5_decoder_with_transformer_engine_block_spec,
                                            get_t5_encoder_with_local_block_spec,
                                            get_t5_decoder_with_local_block_spec)
from megatron.legacy.model import T5Model as LegacyT5Model
from pretrain_gpt import loss_func
36

37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
"""
Pipeline parallelism for T5

T5 is a model architecture with both encoder and decoder blocks.
Consequently, pipeline parallelism is implemented slightly differently
compared to architectures like GPT and BERT.

In particular, when pipeline_model_parallel_world_size > 1, each stage
either executes an encoder block or a decoder block. The
--pipeline-model-parallel-split-rank argument controls the rank at which
the split happens: all ranks lower than this argument execute the
encoder block, and all ranks equal to or higher than this argument value
execute the decoder block.

In the encoder section of the model, only one tensor is sent downstream:
the intermediate encoder_hidden_state. In the decoder section of the
model, two tensors are sent downstream in the forward pass: the fully
computed encoder_hidden_state, and the intermediate decoder_hidden_state.

In particular, these are the shapes of the tensors sent between
different workers:
    If rank is in decoder section:
        intermediate decoder_hidden_state (pre-transpose),
        complete encoder_hidden_state (post-transpose).
    If rank is at boundary between encoder and decoder sections:
        complete encoder_hidden_state (post-transpose).
    If rank is in encoder section:
        intermediate encoder_hidden_state (pre-transpose).

Additionally, we have code in the backward_step function in schedules.py
to accumulate the encoder_hidden_state gradient across skip connections
(encoder_hidden_state fed in as input to each layer in the decoder).
"""


liangjing's avatar
liangjing committed
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
def model_provider(
    pre_process=True, post_process=True, add_encoder=True, add_decoder=True
) -> Union[LegacyT5Model, T5Model]:
    """Builds the model.

    Args:
        pre_process (bool, optional): Set to true if you need to compute embedings. Defaults to True.
        post_process (bool, optional): Set to true if you need to want to compute output logits/loss. Defaults to True.
        add_encoder (bool, optional): Defaults to True
        add_decoder (bool, optional): Defaults to True
    Returns:
        T5Model: The returned T5 model
    """

    args = get_args()

    assert (
        args.encoder_tensor_model_parallel_size == 0 or
        args.encoder_tensor_model_parallel_size == args.tensor_model_parallel_size
    ), f"Because word embeddings are shared between the encoder & decoder, these have to have the same tensor parallel size."

    config = core_transformer_config_from_args(args)
    if args.use_legacy_models:
        model = LegacyT5Model(
            config=config,
            num_tokentypes=0,
            parallel_output=True,
            pre_process=pre_process,
            post_process=post_process,
            add_encoder=add_encoder,
            add_decoder=add_decoder,
        )
    else:
        encoder_config = deepcopy(config)
        encoder_config.num_layers = args.encoder_num_layers

        if args.pipeline_model_parallel_size > 1:
            assert args.encoder_pipeline_model_parallel_size > 0, "Need to know how to shard the encoder & decoder."

        if args.encoder_pipeline_model_parallel_size > 0:
            encoder_config.pipeline_model_parallel_size = args.encoder_pipeline_model_parallel_size

        encoder_layers_per_pipeline = encoder_config.num_layers // encoder_config.pipeline_model_parallel_size
        decoder_layers_per_pipeline = config.num_layers // config.pipeline_model_parallel_size

        if args.transformer_impl == "local":
            en_block_spec = get_t5_encoder_with_local_block_spec(encoder_layers_per_pipeline)
            de_block_spec = get_t5_decoder_with_local_block_spec(decoder_layers_per_pipeline)
        elif args.transformer_impl == "transformer_engine":
            en_block_spec = get_t5_encoder_with_transformer_engine_block_spec(
                encoder_layers_per_pipeline
            )
            de_block_spec = get_t5_decoder_with_transformer_engine_block_spec(
                decoder_layers_per_pipeline
            )

        print_rank_0('building T5 model ...')
        model = T5Model(
            config=config,
            encoder_config=encoder_config,
            transformer_encoder_layer_spec=en_block_spec,
            transformer_decoder_layer_spec=de_block_spec,
            vocab_size=args.padded_vocab_size,
            max_sequence_length=args.max_position_embeddings,
            pre_process=pre_process,
            post_process=post_process,
            fp16_lm_cross_entropy=args.fp16_lm_cross_entropy,
            parallel_output=True,
            share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights,
            position_embedding_type=args.position_embedding_type,
            rotary_percent=args.rotary_percent,
            add_encoder=add_encoder,
            add_decoder=add_decoder
        )
146
147
148
149
150
151
152

    return model


def get_batch(data_iterator):
    """Build the batch."""

liangjing's avatar
liangjing committed
153
    keys = ['text_enc', 'text_dec', 'labels', 'loss_mask', 'enc_mask', 'dec_mask', 'enc_dec_mask']
154
155
156
157
158
159
160
    datatype = torch.int64

    # Broadcast data.
    if data_iterator is not None:
        data = next(data_iterator)
    else:
        data = None
161
    data_b = tensor_parallel.broadcast_data(keys, data, datatype)
162
163
164
165
166
167
168

    # Unpack.
    tokens_enc = data_b['text_enc'].long()
    tokens_dec = data_b['text_dec'].long()
    labels = data_b['labels'].long()
    loss_mask = data_b['loss_mask'].float()

liangjing's avatar
liangjing committed
169
170
171
    enc_mask = data_b['enc_mask'] < 0.5
    dec_mask = data_b['dec_mask'] < 0.5
    enc_dec_mask = data_b['enc_dec_mask'] < 0.5
172

liangjing's avatar
liangjing committed
173
    return tokens_enc, tokens_dec, loss_mask, labels, enc_mask, dec_mask, enc_dec_mask
174
175


liangjing's avatar
liangjing committed
176
177
def forward_step(data_iterator, model: T5Model):
    """Forward training step.
178

liangjing's avatar
liangjing committed
179
180
181
182
    Args:
        data_iterator : Input data iterator
        model (T5Model): The T5 Model
    """
183
184
185
186
187

    args = get_args()
    timers = get_timers()

    # Get the batch.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
188
    timers('batch generator', log_level=2).start()
liangjing's avatar
liangjing committed
189
190
191
    tokens_enc, tokens_dec, loss_mask, lm_labels, enc_mask, dec_mask, enc_dec_mask = get_batch(
        data_iterator
    )
192
193
194
    timers('batch generator').stop()

    # Forward model lm_labels
liangjing's avatar
liangjing committed
195
196
197
    output_tensor = model(
        tokens_enc, tokens_dec, enc_mask, dec_mask, enc_dec_mask, lm_labels=lm_labels
    )
198
199
200
201

    return output_tensor, partial(loss_func, loss_mask)


liangjing's avatar
liangjing committed
202
203
204
205
206
207
def train_valid_test_datasets_provider(train_val_test_num_samples: int):
    """Build the train test and validation datasets.

    Args:
        train_val_test_num_samples : A list containing the number of samples in train test and validation.
    """
208
209
    args = get_args()

liangjing's avatar
liangjing committed
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
    tokenizer = get_tokenizer()

    config = T5MaskedWordPieceDatasetConfig(
        random_seed=args.seed,
        sequence_length=args.encoder_seq_length,
        sequence_length_decoder=args.decoder_seq_length,
        blend=get_blend_from_list(args.data_path),
        blend_per_split=[
            get_blend_from_list(args.train_data_path),
            get_blend_from_list(args.valid_data_path),
            get_blend_from_list(args.test_data_path)
        ],
        renormalize_blend_weights=args.renormalize_blend_weights,
        split=args.split,
        path_to_cache=args.data_cache_path,
        tokenizer=tokenizer,
        masking_probability=args.mask_prob,
        short_sequence_probability=args.short_seq_prob,
        masking_max_ngram=10,
        masking_do_full_word=True,
        masking_do_permutation=False,
        masking_use_longer_ngrams=False,
        masking_use_geometric_distribution=True,
    )

    print_rank_0('> building train, validation, and test datasets for T5 ...')

    train_ds, valid_ds, test_ds = BlendedMegatronDatasetBuilder(
        T5MaskedWordPieceDataset,
        train_val_test_num_samples,
        lambda: mpu.get_tensor_model_parallel_rank() == 0,
        config,
    ).build()

244
245
246
247
248
    print_rank_0("> finished creating T5 datasets ...")

    return train_ds, valid_ds, test_ds


liangjing's avatar
liangjing committed
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
def t5_embedding_ranks(pp_ranks):
    """T5's embedding ranks consist of the encoder's first rank, and the decoder's first & last ranks.
    Args:
        pp_ranks: A list of global ranks that constitute a pipeline group.
    """
    args = get_args()

    first_rank = pp_ranks[0]
    last_rank = pp_ranks[-1]

    # encoder size is also the index to the first rank of the decoder.
    epp = args.encoder_pipeline_model_parallel_size

    if len(pp_ranks) == 1:
        return [first_rank]
    elif pp_ranks[epp] not in (first_rank, last_rank):
        return [first_rank, pp_ranks[epp], last_rank]
    else:
        return [first_rank, last_rank]


def t5_position_embedding_ranks(pp_ranks):
    """T5's positional embeddings are the encoder & decoder first rank stages
    Args:
        pp_ranks: A list of global ranks that constitute a pipeline group.
    """
    args = get_args()

    # encoder size is also the index to the first rank of the decoder.
    epp = args.encoder_pipeline_model_parallel_size

    if len(pp_ranks) == 1 or pp_ranks[0] == pp_ranks[epp]:
        return [pp_ranks[0]]
    else:
        return [pp_ranks[0], pp_ranks[epp]]


286
287
if __name__ == "__main__":

liangjing's avatar
liangjing committed
288
289
290
291
292
293
294
295
296
297
298
299
    # Temporary for transition to core datasets
    train_valid_test_datasets_provider.is_distributed = True

    pretrain(
        train_valid_test_datasets_provider,
        model_provider,
        ModelType.encoder_and_decoder,
        forward_step,
        args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'},
        get_embedding_ranks=t5_embedding_ranks,
        get_position_embedding_ranks=t5_position_embedding_ranks,
    )