finetune_utils.py 12.1 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2
3
4

"""Finetune utilities."""

Jared Casper's avatar
Jared Casper committed
5
from functools import partial
6
import sys
7
8
import torch

xingjinliang's avatar
xingjinliang committed
9
10
11
12
from megatron.training import get_args
from megatron.core.num_microbatches_calculator import get_num_microbatches
from megatron.training import print_rank_0
from megatron.training import get_timers
13
from megatron.core import mpu
14
from megatron.core.enums import ModelType
xingjinliang's avatar
xingjinliang committed
15
16
17
18
19
20
21
22
23
from megatron.training.checkpointing import load_checkpoint
from megatron.training.checkpointing import save_checkpoint
from megatron.training.training import evaluate_and_print_results
from megatron.training.training import setup_model_and_optimizer
from megatron.training.training import train_step
from megatron.training.training import training_log
from megatron.training.utils import average_losses_across_data_parallel_group
from megatron.training.utils import calc_params_l2_norm
from megatron.training.utils import check_adlr_autoresume_termination
24
25


Mohammad's avatar
Mohammad committed
26
def process_batch(batch):
27
    """Process batch and produce inputs for the model."""
Mohammad's avatar
Mohammad committed
28
    args = get_args()
29
30
31
32
33
34
35
36
37
38
39

    tokens = batch['text'].long().cuda().contiguous()
    types = batch['types'].long().cuda().contiguous()
    labels = batch['label'].long().cuda().contiguous()
    attention_mask = batch['padding_mask'].float().cuda().contiguous()
    if args.fp16:
        attention_mask = attention_mask.half()

    return tokens, types, labels, attention_mask


Jared Casper's avatar
Jared Casper committed
40
41
42
43
44
45
46
47
48
49
50
51
52
53
def cross_entropy_loss_func(labels, output_tensor):
    logits = output_tensor

    # Cross-entropy loss.
    loss_func = torch.nn.CrossEntropyLoss()
    loss = loss_func(logits.contiguous().float(), labels)

    # Reduce loss for logging.
    averaged_loss = average_losses_across_data_parallel_group([loss])

    return loss, {'lm loss': averaged_loss[0]}


def _cross_entropy_forward_step(batch, model):
54
    """Simple forward step with cross-entropy loss."""
Mohammad's avatar
Mohammad committed
55
    timers = get_timers()
56
57

    # Get the batch.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
58
    timers('batch-generator', log_level=2).start()
59
60
    try:
        batch_ = next(batch)
xingjinliang's avatar
xingjinliang committed
61
    except Exception:
62
        batch_ = batch
Mohammad's avatar
Mohammad committed
63
    tokens, types, labels, attention_mask = process_batch(batch_)
mohammad's avatar
mohammad committed
64
    timers('batch-generator').stop()
65
66

    # Forward model.
Jared Casper's avatar
Jared Casper committed
67
    output_tensor = model(tokens, attention_mask, tokentype_ids=types)
68

Jared Casper's avatar
Jared Casper committed
69
    return output_tensor, partial(cross_entropy_loss_func, labels)
70
71


Mostofa Patwary's avatar
Mostofa Patwary committed
72
def build_data_loader(dataset, micro_batch_size, num_workers, drop_last,
Mostofa Patwary's avatar
Mostofa Patwary committed
73
        task_collate_fn=None):
74
75
76
77
78
79
80
81
82
83
    """Data loader. Note that batch-size is the local (per GPU) batch-size."""

    # Sampler.
    world_size = mpu.get_data_parallel_world_size()
    rank = mpu.get_data_parallel_rank()
    sampler = torch.utils.data.distributed.DistributedSampler(
        dataset, num_replicas=world_size, rank=rank)

    # Data loader. Note that batch size is the per GPU batch size.
    data_loader = torch.utils.data.DataLoader(dataset,
84
                                              batch_size=micro_batch_size,
85
86
87
88
                                              sampler=sampler,
                                              shuffle=False,
                                              num_workers=num_workers,
                                              drop_last=drop_last,
Mostofa Patwary's avatar
Mostofa Patwary committed
89
90
                                              pin_memory=True,
                                              collate_fn=task_collate_fn)
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105

    return data_loader


def _build_infinite_size_dataloader(dataloader):
    """Build a looped dataloader with infinite size."""

    iterator = dataloader.__iter__()
    while True:
        try:
            yield iterator.__next__()
        except StopIteration:
            iterator = dataloader.__iter__()


Mostofa Patwary's avatar
Mostofa Patwary committed
106
107
def _build_train_valid_dataloaders(train_dataset, valid_dataset, 
    task_collate_fn=None):
108
    """Traing and validation dataloaders."""
Mohammad's avatar
Mohammad committed
109
    args = get_args()
110
111
112

    print_rank_0('building train and validation dataloaders ...')
    # Training dataset.
113
    train_dataloader = build_data_loader(train_dataset, args.micro_batch_size,
Mostofa Patwary's avatar
Mostofa Patwary committed
114
115
                                         args.num_workers, not args.keep_last,
                                         task_collate_fn)
116
117
118
119
120
    # Set the training iterations.
    args.train_iters_per_epoch = len(train_dataloader)
    args.train_iters = args.epochs * args.train_iters_per_epoch
    # Validation dataset. For this dataset, we do not need to set up
    # shuffling so we can just use a simple infinite loop.
121
    valid_dataloader_ = build_data_loader(valid_dataset, args.micro_batch_size,
Mostofa Patwary's avatar
Mostofa Patwary committed
122
123
                                          args.num_workers, not args.keep_last,
                                          task_collate_fn)
124
125
    valid_dataloader = _build_infinite_size_dataloader(valid_dataloader_)

126
127
128
129
130
    # Now that we've built the data loaders, set batch_size arguments
    # to the actual batch size the model will see for this dataset.
    # This is necessary so pipeline transfers know what size they are
    # and the LR schedule, which is based on samples seen, gets set
    # correctly.
Jared Casper's avatar
Jared Casper committed
131
132
    args.orig_micro_batch_size = args.micro_batch_size
    args.orig_global_batch_size = args.global_batch_size
133
    if hasattr(train_dataset, 'sample_multiplier'):
134
135
136
137
138
        # If our dataset as a sample_multiplier attribute that means
        # each "sample" from the dataset actually has multiple samples
        # that will collapse into the batch dimension (for example in
        # the RACE dataset that has several options), we need to
        # account for that when setting the micro batch size.
139
        args.micro_batch_size *= train_dataset.sample_multiplier
140
        args.global_batch_size *= train_dataset.sample_multiplier
141

142
143
144
    return train_dataloader, valid_dataloader


145
def _train(model, optimizer, opt_param_scheduler, forward_step,
Mohammad's avatar
Mohammad committed
146
           train_dataloader, valid_dataloader, end_of_epoch_callback):
147
    """Train the model."""
Mohammad's avatar
Mohammad committed
148
149
    args = get_args()
    timers = get_timers()
150

151
152
    assert get_num_microbatches() == 1, "finetuning with gradient accumulation doesn't currently work"

153
    # Turn on training mode which enables dropout.
Jared Casper's avatar
Jared Casper committed
154
155
    for m in model:
        m.train()
156
157
158
159
160
161
162
163
164
165
166
167
168

    # Tracking loss.
    losses_dict_sum = {}

    # Starting epoch and iteration
    start_epoch = args.iteration // args.train_iters_per_epoch
    start_iteration = args.iteration % args.train_iters_per_epoch
    iteration = args.iteration

    # Memory reporting flag.
    report_memory_flag = True

    # For each remaining epoch
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
169
    timers('interval-time', log_level=0).start(barrier=True)
170
    for epoch in range(start_epoch, args.epochs):
Neel Kant's avatar
Neel Kant committed
171
        print_rank_0('working on epoch {} ...'.format(epoch + 1))
172
173
174
175
176
177
178
179
180
181
182
183
184
185

        # Set the data loader epoch to shuffle the index iterator.
        train_dataloader.sampler.set_epoch(args.seed + epoch)

        # For all the batches in the dataset.
        for iteration_, batch in enumerate(train_dataloader):

            # Ignore the iterations before starting value
            if iteration_ < start_iteration:
                continue
            # Set to zero so the next epoch does not skip any batches.
            start_iteration = 0

            # Train for one step.
186
            out = train_step(forward_step, batch, model, optimizer, opt_param_scheduler)
Mostofa Patwary's avatar
Mostofa Patwary committed
187

Jared Casper's avatar
Jared Casper committed
188
            losses_dict, skipped_iter, grad_norm, num_zeros_in_grad = out
189
190
191
            iteration += 1

            # Logging.
192
193
194
            params_norm = None
            if args.log_params_norm:
                params_norm = calc_params_l2_norm(model)
195
196
            report_memory_flag = training_log(losses_dict, losses_dict_sum,
                                              optimizer.param_groups[0]['lr'],
197
198
                                              iteration,
                                              optimizer.get_loss_scale().item(),
199
                                              report_memory_flag, skipped_iter,
Jared Casper's avatar
Jared Casper committed
200
                                              grad_norm, params_norm, num_zeros_in_grad)
201
202

            # Autoresume
Neel Kant's avatar
Neel Kant committed
203
            if args.adlr_autoresume and \
204
               (iteration % args.adlr_autoresume_interval == 0):
Mohammad's avatar
Mohammad committed
205
                check_adlr_autoresume_termination(iteration, model,
206
                                                  optimizer, opt_param_scheduler)
207
208

            # Checkpointing
209
            saved_checkpoint = False
210
211
            if args.save and args.save_interval and \
               iteration % args.save_interval == 0:
212
                save_checkpoint(iteration, model, optimizer, opt_param_scheduler)
213
                saved_checkpoint = True
214
215
216
217
218

            # Evaluation
            if args.eval_interval and iteration % args.eval_interval == 0:
                prefix = 'iteration {}'.format(iteration)
                evaluate_and_print_results(prefix, forward_step,
Mohammad's avatar
Mohammad committed
219
                                           valid_dataloader, model,
220
                                           iteration, None, False)
221

222
223
224
            # Exiting based on iterations
            if args.exit_interval and iteration % args.exit_interval == 0:
                if not saved_checkpoint:
225
                    save_checkpoint(iteration, model, optimizer, opt_param_scheduler)
226
227
228
229
                torch.distributed.barrier()
                print_rank_0('exiting program at iteration {}'.format(iteration))
                sys.exit()

230
231
        # Checkpointing at the end of each epoch.
        if args.save:
232
            save_checkpoint(iteration, model, optimizer, opt_param_scheduler)
233
234
235

        # Callback at the end of each epoch.
        if end_of_epoch_callback is not None:
Mohammad's avatar
Mohammad committed
236
            end_of_epoch_callback(model, epoch)
237
238


Mohammad's avatar
Mohammad committed
239
def finetune(train_valid_datasets_provider, model_provider,
240
             model_type=ModelType.encoder_or_decoder,
241
             forward_step=_cross_entropy_forward_step,
Mostofa Patwary's avatar
Mostofa Patwary committed
242
243
             end_of_epoch_callback_provider=None,
             task_collate_fn=None):
244
    """Main finetune function used across all tasks."""
Mohammad's avatar
Mohammad committed
245
246
    args = get_args()
    timers = get_timers()
247

248
249
250
    assert args.rampup_batch_size is None, \
        'batch size scaling is not supported for finetuning'

251
    # Train and validation data loaders.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
252
    timers('train/valid/test dataset/dataloder', log_level=0).start()
253
    if args.epochs > 0:
Mohammad's avatar
Mohammad committed
254
        train_dataset, valid_dataset = train_valid_datasets_provider()
255
        train_dataloader, valid_dataloader = _build_train_valid_dataloaders(
Mostofa Patwary's avatar
Mostofa Patwary committed
256
            train_dataset, valid_dataset, task_collate_fn)
257
258
    else:
        args.train_iters = 0
Mohammad's avatar
Mohammad committed
259
    timers('train/valid/test dataset/dataloder').stop()
260
261

    # Build calback function.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
262
    timers('callback function', log_level=0).start()
263
264
    end_of_epoch_callback = None
    if end_of_epoch_callback_provider is not None:
Mohammad's avatar
Mohammad committed
265
266
        end_of_epoch_callback = end_of_epoch_callback_provider()
    timers('callback function').stop()
267
268

    # Build model, optimizer and learning rate scheduler.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
269
    timers('model and optimizer', log_level=0).start()
270
    model, optimizer, opt_param_scheduler = setup_model_and_optimizer(model_provider, model_type)
Mohammad's avatar
Mohammad committed
271
    timers('model and optimizer').stop()
272
273
274
275

    # If pretrained checkpoint is provided and we have not trained for
    # any iteration (i.e., iteration is zero), then load the pretrained
    # checkpoint.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
276
    timers('pretrained checkpoint', log_level=0).start(barrier=True)
277
278
279
    if args.iteration == 0 and args.pretrained_checkpoint is not None:
        original_load = args.load
        args.load = args.pretrained_checkpoint
Mostofa Patwary's avatar
Mostofa Patwary committed
280
281
        original_rng = args.no_load_rng
        args.no_load_rng = True
Mohammad's avatar
Mohammad committed
282
        _ = load_checkpoint(model, None, None)
283
        args.load = original_load
Mostofa Patwary's avatar
Mostofa Patwary committed
284
        args.no_load_rng = original_rng
285
        # This is critical when only model is loaded. We should make sure
286
        # main parameters are also updated.
287
        optimizer.reload_model_params()
Mohammad's avatar
Mohammad committed
288
    timers('pretrained checkpoint').stop()
289

Mohammad's avatar
Mohammad committed
290
291
292
    # Print setup timing.
    print_rank_0('done with setups ...')
    timers.log(['train/valid/test dataset/dataloder', 'callback function',
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
293
                'model and optimizer', 'pretrained checkpoint'], barrier=True)
Mohammad's avatar
Mohammad committed
294
    print_rank_0('training ...')
295
296
297

    # Finetune the model.
    if args.epochs > 0:
298
        _train(model, optimizer, opt_param_scheduler, forward_step,
Mohammad's avatar
Mohammad committed
299
               train_dataloader, valid_dataloader, end_of_epoch_callback)
300
301
302
303
    # Or just evaluate.
    else:
        if end_of_epoch_callback is not None:
            print_rank_0('evaluation only mode, setting epoch to -1')
Mohammad's avatar
Mohammad committed
304
            end_of_epoch_callback(model, epoch=-1, output_predictions=True)
305
    print_rank_0('done :-)')