finetune_utils.py 11.6 KB
Newer Older
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Finetune utilities."""

Jared Casper's avatar
Jared Casper committed
18
from functools import partial
Mostofa Patwary's avatar
Mostofa Patwary committed
19
import sys
Jared Casper's avatar
Jared Casper committed
20

21
22
import torch

Neel Kant's avatar
Neel Kant committed
23
24
from megatron import get_args
from megatron import print_rank_0
Mohammad's avatar
Mohammad committed
25
from megatron import get_timers
26
from megatron import mpu
Neel Kant's avatar
Neel Kant committed
27
from megatron.checkpointing import load_checkpoint
Mohammad's avatar
Mohammad committed
28
from megatron.checkpointing import save_checkpoint
29
30
31
32
from megatron.training import evaluate_and_print_results
from megatron.training import setup_model_and_optimizer
from megatron.training import train_step
from megatron.training import training_log
33
from megatron.utils import average_losses_across_data_parallel_group
mohammad's avatar
mohammad committed
34
35
from megatron.utils import calc_params_l2_norm
from megatron.utils import check_adlr_autoresume_termination
36
37


Mohammad's avatar
Mohammad committed
38
def process_batch(batch):
39
    """Process batch and produce inputs for the model."""
Mohammad's avatar
Mohammad committed
40
    args = get_args()
41
42
43
44
45
46
47
48
49
50
51

    tokens = batch['text'].long().cuda().contiguous()
    types = batch['types'].long().cuda().contiguous()
    labels = batch['label'].long().cuda().contiguous()
    attention_mask = batch['padding_mask'].float().cuda().contiguous()
    if args.fp16:
        attention_mask = attention_mask.half()

    return tokens, types, labels, attention_mask


Jared Casper's avatar
Jared Casper committed
52
53
54
55
56
57
58
59
60
61
62
63
64
65
def cross_entropy_loss_func(labels, output_tensor):
    logits = output_tensor

    # Cross-entropy loss.
    loss_func = torch.nn.CrossEntropyLoss()
    loss = loss_func(logits.contiguous().float(), labels)

    # Reduce loss for logging.
    averaged_loss = average_losses_across_data_parallel_group([loss])

    return loss, {'lm loss': averaged_loss[0]}


def _cross_entropy_forward_step(batch, model):
66
    """Simple forward step with cross-entropy loss."""
Mohammad's avatar
Mohammad committed
67
    timers = get_timers()
68
69

    # Get the batch.
mohammad's avatar
mohammad committed
70
    timers('batch-generator').start()
71
72
    try:
        batch_ = next(batch)
Neel Kant's avatar
Neel Kant committed
73
    except BaseException:
74
        batch_ = batch
Mohammad's avatar
Mohammad committed
75
    tokens, types, labels, attention_mask = process_batch(batch_)
mohammad's avatar
mohammad committed
76
    timers('batch-generator').stop()
77
78

    # Forward model.
Jared Casper's avatar
Jared Casper committed
79
    output_tensor = model(tokens, attention_mask, tokentype_ids=types)
80

Jared Casper's avatar
Jared Casper committed
81
    return output_tensor, partial(cross_entropy_loss_func, labels)
82
83


Mostofa Patwary's avatar
Mostofa Patwary committed
84
85
def build_data_loader(dataset, micro_batch_size, num_workers, drop_last, 
        task_collate_fn=None):
86
87
88
89
90
91
92
93
94
95
    """Data loader. Note that batch-size is the local (per GPU) batch-size."""

    # Sampler.
    world_size = mpu.get_data_parallel_world_size()
    rank = mpu.get_data_parallel_rank()
    sampler = torch.utils.data.distributed.DistributedSampler(
        dataset, num_replicas=world_size, rank=rank)

    # Data loader. Note that batch size is the per GPU batch size.
    data_loader = torch.utils.data.DataLoader(dataset,
96
                                              batch_size=micro_batch_size,
97
98
99
100
                                              sampler=sampler,
                                              shuffle=False,
                                              num_workers=num_workers,
                                              drop_last=drop_last,
Mostofa Patwary's avatar
Mostofa Patwary committed
101
102
                                              pin_memory=True,
                                              collate_fn=task_collate_fn)
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117

    return data_loader


def _build_infinite_size_dataloader(dataloader):
    """Build a looped dataloader with infinite size."""

    iterator = dataloader.__iter__()
    while True:
        try:
            yield iterator.__next__()
        except StopIteration:
            iterator = dataloader.__iter__()


Mostofa Patwary's avatar
Mostofa Patwary committed
118
def _build_train_valid_dataloaders(train_dataset, valid_dataset, task_collate_fn=None):
119
    """Traing and validation dataloaders."""
Mohammad's avatar
Mohammad committed
120
    args = get_args()
121
122
123

    print_rank_0('building train and validation dataloaders ...')
    # Training dataset.
124
    train_dataloader = build_data_loader(train_dataset, args.micro_batch_size,
Mostofa Patwary's avatar
Mostofa Patwary committed
125
126
                                         args.num_workers, not args.keep_last,
                                         task_collate_fn)
127
128
129
130
131
    # Set the training iterations.
    args.train_iters_per_epoch = len(train_dataloader)
    args.train_iters = args.epochs * args.train_iters_per_epoch
    # Validation dataset. For this dataset, we do not need to set up
    # shuffling so we can just use a simple infinite loop.
132
    valid_dataloader_ = build_data_loader(valid_dataset, args.micro_batch_size,
Mostofa Patwary's avatar
Mostofa Patwary committed
133
134
                                          args.num_workers, not args.keep_last,
                                          task_collate_fn)
135
136
    valid_dataloader = _build_infinite_size_dataloader(valid_dataloader_)

137
138
139
140
141
    # Now that we've built the data loaders, set batch_size arguments
    # to the actual batch size the model will see for this dataset.
    # This is necessary so pipeline transfers know what size they are
    # and the LR schedule, which is based on samples seen, gets set
    # correctly.
Jared Casper's avatar
Jared Casper committed
142
143
    args.orig_micro_batch_size = args.micro_batch_size
    args.orig_global_batch_size = args.global_batch_size
144
    if hasattr(train_dataset, 'sample_multiplier'):
145
146
147
148
149
        # If our dataset as a sample_multiplier attribute that means
        # each "sample" from the dataset actually has multiple samples
        # that will collapse into the batch dimension (for example in
        # the RACE dataset that has several options), we need to
        # account for that when setting the micro batch size.
150
        args.micro_batch_size *= train_dataset.sample_multiplier
151
        args.global_batch_size *= train_dataset.sample_multiplier
152

153
154
155
156
    return train_dataloader, valid_dataloader


def _train(model, optimizer, lr_scheduler, forward_step,
Mohammad's avatar
Mohammad committed
157
           train_dataloader, valid_dataloader, end_of_epoch_callback):
158
    """Train the model."""
Mohammad's avatar
Mohammad committed
159
160
    args = get_args()
    timers = get_timers()
161
162

    # Turn on training mode which enables dropout.
Jared Casper's avatar
Jared Casper committed
163
164
    for m in model:
        m.train()
165
166
167
168
169
170
171
172
173
174
175
176
177

    # Tracking loss.
    losses_dict_sum = {}

    # Starting epoch and iteration
    start_epoch = args.iteration // args.train_iters_per_epoch
    start_iteration = args.iteration % args.train_iters_per_epoch
    iteration = args.iteration

    # Memory reporting flag.
    report_memory_flag = True

    # For each remaining epoch
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
178
    timers('interval-time').start()
179
    for epoch in range(start_epoch, args.epochs):
Neel Kant's avatar
Neel Kant committed
180
        print_rank_0('working on epoch {} ...'.format(epoch + 1))
181
182
183
184
185
186
187
188
189
190
191
192

        # Set the data loader epoch to shuffle the index iterator.
        train_dataloader.sampler.set_epoch(args.seed + epoch)

        # For all the batches in the dataset.
        for iteration_, batch in enumerate(train_dataloader):

            # Ignore the iterations before starting value
            if iteration_ < start_iteration:
                continue
            # Set to zero so the next epoch does not skip any batches.
            start_iteration = 0
Mostofa Patwary's avatar
Mostofa Patwary committed
193
    
194
            # Train for one step.
Jared Casper's avatar
Jared Casper committed
195
            out = train_step(forward_step, batch, model, optimizer, lr_scheduler)
Mostofa Patwary's avatar
Mostofa Patwary committed
196

Jared Casper's avatar
Jared Casper committed
197
            losses_dict, skipped_iter, grad_norm, num_zeros_in_grad = out
198
199
200
            iteration += 1

            # Logging.
201
202
203
            params_norm = None
            if args.log_params_norm:
                params_norm = calc_params_l2_norm(model)
204
205
            report_memory_flag = training_log(losses_dict, losses_dict_sum,
                                              optimizer.param_groups[0]['lr'],
206
207
                                              iteration,
                                              optimizer.get_loss_scale().item(),
208
                                              report_memory_flag, skipped_iter,
Jared Casper's avatar
Jared Casper committed
209
                                              grad_norm, params_norm, num_zeros_in_grad)
210
211

            # Autoresume
Neel Kant's avatar
Neel Kant committed
212
            if args.adlr_autoresume and \
213
               (iteration % args.adlr_autoresume_interval == 0):
Mohammad's avatar
Mohammad committed
214
215
                check_adlr_autoresume_termination(iteration, model,
                                                  optimizer, lr_scheduler)
216
217
218
219

            # Checkpointing
            if args.save and args.save_interval and \
               iteration % args.save_interval == 0:
Mohammad's avatar
Mohammad committed
220
                save_checkpoint(iteration, model, optimizer, lr_scheduler)
221
222
223
224
225

            # Evaluation
            if args.eval_interval and iteration % args.eval_interval == 0:
                prefix = 'iteration {}'.format(iteration)
                evaluate_and_print_results(prefix, forward_step,
Mohammad's avatar
Mohammad committed
226
227
                                           valid_dataloader, model,
                                           iteration, False)
228

Mostofa Patwary's avatar
Mostofa Patwary committed
229
230
231
            #if iteration == 600:
            #    sys.exit()

232
233
        # Checkpointing at the end of each epoch.
        if args.save:
Mohammad's avatar
Mohammad committed
234
            save_checkpoint(iteration, model, optimizer, lr_scheduler)
235
236
237

        # Callback at the end of each epoch.
        if end_of_epoch_callback is not None:
Mohammad's avatar
Mohammad committed
238
            end_of_epoch_callback(model, epoch)
239
240


Mohammad's avatar
Mohammad committed
241
def finetune(train_valid_datasets_provider, model_provider,
242
             forward_step=_cross_entropy_forward_step,
Mostofa Patwary's avatar
Mostofa Patwary committed
243
244
             end_of_epoch_callback_provider=None,
             task_collate_fn=None):
245
    """Main finetune function used across all tasks."""
Mohammad's avatar
Mohammad committed
246
247
    args = get_args()
    timers = get_timers()
248

249
250
251
    assert args.rampup_batch_size is None, \
        'batch size scaling is not supported for finetuning'

252
    # Train and validation data loaders.
Mohammad's avatar
Mohammad committed
253
    timers('train/valid/test dataset/dataloder').start()
254
    if args.epochs > 0:
Mohammad's avatar
Mohammad committed
255
        train_dataset, valid_dataset = train_valid_datasets_provider()
256
        train_dataloader, valid_dataloader = _build_train_valid_dataloaders(
Mostofa Patwary's avatar
Mostofa Patwary committed
257
            train_dataset, valid_dataset, task_collate_fn)
258
259
    else:
        args.train_iters = 0
Mohammad's avatar
Mohammad committed
260
    timers('train/valid/test dataset/dataloder').stop()
261
262

    # Build calback function.
Mohammad's avatar
Mohammad committed
263
    timers('callback function').start()
264
265
    end_of_epoch_callback = None
    if end_of_epoch_callback_provider is not None:
Mohammad's avatar
Mohammad committed
266
267
        end_of_epoch_callback = end_of_epoch_callback_provider()
    timers('callback function').stop()
268
269

    # Build model, optimizer and learning rate scheduler.
Mohammad's avatar
Mohammad committed
270
271
272
    timers('model and optimizer').start()
    model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider)
    timers('model and optimizer').stop()
273
274
275
276

    # If pretrained checkpoint is provided and we have not trained for
    # any iteration (i.e., iteration is zero), then load the pretrained
    # checkpoint.
Mohammad's avatar
Mohammad committed
277
    timers('pretrained checkpoint').start()
278
279
280
    if args.iteration == 0 and args.pretrained_checkpoint is not None:
        original_load = args.load
        args.load = args.pretrained_checkpoint
Mohammad's avatar
Mohammad committed
281
        _ = load_checkpoint(model, None, None)
282
283
        args.load = original_load
        # This is critical when only model is loaded. We should make sure
284
        # main parameters are also updated.
285
        optimizer.reload_model_params()
Mohammad's avatar
Mohammad committed
286
    timers('pretrained checkpoint').stop()
287

Mohammad's avatar
Mohammad committed
288
289
290
291
292
    # Print setup timing.
    print_rank_0('done with setups ...')
    timers.log(['train/valid/test dataset/dataloder', 'callback function',
                'model and optimizer', 'pretrained checkpoint'])
    print_rank_0('training ...')
293
294
295
296

    # Finetune the model.
    if args.epochs > 0:
        _train(model, optimizer, lr_scheduler, forward_step,
Mohammad's avatar
Mohammad committed
297
               train_dataloader, valid_dataloader, end_of_epoch_callback)
298
299
300
301
    # Or just evaluate.
    else:
        if end_of_epoch_callback is not None:
            print_rank_0('evaluation only mode, setting epoch to -1')
Mohammad's avatar
Mohammad committed
302
            end_of_epoch_callback(model, epoch=-1, output_predictions=True)
303
    print_rank_0('done :-)')