pretrain_gpt2.py 4.95 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Pretrain GPT2"""

Mohammad's avatar
Mohammad committed
18
19
import os

20
21
import torch

Mohammad's avatar
Mohammad committed
22
23
from megatron import get_args
from megatron import get_timers
Mohammad's avatar
Mohammad committed
24
from megatron import get_tokenizer
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
25
from megatron import mpu
Mohammad's avatar
Mohammad committed
26
from megatron import print_rank_0
Mohammad's avatar
Mohammad committed
27
from megatron.data.gpt2_dataset import GPT2Dataset
28
from megatron.model import GPT2Model
Mohammad's avatar
Mohammad committed
29
from megatron.training import pretrain
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
30
from megatron.utils import get_ltor_masks_and_position_ids
31
from megatron.utils import make_data_loader
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
32
from megatron.utils import reduce_losses
Mohammad's avatar
Mohammad committed
33

34

Mohammad's avatar
Mohammad committed
35
def model_provider():
36
    """Build the model."""
Mohammad's avatar
Mohammad committed
37
    args = get_args()
38
39

    print_rank_0('building GPT2 model ...')
Mohammad's avatar
Mohammad committed
40
    model = GPT2Model(num_tokentypes=0, parallel_output=True)
41
42
43
44

    return model


Mohammad's avatar
Mohammad committed
45
def get_batch(data_iterator):
46
    """Generate a batch"""
Mohammad's avatar
Mohammad committed
47
    args = get_args()
48

49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
    # Items and their type.
    keys = ['text']
    datatype = torch.int64

    # Broadcast data.
    if data_iterator is not None:
        data = next(data_iterator)
    else:
        data = None
    data_b = mpu.broadcast_data(keys, data, datatype)

    # Unpack.
    tokens_ = data_b['text'].long()
    labels = tokens_[:, 1:].contiguous()
    tokens = tokens_[:, :-1].contiguous()

    # Get the masks and postition ids.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
66
    attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(
67
68
69
        tokens,
        args.eod_token,
        args.reset_position_ids,
70
71
        args.reset_attention_mask,
        args.eod_mask_loss)
72
73
74
75
76
77
78
    # Convert
    if args.fp16:
        attention_mask = attention_mask.half()

    return tokens, labels, loss_mask, attention_mask, position_ids


Mohammad's avatar
Mohammad committed
79
def forward_step(data_iterator, model):
80
    """Forward step."""
Mohammad's avatar
Mohammad committed
81
    timers = get_timers()
82
83
84
85

    # Get the batch.
    timers('batch generator').start()
    tokens, labels, loss_mask, attention_mask, position_ids = get_batch(
Mohammad's avatar
Mohammad committed
86
        data_iterator)
87
88
89
90
91
92
93
94
95
    timers('batch generator').stop()

    # Forward model.
    output = model(tokens, position_ids, attention_mask)
    losses = mpu.vocab_parallel_cross_entropy(output.contiguous().float(),
                                              labels)
    loss_mask = loss_mask.view(-1)
    loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()

96
    # Reduce loss for logging.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
97
    reduced_loss = reduce_losses([loss])
98

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
99
    return loss, {'lm loss': reduced_loss[0]}
100
101


Mohammad's avatar
Mohammad committed
102
103
104
105
106
107
108
109
110
def make_gpt2_dataloaders():
    """Build gpt2 dataloders."""
    args = get_args()

    # Input parameters.
    input_data_sizes_file = args.input_data_sizes_file
    seq_length = args.seq_length
    initial_seed = args.seed

111
112
113
114
115
116
117
118
119
120
121
122
123
    # Build the datasets.
    def build_dataset_(name):
        return GPT2Dataset(os.path.join(args.data_path, name),
                           args.input_data_sizes_file,
                           args.seq_length, args.seed)
    train_ds = build_dataset_('train')
    valid_ds = build_dataset_('valid')
    test_ds = build_dataset_('test')

    # Dataloaders
    train = make_data_loader(train_ds)
    valid = make_data_loader(valid_ds)
    test = make_data_loader(test_ds)
Mohammad's avatar
Mohammad committed
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138

    args.do_train = False
    args.do_valid = False
    args.do_test = False

    if train is not None:
        args.do_train = True
    if valid is not None:
        args.do_valid = True
    if test is not None:
        args.do_test = True

    return (train, valid, test)


Mohammad's avatar
Mohammad committed
139
def get_train_val_test_data():
140
    """Load the data on rank zero and boradcast number of tokens to all GPUS."""
Mohammad's avatar
Mohammad committed
141
    args = get_args()
Mohammad's avatar
Mohammad committed
142

143
144
145
146
    (train_data, val_data, test_data) = (None, None, None)

    # Data loader only on rank 0 of each model parallel group.
    if mpu.get_model_parallel_rank() == 0:
Mohammad's avatar
Mohammad committed
147

Mohammad's avatar
Mohammad committed
148
149
150
151
        (train_data, val_data, test_data) = make_gpt2_dataloaders()
        flags = torch.cuda.LongTensor([int(args.do_train),
                                       int(args.do_valid),
                                       int(args.do_test)])
152
    else:
Mohammad's avatar
Mohammad committed
153
        flags = torch.cuda.LongTensor([0, 0, 0])
154
155

    # Broadcast num tokens.
Mohammad's avatar
Mohammad committed
156
    torch.distributed.broadcast(flags,
157
158
                                mpu.get_model_parallel_src_rank(),
                                group=mpu.get_model_parallel_group())
Mohammad's avatar
Mohammad committed
159
160
161
    args.do_train = flags[0].item()
    args.do_valid = flags[1].item()
    args.do_test = flags[2].item()
162

Mohammad's avatar
Mohammad committed
163
164
    tokenizer = get_tokenizer()
    args.eod_token = tokenizer.eod_id
165

166
    return train_data, val_data, test_data
167
168
169


if __name__ == "__main__":
170

Mohammad's avatar
Mohammad committed
171
172
    pretrain(get_train_val_test_data, model_provider, forward_step,
             args_defaults={'tokenizer_type': 'GPT2BPETokenizer'})