finetune_utils.py 12 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2
3
4

"""Finetune utilities."""

Jared Casper's avatar
Jared Casper committed
5
from functools import partial
6
import sys
7
8
import torch

9
from megatron import get_args, get_num_microbatches
Neel Kant's avatar
Neel Kant committed
10
from megatron import print_rank_0
Mohammad's avatar
Mohammad committed
11
from megatron import get_timers
12
from megatron import mpu
Neel Kant's avatar
Neel Kant committed
13
from megatron.checkpointing import load_checkpoint
Mohammad's avatar
Mohammad committed
14
from megatron.checkpointing import save_checkpoint
15
from megatron.model import ModelType
16
17
18
19
from megatron.training import evaluate_and_print_results
from megatron.training import setup_model_and_optimizer
from megatron.training import train_step
from megatron.training import training_log
20
from megatron.utils import average_losses_across_data_parallel_group
mohammad's avatar
mohammad committed
21
22
from megatron.utils import calc_params_l2_norm
from megatron.utils import check_adlr_autoresume_termination
23
24


Mohammad's avatar
Mohammad committed
25
def process_batch(batch):
26
    """Process batch and produce inputs for the model."""
Mohammad's avatar
Mohammad committed
27
    args = get_args()
28
29
30
31
32
33
34
35
36
37
38

    tokens = batch['text'].long().cuda().contiguous()
    types = batch['types'].long().cuda().contiguous()
    labels = batch['label'].long().cuda().contiguous()
    attention_mask = batch['padding_mask'].float().cuda().contiguous()
    if args.fp16:
        attention_mask = attention_mask.half()

    return tokens, types, labels, attention_mask


Jared Casper's avatar
Jared Casper committed
39
40
41
42
43
44
45
46
47
48
49
50
51
52
def cross_entropy_loss_func(labels, output_tensor):
    logits = output_tensor

    # Cross-entropy loss.
    loss_func = torch.nn.CrossEntropyLoss()
    loss = loss_func(logits.contiguous().float(), labels)

    # Reduce loss for logging.
    averaged_loss = average_losses_across_data_parallel_group([loss])

    return loss, {'lm loss': averaged_loss[0]}


def _cross_entropy_forward_step(batch, model):
53
    """Simple forward step with cross-entropy loss."""
Mohammad's avatar
Mohammad committed
54
    timers = get_timers()
55
56

    # Get the batch.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
57
    timers('batch-generator', log_level=2).start()
58
59
    try:
        batch_ = next(batch)
Neel Kant's avatar
Neel Kant committed
60
    except BaseException:
61
        batch_ = batch
Mohammad's avatar
Mohammad committed
62
    tokens, types, labels, attention_mask = process_batch(batch_)
mohammad's avatar
mohammad committed
63
    timers('batch-generator').stop()
64
65

    # Forward model.
Jared Casper's avatar
Jared Casper committed
66
    output_tensor = model(tokens, attention_mask, tokentype_ids=types)
67

Jared Casper's avatar
Jared Casper committed
68
    return output_tensor, partial(cross_entropy_loss_func, labels)
69
70


Mostofa Patwary's avatar
Mostofa Patwary committed
71
def build_data_loader(dataset, micro_batch_size, num_workers, drop_last,
Mostofa Patwary's avatar
Mostofa Patwary committed
72
        task_collate_fn=None):
73
74
75
76
77
78
79
80
81
82
    """Data loader. Note that batch-size is the local (per GPU) batch-size."""

    # Sampler.
    world_size = mpu.get_data_parallel_world_size()
    rank = mpu.get_data_parallel_rank()
    sampler = torch.utils.data.distributed.DistributedSampler(
        dataset, num_replicas=world_size, rank=rank)

    # Data loader. Note that batch size is the per GPU batch size.
    data_loader = torch.utils.data.DataLoader(dataset,
83
                                              batch_size=micro_batch_size,
84
85
86
87
                                              sampler=sampler,
                                              shuffle=False,
                                              num_workers=num_workers,
                                              drop_last=drop_last,
Mostofa Patwary's avatar
Mostofa Patwary committed
88
89
                                              pin_memory=True,
                                              collate_fn=task_collate_fn)
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104

    return data_loader


def _build_infinite_size_dataloader(dataloader):
    """Build a looped dataloader with infinite size."""

    iterator = dataloader.__iter__()
    while True:
        try:
            yield iterator.__next__()
        except StopIteration:
            iterator = dataloader.__iter__()


Mostofa Patwary's avatar
Mostofa Patwary committed
105
106
def _build_train_valid_dataloaders(train_dataset, valid_dataset, 
    task_collate_fn=None):
107
    """Traing and validation dataloaders."""
Mohammad's avatar
Mohammad committed
108
    args = get_args()
109
110
111

    print_rank_0('building train and validation dataloaders ...')
    # Training dataset.
112
    train_dataloader = build_data_loader(train_dataset, args.micro_batch_size,
Mostofa Patwary's avatar
Mostofa Patwary committed
113
114
                                         args.num_workers, not args.keep_last,
                                         task_collate_fn)
115
116
117
118
119
    # Set the training iterations.
    args.train_iters_per_epoch = len(train_dataloader)
    args.train_iters = args.epochs * args.train_iters_per_epoch
    # Validation dataset. For this dataset, we do not need to set up
    # shuffling so we can just use a simple infinite loop.
120
    valid_dataloader_ = build_data_loader(valid_dataset, args.micro_batch_size,
Mostofa Patwary's avatar
Mostofa Patwary committed
121
122
                                          args.num_workers, not args.keep_last,
                                          task_collate_fn)
123
124
    valid_dataloader = _build_infinite_size_dataloader(valid_dataloader_)

125
126
127
128
129
    # Now that we've built the data loaders, set batch_size arguments
    # to the actual batch size the model will see for this dataset.
    # This is necessary so pipeline transfers know what size they are
    # and the LR schedule, which is based on samples seen, gets set
    # correctly.
Jared Casper's avatar
Jared Casper committed
130
131
    args.orig_micro_batch_size = args.micro_batch_size
    args.orig_global_batch_size = args.global_batch_size
132
    if hasattr(train_dataset, 'sample_multiplier'):
133
134
135
136
137
        # If our dataset as a sample_multiplier attribute that means
        # each "sample" from the dataset actually has multiple samples
        # that will collapse into the batch dimension (for example in
        # the RACE dataset that has several options), we need to
        # account for that when setting the micro batch size.
138
        args.micro_batch_size *= train_dataset.sample_multiplier
139
        args.global_batch_size *= train_dataset.sample_multiplier
140

141
142
143
    return train_dataloader, valid_dataloader


144
def _train(model, optimizer, opt_param_scheduler, forward_step,
Mohammad's avatar
Mohammad committed
145
           train_dataloader, valid_dataloader, end_of_epoch_callback):
146
    """Train the model."""
Mohammad's avatar
Mohammad committed
147
148
    args = get_args()
    timers = get_timers()
149

150
151
    assert get_num_microbatches() == 1, "finetuning with gradient accumulation doesn't currently work"

152
    # Turn on training mode which enables dropout.
Jared Casper's avatar
Jared Casper committed
153
154
    for m in model:
        m.train()
155
156
157
158
159
160
161
162
163
164
165
166
167

    # Tracking loss.
    losses_dict_sum = {}

    # Starting epoch and iteration
    start_epoch = args.iteration // args.train_iters_per_epoch
    start_iteration = args.iteration % args.train_iters_per_epoch
    iteration = args.iteration

    # Memory reporting flag.
    report_memory_flag = True

    # For each remaining epoch
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
168
    timers('interval-time', log_level=0).start(barrier=True)
169
    for epoch in range(start_epoch, args.epochs):
Neel Kant's avatar
Neel Kant committed
170
        print_rank_0('working on epoch {} ...'.format(epoch + 1))
171
172
173
174
175
176
177
178
179
180
181
182
183
184

        # Set the data loader epoch to shuffle the index iterator.
        train_dataloader.sampler.set_epoch(args.seed + epoch)

        # For all the batches in the dataset.
        for iteration_, batch in enumerate(train_dataloader):

            # Ignore the iterations before starting value
            if iteration_ < start_iteration:
                continue
            # Set to zero so the next epoch does not skip any batches.
            start_iteration = 0

            # Train for one step.
185
            out = train_step(forward_step, batch, model, optimizer, opt_param_scheduler)
Mostofa Patwary's avatar
Mostofa Patwary committed
186

Jared Casper's avatar
Jared Casper committed
187
            losses_dict, skipped_iter, grad_norm, num_zeros_in_grad = out
188
189
190
            iteration += 1

            # Logging.
191
192
193
            params_norm = None
            if args.log_params_norm:
                params_norm = calc_params_l2_norm(model)
194
195
            report_memory_flag = training_log(losses_dict, losses_dict_sum,
                                              optimizer.param_groups[0]['lr'],
196
197
                                              iteration,
                                              optimizer.get_loss_scale().item(),
198
                                              report_memory_flag, skipped_iter,
Jared Casper's avatar
Jared Casper committed
199
                                              grad_norm, params_norm, num_zeros_in_grad)
200
201

            # Autoresume
Neel Kant's avatar
Neel Kant committed
202
            if args.adlr_autoresume and \
203
               (iteration % args.adlr_autoresume_interval == 0):
Mohammad's avatar
Mohammad committed
204
                check_adlr_autoresume_termination(iteration, model,
205
                                                  optimizer, opt_param_scheduler)
206
207

            # Checkpointing
208
            saved_checkpoint = False
209
210
            if args.save and args.save_interval and \
               iteration % args.save_interval == 0:
211
                save_checkpoint(iteration, model, optimizer, opt_param_scheduler)
212
                saved_checkpoint = True
213
214
215
216
217

            # Evaluation
            if args.eval_interval and iteration % args.eval_interval == 0:
                prefix = 'iteration {}'.format(iteration)
                evaluate_and_print_results(prefix, forward_step,
Mohammad's avatar
Mohammad committed
218
                                           valid_dataloader, model,
219
                                           iteration, None, False)
220

221
222
223
            # Exiting based on iterations
            if args.exit_interval and iteration % args.exit_interval == 0:
                if not saved_checkpoint:
224
                    save_checkpoint(iteration, model, optimizer, opt_param_scheduler)
225
226
227
228
                torch.distributed.barrier()
                print_rank_0('exiting program at iteration {}'.format(iteration))
                sys.exit()

229
230
        # Checkpointing at the end of each epoch.
        if args.save:
231
            save_checkpoint(iteration, model, optimizer, opt_param_scheduler)
232
233
234

        # Callback at the end of each epoch.
        if end_of_epoch_callback is not None:
Mohammad's avatar
Mohammad committed
235
            end_of_epoch_callback(model, epoch)
236
237


Mohammad's avatar
Mohammad committed
238
def finetune(train_valid_datasets_provider, model_provider,
239
             model_type=ModelType.encoder_or_decoder,
240
             forward_step=_cross_entropy_forward_step,
Mostofa Patwary's avatar
Mostofa Patwary committed
241
242
             end_of_epoch_callback_provider=None,
             task_collate_fn=None):
243
    """Main finetune function used across all tasks."""
Mohammad's avatar
Mohammad committed
244
245
    args = get_args()
    timers = get_timers()
246

247
248
249
    assert args.rampup_batch_size is None, \
        'batch size scaling is not supported for finetuning'

250
    # Train and validation data loaders.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
251
    timers('train/valid/test dataset/dataloder', log_level=0).start()
252
    if args.epochs > 0:
Mohammad's avatar
Mohammad committed
253
        train_dataset, valid_dataset = train_valid_datasets_provider()
254
        train_dataloader, valid_dataloader = _build_train_valid_dataloaders(
Mostofa Patwary's avatar
Mostofa Patwary committed
255
            train_dataset, valid_dataset, task_collate_fn)
256
257
    else:
        args.train_iters = 0
Mohammad's avatar
Mohammad committed
258
    timers('train/valid/test dataset/dataloder').stop()
259
260

    # Build calback function.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
261
    timers('callback function', log_level=0).start()
262
263
    end_of_epoch_callback = None
    if end_of_epoch_callback_provider is not None:
Mohammad's avatar
Mohammad committed
264
265
        end_of_epoch_callback = end_of_epoch_callback_provider()
    timers('callback function').stop()
266
267

    # Build model, optimizer and learning rate scheduler.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
268
    timers('model and optimizer', log_level=0).start()
269
    model, optimizer, opt_param_scheduler = setup_model_and_optimizer(model_provider, model_type)
Mohammad's avatar
Mohammad committed
270
    timers('model and optimizer').stop()
271
272
273
274

    # If pretrained checkpoint is provided and we have not trained for
    # any iteration (i.e., iteration is zero), then load the pretrained
    # checkpoint.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
275
    timers('pretrained checkpoint', log_level=0).start(barrier=True)
276
277
278
    if args.iteration == 0 and args.pretrained_checkpoint is not None:
        original_load = args.load
        args.load = args.pretrained_checkpoint
Mostofa Patwary's avatar
Mostofa Patwary committed
279
280
        original_rng = args.no_load_rng
        args.no_load_rng = True
Mohammad's avatar
Mohammad committed
281
        _ = load_checkpoint(model, None, None)
282
        args.load = original_load
Mostofa Patwary's avatar
Mostofa Patwary committed
283
        args.no_load_rng = original_rng
284
        # This is critical when only model is loaded. We should make sure
285
        # main parameters are also updated.
286
        optimizer.reload_model_params()
Mohammad's avatar
Mohammad committed
287
    timers('pretrained checkpoint').stop()
288

Mohammad's avatar
Mohammad committed
289
290
291
    # Print setup timing.
    print_rank_0('done with setups ...')
    timers.log(['train/valid/test dataset/dataloder', 'callback function',
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
292
                'model and optimizer', 'pretrained checkpoint'], barrier=True)
Mohammad's avatar
Mohammad committed
293
    print_rank_0('training ...')
294
295
296

    # Finetune the model.
    if args.epochs > 0:
297
        _train(model, optimizer, opt_param_scheduler, forward_step,
Mohammad's avatar
Mohammad committed
298
               train_dataloader, valid_dataloader, end_of_epoch_callback)
299
300
301
302
    # Or just evaluate.
    else:
        if end_of_epoch_callback is not None:
            print_rank_0('evaluation only mode, setting epoch to -1')
Mohammad's avatar
Mohammad committed
303
            end_of_epoch_callback(model, epoch=-1, output_predictions=True)
304
    print_rank_0('done :-)')