pretrain_gpt.py 5.19 KB
Newer Older
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

16
"""Pretrain GPT"""
17
18
19

import torch

Neel Kant's avatar
Neel Kant committed
20
21
from megatron import get_args
from megatron import print_rank_0
Mohammad's avatar
Mohammad committed
22
from megatron import get_timers
Mohammad's avatar
Mohammad committed
23
from megatron import get_tokenizer
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
24
from megatron import mpu
25
26
27
28
29
from megatron.data.gpt_dataset import build_train_valid_test_datasets
from megatron.model import (GPTModel,
                            GPTModelFirstStage,
                            GPTModelIntermediateStage,
                            GPTModelLastStage)
Mohammad's avatar
Mohammad committed
30
from megatron.training import pretrain
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
31
from megatron.utils import get_ltor_masks_and_position_ids
32
from megatron.utils import average_losses_across_data_parallel_group
Mohammad's avatar
Mohammad committed
33

Mohammad's avatar
Mohammad committed
34
def model_provider():
35
36
    """Build the model."""

37
    print_rank_0('building GPT model ...')
38
39

    def model_provider_pipelined():
40
        # Determine model based on position of stage in pipeline.
41
        if mpu.is_pipeline_first_stage():
42
            model = GPTModelFirstStage(num_tokentypes=0)
43
        elif mpu.is_pipeline_last_stage():
44
            model = GPTModelLastStage(
45
46
                num_tokentypes=0, parallel_output=True)
        else:
47
            model = GPTModelIntermediateStage(
48
                num_tokentypes=0)
49
50
51
52
53
54
55
56
57
58
59
        return model

    args = get_args()
    if mpu.get_pipeline_model_parallel_world_size() > 1:
        if args.virtual_pipeline_model_parallel_size is not None:
            model = []
            for i in range(args.virtual_pipeline_model_parallel_size):
                mpu.set_virtual_pipeline_model_parallel_rank(i)
                model.append(model_provider_pipelined())
        else:
            model = model_provider_pipelined()
60
    else:
61
        model = GPTModel(num_tokentypes=0, parallel_output=True)
62
63
64
65

    return model


Mohammad's avatar
Mohammad committed
66
def get_batch(data_iterator):
67
    """Generate a batch"""
Mohammad's avatar
Mohammad committed
68
    args = get_args()
Mohammad's avatar
Mohammad committed
69
    tokenizer = get_tokenizer()
70

71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
    # Items and their type.
    keys = ['text']
    datatype = torch.int64

    # Broadcast data.
    if data_iterator is not None:
        data = next(data_iterator)
    else:
        data = None
    data_b = mpu.broadcast_data(keys, data, datatype)

    # Unpack.
    tokens_ = data_b['text'].long()
    labels = tokens_[:, 1:].contiguous()
    tokens = tokens_[:, :-1].contiguous()

    # Get the masks and postition ids.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
88
    attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(
89
        tokens,
Mohammad's avatar
Mohammad committed
90
        tokenizer.eod,
91
        args.reset_position_ids,
92
        args.reset_attention_mask,
93
        args.eod_mask_loss)
94
95
96
97

    return tokens, labels, loss_mask, attention_mask, position_ids


98
def forward_step(data_iterator, model, input_tensor):
99
    """Forward step."""
100
    args = get_args()
Mohammad's avatar
Mohammad committed
101
    timers = get_timers()
102
103

    # Get the batch.
mohammad's avatar
mohammad committed
104
    timers('batch-generator').start()
105
    tokens, labels, loss_mask, attention_mask, position_ids = get_batch(
Mohammad's avatar
Mohammad committed
106
        data_iterator)
mohammad's avatar
mohammad committed
107
    timers('batch-generator').stop()
108

109
    # Forward pass through the model.
110
    if mpu.is_pipeline_first_stage():
111
        assert input_tensor is None
112
        if mpu.is_pipeline_last_stage():
113
114
115
116
            output_tensor = model(tokens, position_ids, attention_mask,
                                  labels=labels)
        else:
            output_tensor = model(tokens, position_ids, attention_mask)
117
    elif mpu.is_pipeline_last_stage():
118
119
120
121
122
123
        assert input_tensor is not None
        output_tensor = model(input_tensor, attention_mask, labels=labels)
    else:
        assert input_tensor is not None
        output_tensor = model(input_tensor, attention_mask)

124
    if mpu.is_pipeline_last_stage():
125
126
        losses = output_tensor.float()
        loss_mask = loss_mask.view(-1).float()
127
        loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()
128
129
130

        # Reduce loss for logging.
        averaged_loss = average_losses_across_data_parallel_group([loss])
131

132
133
        return loss, {'lm loss': averaged_loss[0]}
    return output_tensor
134
135


136
137
def train_valid_test_datasets_provider(train_val_test_num_samples):
    """Build train, valid, and test datasets."""
Mohammad's avatar
Mohammad committed
138
    args = get_args()
Mohammad's avatar
Mohammad committed
139

140
    print_rank_0('> building train, validation, and test datasets '
141
                 'for GPT ...')
142
143
144
145
146
147
148
149
    train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
        data_prefix=args.data_path,
        data_impl=args.data_impl,
        splits_string=args.split,
        train_valid_test_num_samples=train_val_test_num_samples,
        seq_length=args.seq_length,
        seed=args.seed,
        skip_warmup=(not args.mmap_warmup))
150
    print_rank_0("> finished creating GPT datasets ...")
151

152
    return train_ds, valid_ds, test_ds
153
154
155


if __name__ == "__main__":
156

157
    pretrain(train_valid_test_datasets_provider, model_provider, forward_step,
158
             args_defaults={'tokenizer_type': 'GPT2BPETokenizer'})