pretrain_gpt2.py 6.65 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Pretrain GPT2"""

Mohammad's avatar
Mohammad committed
18
19
import os

20
21
import torch

Mohammad's avatar
Mohammad committed
22
23
from megatron import get_args
from megatron import get_timers
Mohammad's avatar
Mohammad committed
24
from megatron import get_tokenizer
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
25
from megatron import mpu
Mohammad's avatar
Mohammad committed
26
from megatron import print_rank_0
Mohammad's avatar
Mohammad committed
27
28
from megatron.data.gpt2_dataset import GPT2Dataset
from megatron.data_utils.samplers import DistributedBatchSampler
29
from megatron.model import GPT2Model
Mohammad's avatar
Mohammad committed
30
from megatron.training import pretrain
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
31
from megatron.utils import get_ltor_masks_and_position_ids
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
32
from megatron.utils import reduce_losses
Mohammad's avatar
Mohammad committed
33

34

Mohammad's avatar
Mohammad committed
35
def model_provider():
36
    """Build the model."""
Mohammad's avatar
Mohammad committed
37
    args = get_args()
38
39
40

    print_rank_0('building GPT2 model ...')
    model = GPT2Model(num_layers=args.num_layers,
Mohammad's avatar
Mohammad committed
41
                      vocab_size=args.padded_vocab_size,
42
43
44
45
46
47
48
49
                      hidden_size=args.hidden_size,
                      num_attention_heads=args.num_attention_heads,
                      embedding_dropout_prob=args.hidden_dropout,
                      attention_dropout_prob=args.attention_dropout,
                      output_dropout_prob=args.hidden_dropout,
                      max_sequence_length=args.max_position_embeddings,
                      checkpoint_activations=args.checkpoint_activations,
                      checkpoint_num_layers=args.checkpoint_num_layers,
50
                      layernorm_epsilon=args.layernorm_epsilon,
51
52
53
                      parallel_output=True,
                      apply_query_key_layer_scaling=args.apply_query_key_layer_scaling,
                      attention_softmax_in_fp32=args.attention_softmax_in_fp32)
54
55
56
57

    return model


Mohammad's avatar
Mohammad committed
58
def get_batch(data_iterator):
59
    """Generate a batch"""
Mohammad's avatar
Mohammad committed
60
    args = get_args()
61

62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
    # Items and their type.
    keys = ['text']
    datatype = torch.int64

    # Broadcast data.
    if data_iterator is not None:
        data = next(data_iterator)
    else:
        data = None
    data_b = mpu.broadcast_data(keys, data, datatype)

    # Unpack.
    tokens_ = data_b['text'].long()
    labels = tokens_[:, 1:].contiguous()
    tokens = tokens_[:, :-1].contiguous()

    # Get the masks and postition ids.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
79
    attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(
80
81
82
        tokens,
        args.eod_token,
        args.reset_position_ids,
83
84
        args.reset_attention_mask,
        args.eod_mask_loss)
85
86
87
88
89
90
91
    # Convert
    if args.fp16:
        attention_mask = attention_mask.half()

    return tokens, labels, loss_mask, attention_mask, position_ids


Mohammad's avatar
Mohammad committed
92
def forward_step(data_iterator, model):
93
    """Forward step."""
Mohammad's avatar
Mohammad committed
94
    timers = get_timers()
95
96
97
98

    # Get the batch.
    timers('batch generator').start()
    tokens, labels, loss_mask, attention_mask, position_ids = get_batch(
Mohammad's avatar
Mohammad committed
99
        data_iterator)
100
101
102
103
104
105
106
107
108
    timers('batch generator').stop()

    # Forward model.
    output = model(tokens, position_ids, attention_mask)
    losses = mpu.vocab_parallel_cross_entropy(output.contiguous().float(),
                                              labels)
    loss_mask = loss_mask.view(-1)
    loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()

109
    # Reduce loss for logging.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
110
    reduced_loss = reduce_losses([loss])
111

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
112
    return loss, {'lm loss': reduced_loss[0]}
113
114


Mohammad's avatar
Mohammad committed
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
def make_gpt2_dataloaders():
    """Build gpt2 dataloders."""
    args = get_args()

    # Input parameters.
    input_data_sizes_file = args.input_data_sizes_file
    seq_length = args.seq_length
    initial_seed = args.seed

    # Data parallel arguments.
    world_size = mpu.get_data_parallel_world_size()
    rank = mpu.get_data_parallel_rank()
    global_batch_size = args.batch_size * world_size
    num_workers = args.num_workers

    def make_data_loader_(data_path):
        # Build the dataset.
        dataset = GPT2Dataset(data_path, input_data_sizes_file,
                              seq_length, initial_seed)
        # Use a simple sampler with distributed batch sampler.
        sampler = torch.utils.data.SequentialSampler(dataset)
        batch_sampler = DistributedBatchSampler(sampler=sampler,
                                                batch_size=global_batch_size,
                                                drop_last=True,
                                                rank=rank,
                                                world_size=world_size)
        # Torch dataloader.
        return torch.utils.data.DataLoader(dataset,
                                           batch_sampler=batch_sampler,
                                           num_workers=num_workers,
                                           pin_memory=True)

    train = make_data_loader_(os.path.join(args.data_path, 'train'))
    valid = make_data_loader_(os.path.join(args.data_path, 'valid'))
    test = make_data_loader_(os.path.join(args.data_path, 'test'))

    args.do_train = False
    args.do_valid = False
    args.do_test = False

    if train is not None:
        args.do_train = True
    if valid is not None:
        args.do_valid = True
    if test is not None:
        args.do_test = True

    return (train, valid, test)


Mohammad's avatar
Mohammad committed
165
def get_train_val_test_data():
166
    """Load the data on rank zero and boradcast number of tokens to all GPUS."""
Mohammad's avatar
Mohammad committed
167
    args = get_args()
Mohammad's avatar
Mohammad committed
168

169
170
171
172
    (train_data, val_data, test_data) = (None, None, None)

    # Data loader only on rank 0 of each model parallel group.
    if mpu.get_model_parallel_rank() == 0:
Mohammad's avatar
Mohammad committed
173

Mohammad's avatar
Mohammad committed
174
175
176
177
        (train_data, val_data, test_data) = make_gpt2_dataloaders()
        flags = torch.cuda.LongTensor([int(args.do_train),
                                       int(args.do_valid),
                                       int(args.do_test)])
178
    else:
Mohammad's avatar
Mohammad committed
179
        flags = torch.cuda.LongTensor([0, 0, 0])
180
181

    # Broadcast num tokens.
Mohammad's avatar
Mohammad committed
182
    torch.distributed.broadcast(flags,
183
184
                                mpu.get_model_parallel_src_rank(),
                                group=mpu.get_model_parallel_group())
Mohammad's avatar
Mohammad committed
185
186
187
    args.do_train = flags[0].item()
    args.do_valid = flags[1].item()
    args.do_test = flags[2].item()
188

Mohammad's avatar
Mohammad committed
189
190
    tokenizer = get_tokenizer()
    args.eod_token = tokenizer.eod_id
191

192
    return train_data, val_data, test_data
193
194
195


if __name__ == "__main__":
196

Mohammad's avatar
Mohammad committed
197
198
    pretrain(get_train_val_test_data, model_provider, forward_step,
             args_defaults={'tokenizer_type': 'GPT2BPETokenizer'})