pretrain_gpt.py 3.37 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved.
2

3
"""Pretrain GPT"""
4
5

import torch
6
from functools import partial
Neel Kant's avatar
Neel Kant committed
7
8
from megatron import get_args
from megatron import print_rank_0
Mohammad's avatar
Mohammad committed
9
from megatron import get_timers
Mohammad's avatar
Mohammad committed
10
from megatron import get_tokenizer
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
11
from megatron import mpu
12
from megatron import core
13
from megatron.data.gpt_dataset import build_train_valid_test_datasets
14
from megatron.model import GPTModel, ModelType
Mohammad's avatar
Mohammad committed
15
from megatron.training import pretrain
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
16
from megatron.utils import get_ltor_masks_and_position_ids
17
from megatron.utils import average_losses_across_data_parallel_group
Mohammad's avatar
Mohammad committed
18

19
def model_provider(pre_process=True, post_process=True):
20
21
    """Build the model."""

22
    print_rank_0('building GPT model ...')
23
24
25
26
27
28
    model = GPTModel(
        num_tokentypes=0,
        parallel_output=True,
        pre_process=pre_process,
        post_process=post_process
    )
29
30
31
    return model


Mohammad's avatar
Mohammad committed
32
def get_batch(data_iterator):
33
    """Generate a batch"""
Mohammad's avatar
Mohammad committed
34
    args = get_args()
Mohammad's avatar
Mohammad committed
35
    tokenizer = get_tokenizer()
36

37
38
39
40
41
42
43
44
45
    # Items and their type.
    keys = ['text']
    datatype = torch.int64

    # Broadcast data.
    if data_iterator is not None:
        data = next(data_iterator)
    else:
        data = None
46
    data_b = core.tensor_parallel.broadcast_data(keys, data, datatype)
47
48
49
50
51
52
53

    # Unpack.
    tokens_ = data_b['text'].long()
    labels = tokens_[:, 1:].contiguous()
    tokens = tokens_[:, :-1].contiguous()

    # Get the masks and postition ids.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
54
    attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(
55
        tokens,
Mohammad's avatar
Mohammad committed
56
        tokenizer.eod,
57
        args.reset_position_ids,
58
        args.reset_attention_mask,
59
        args.eod_mask_loss)
60
61
62

    return tokens, labels, loss_mask, attention_mask, position_ids

63
64
65
66
def loss_func(loss_mask, output_tensor):
    losses = output_tensor.float()
    loss_mask = loss_mask.view(-1).float()
    loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()
67

68
69
70
71
72
73
74
    # Reduce loss for logging.
    averaged_loss = average_losses_across_data_parallel_group([loss])

    return loss, {'lm loss': averaged_loss[0]}


def forward_step(data_iterator, model):
75
    """Forward step."""
76
    args = get_args()
Mohammad's avatar
Mohammad committed
77
    timers = get_timers()
78
79

    # Get the batch.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
80
    timers('batch-generator', log_level=2).start()
81
    tokens, labels, loss_mask, attention_mask, position_ids = get_batch(
Mohammad's avatar
Mohammad committed
82
        data_iterator)
mohammad's avatar
mohammad committed
83
    timers('batch-generator').stop()
84

85
86
    output_tensor = model(tokens, position_ids, attention_mask,
                          labels=labels)
87

88
    return output_tensor, partial(loss_func, loss_mask)
89
90


91
92
def train_valid_test_datasets_provider(train_val_test_num_samples):
    """Build train, valid, and test datasets."""
Mohammad's avatar
Mohammad committed
93
    args = get_args()
Mohammad's avatar
Mohammad committed
94

95
    print_rank_0('> building train, validation, and test datasets '
96
                 'for GPT ...')
97
98
99
100
101
102
103
104
    train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
        data_prefix=args.data_path,
        data_impl=args.data_impl,
        splits_string=args.split,
        train_valid_test_num_samples=train_val_test_num_samples,
        seq_length=args.seq_length,
        seed=args.seed,
        skip_warmup=(not args.mmap_warmup))
105
    print_rank_0("> finished creating GPT datasets ...")
106

107
    return train_ds, valid_ds, test_ds
108
109
110


if __name__ == "__main__":
111

112
113
114
    pretrain(train_valid_test_datasets_provider, model_provider,
             ModelType.encoder_or_decoder,
             forward_step, args_defaults={'tokenizer_type': 'GPT2BPETokenizer'})