pretrain_gpt2.py 5.58 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Pretrain GPT2"""

Mohammad's avatar
Mohammad committed
18
19
import os

20
21
import torch

Mohammad's avatar
Mohammad committed
22
23
from megatron import get_args
from megatron import get_timers
Mohammad's avatar
Mohammad committed
24
from megatron import get_tokenizer
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
25
from megatron import mpu
Mohammad's avatar
Mohammad committed
26
from megatron import print_rank_0
27
from megatron.data.gpt2_dataset import build_train_valid_test_datasets
28
from megatron.model import GPT2Model
Mohammad's avatar
Mohammad committed
29
from megatron.training import pretrain
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
30
from megatron.utils import get_ltor_masks_and_position_ids
31
from megatron.utils import make_data_loader
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
32
from megatron.utils import reduce_losses
Mohammad's avatar
Mohammad committed
33

34

Mohammad's avatar
Mohammad committed
35
def model_provider():
36
    """Build the model."""
Mohammad's avatar
Mohammad committed
37
    args = get_args()
38
39

    print_rank_0('building GPT2 model ...')
Mohammad's avatar
Mohammad committed
40
    model = GPT2Model(num_tokentypes=0, parallel_output=True)
41
42
43
44

    return model


Mohammad's avatar
Mohammad committed
45
def get_batch(data_iterator):
46
    """Generate a batch"""
Mohammad's avatar
Mohammad committed
47
    args = get_args()
Mohammad's avatar
Mohammad committed
48
    tokenizer = get_tokenizer()
49

50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
    # Items and their type.
    keys = ['text']
    datatype = torch.int64

    # Broadcast data.
    if data_iterator is not None:
        data = next(data_iterator)
    else:
        data = None
    data_b = mpu.broadcast_data(keys, data, datatype)

    # Unpack.
    tokens_ = data_b['text'].long()
    labels = tokens_[:, 1:].contiguous()
    tokens = tokens_[:, :-1].contiguous()

    # Get the masks and postition ids.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
67
    attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(
68
        tokens,
Mohammad's avatar
Mohammad committed
69
        tokenizer.eod,
70
        args.reset_position_ids,
71
        args.reset_attention_mask,
Mohammad's avatar
Mohammad committed
72
73
        args.eod_mask_loss,
        args.fp16)
74
75
76
77

    return tokens, labels, loss_mask, attention_mask, position_ids


Mohammad's avatar
Mohammad committed
78
def forward_step(data_iterator, model):
79
    """Forward step."""
Mohammad's avatar
Mohammad committed
80
    timers = get_timers()
81
82
83
84

    # Get the batch.
    timers('batch generator').start()
    tokens, labels, loss_mask, attention_mask, position_ids = get_batch(
Mohammad's avatar
Mohammad committed
85
        data_iterator)
86
87
88
89
90
91
92
93
94
    timers('batch generator').stop()

    # Forward model.
    output = model(tokens, position_ids, attention_mask)
    losses = mpu.vocab_parallel_cross_entropy(output.contiguous().float(),
                                              labels)
    loss_mask = loss_mask.view(-1)
    loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()

95
    # Reduce loss for logging.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
96
    reduced_loss = reduce_losses([loss])
97

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
98
    return loss, {'lm loss': reduced_loss[0]}
99
100


Mohammad's avatar
Mohammad committed
101
def get_train_val_test_data():
102
    """Load the data on rank zero and boradcast number of tokens to all GPUS."""
Mohammad's avatar
Mohammad committed
103
    args = get_args()
Mohammad's avatar
Mohammad committed
104

105
    (train_data, valid_data, test_data) = (None, None, None)
106
107
108

    # Data loader only on rank 0 of each model parallel group.
    if mpu.get_model_parallel_rank() == 0:
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
        print_rank_0('> building train, validation, and test datasets '
                     'for GPT2 ...')

        data_parallel_size = mpu.get_data_parallel_world_size()
        data_parallel_rank = mpu.get_data_parallel_rank()
        global_batch_size = args.batch_size * data_parallel_size

        # Number of train/valid/test samples.
        train_iters = args.train_iters
        eval_iters = (train_iters // args.eval_interval + 1) * args.eval_iters
        test_iters = args.eval_iters
        train_val_test_num_samples = [train_iters * global_batch_size,
                                      eval_iters * global_batch_size,
                                      test_iters * global_batch_size]
        print_rank_0(' > datasets target sizes (minimum size):')
        print_rank_0('    train:      {}'.format(train_val_test_num_samples[0]))
        print_rank_0('    validation: {}'.format(train_val_test_num_samples[1]))
        print_rank_0('    test:       {}'.format(train_val_test_num_samples[2]))

        train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
            data_prefix=args.data_path,
            data_impl=args.data_impl,
            splits_string=args.split,
            train_valid_test_num_samples=train_val_test_num_samples,
            seq_length=args.seq_length,
            seed=args.seed,
            skip_warmup=(not args.mmap_warmup))
        print_rank_0("> finished creating GPT2 datasets ...")

        train_data = make_data_loader(train_ds)
        valid_data = make_data_loader(valid_ds)
        test_data = make_data_loader(test_ds)

        do_train = train_data is not None and args.train_iters > 0
        do_valid = valid_data is not None and args.eval_iters > 0
        do_test = test_data is not None and args.eval_iters > 0
        # Need to broadcast num_tokens and num_type_tokens.
        flags = torch.cuda.LongTensor(
            [int(do_train), int(do_valid), int(do_test)])
148
    else:
Mohammad's avatar
Mohammad committed
149
        flags = torch.cuda.LongTensor([0, 0, 0])
150
151

    # Broadcast num tokens.
Mohammad's avatar
Mohammad committed
152
    torch.distributed.broadcast(flags,
153
154
                                mpu.get_model_parallel_src_rank(),
                                group=mpu.get_model_parallel_group())
Mohammad's avatar
Mohammad committed
155
156
157
    args.do_train = flags[0].item()
    args.do_valid = flags[1].item()
    args.do_test = flags[2].item()
158

159
    return train_data, valid_data, test_data
160
161
162


if __name__ == "__main__":
163

Mohammad's avatar
Mohammad committed
164
165
    pretrain(get_train_val_test_data, model_provider, forward_step,
             args_defaults={'tokenizer_type': 'GPT2BPETokenizer'})