pretrain_gpt2.py 5.98 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Pretrain GPT2"""

import torch

from configure_data import configure_data
21
from gpt2_data_loader import make_gpt2_dataloaders
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
22
from megatron import mpu
23
from megatron.model import GPT2Model
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
24
from megatron.utils import get_ltor_masks_and_position_ids
25
from megatron import print_rank_0
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
26
from megatron.utils import reduce_losses
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
27
from megatron.utils import vocab_size_with_padding
28
from megatron.training import run
29
30


31
def model_provider(args):
32
33
34
35
36
37
38
39
40
41
42
43
44
    """Build the model."""

    print_rank_0('building GPT2 model ...')
    model = GPT2Model(num_layers=args.num_layers,
                      vocab_size=args.vocab_size,
                      hidden_size=args.hidden_size,
                      num_attention_heads=args.num_attention_heads,
                      embedding_dropout_prob=args.hidden_dropout,
                      attention_dropout_prob=args.attention_dropout,
                      output_dropout_prob=args.hidden_dropout,
                      max_sequence_length=args.max_position_embeddings,
                      checkpoint_activations=args.checkpoint_activations,
                      checkpoint_num_layers=args.checkpoint_num_layers,
45
                      layernorm_epsilon=args.layernorm_epsilon,
46
47
48
                      parallel_output=True,
                      apply_query_key_layer_scaling=args.apply_query_key_layer_scaling,
                      attention_softmax_in_fp32=args.attention_softmax_in_fp32)
49
50
51
52
53

    return model


def get_batch(data_iterator, args, timers):
54
55
    """Generate a batch"""

56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
    # Items and their type.
    keys = ['text']
    datatype = torch.int64

    # Broadcast data.
    timers('data loader').start()
    if data_iterator is not None:
        data = next(data_iterator)
    else:
        data = None
    timers('data loader').stop()
    data_b = mpu.broadcast_data(keys, data, datatype)

    # Unpack.
    tokens_ = data_b['text'].long()
    labels = tokens_[:, 1:].contiguous()
    tokens = tokens_[:, :-1].contiguous()

    # Get the masks and postition ids.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
75
    attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(
76
77
78
        tokens,
        args.eod_token,
        args.reset_position_ids,
79
80
        args.reset_attention_mask,
        args.eod_mask_loss)
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
    # Convert
    if args.fp16:
        attention_mask = attention_mask.half()

    return tokens, labels, loss_mask, attention_mask, position_ids


def forward_step(data_iterator, model, args, timers):
    """Forward step."""

    # Get the batch.
    timers('batch generator').start()
    tokens, labels, loss_mask, attention_mask, position_ids = get_batch(
        data_iterator, args, timers)
    timers('batch generator').stop()

    # Forward model.
    output = model(tokens, position_ids, attention_mask)
    losses = mpu.vocab_parallel_cross_entropy(output.contiguous().float(),
                                              labels)
    loss_mask = loss_mask.view(-1)
    loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()

104
    # Reduce loss for logging.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
105
    reduced_loss = reduce_losses([loss])
106

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
107
    return loss, {'lm loss': reduced_loss[0]}
108
109
110
111
112
113
114
115
116


def get_train_val_test_data(args):
    """Load the data on rank zero and boradcast number of tokens to all GPUS."""

    (train_data, val_data, test_data) = (None, None, None)

    # Data loader only on rank 0 of each model parallel group.
    if mpu.get_model_parallel_rank() == 0:
117
        if args.data_loader == 'numpy':
118
119
120
121
122
123
            assert len(args.train_data) == 1
            args.train_data = args.train_data[0]
            assert len(args.valid_data) == 1
            args.valid_data = args.valid_data[0]
            assert len(args.test_data) == 1
            args.test_data = args.test_data[0]
124
125
            (train_data, val_data, test_data), num_tokens, \
                eod_token = make_gpt2_dataloaders(args)
126
        elif args.data_loader == 'raw' or args.data_loader == 'lazy':
127
128
129
130
131
132
133
            data_config = configure_data()
            data_config.set_defaults(data_set_type='GPT2', transpose=False)
            (train_data, val_data, test_data), tokenizer = data_config.apply(
                args)
            num_tokens = tokenizer.num_tokens
            eod_token = tokenizer.get_command('eos').Id
            assert eod_token == tokenizer.get_command('pad').Id
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
134
135
136
        else:
            print("Unsupported data loader for GPT2.")
            exit(1)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
137
138
        # pad.
        num_tokens = vocab_size_with_padding(num_tokens, args)
139
        print_rank_0('> found end-of-document token: {}'.format(eod_token))
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
140
141
142
143
        token_counts = torch.cuda.LongTensor([num_tokens, eod_token,
                                              int(args.do_train),
                                              int(args.do_valid),
                                              int(args.do_test)])
144
145
146
147
148
149
150
151
152
153
154
155
156
    else:
        token_counts = torch.cuda.LongTensor([0, 0, 0, 0, 0])

    # Broadcast num tokens.
    torch.distributed.broadcast(token_counts,
                                mpu.get_model_parallel_src_rank(),
                                group=mpu.get_model_parallel_group())
    num_tokens = token_counts[0].item()
    eod_token = token_counts[1].item()
    args.do_train = token_counts[2].item()
    args.do_valid = token_counts[3].item()
    args.do_test = token_counts[4].item()

157
158
    args.vocab_size = num_tokens
    args.eod_token = eod_token
159

160
    return train_data, val_data, test_data
161
162
163


if __name__ == "__main__":
164
165
166

    run('Pretrain GPT-2 model', get_train_val_test_data,
        model_provider, forward_step)