finetune_utils.py 10.5 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2
3
4
5
6

"""Finetune utilities."""

import torch
import torch.nn.functional as F
xingjinliang's avatar
xingjinliang committed
7
8
9
10
from megatron.training import get_args
from megatron.training import print_rank_0
from megatron.training import get_timers
from megatron.training import utils
11
from megatron.core import mpu
xingjinliang's avatar
xingjinliang committed
12
13
from megatron.training.checkpointing import load_checkpoint
from megatron.training.checkpointing import save_checkpoint
14
15
16
17
from megatron.training import evaluate_and_print_results
from megatron.training import setup_model_and_optimizer
from megatron.training import train_step
from megatron.training import training_log
xingjinliang's avatar
xingjinliang committed
18
19
from megatron.training.utils import check_adlr_autoresume_termination
from megatron.training.utils import average_losses_across_data_parallel_group, print_params_min_max_norm
20
from megatron.core.enums import ModelType
21
22
23

def process_batch(batch):
    """Process batch and produce inputs for the model."""
Vijay Korthikanti's avatar
Vijay Korthikanti committed
24
25
    images = batch[0].cuda().contiguous()
    labels = batch[1].cuda().contiguous()
26
27
28
    return images, labels


29
30
def build_data_loader(dataset, micro_batch_size,
                      num_workers, drop_last, shuffle):
31
32
33
34
35
36
    """Data loader. Note that batch-size is the local (per GPU) batch-size."""

    # Sampler.
    world_size = mpu.get_data_parallel_world_size()
    rank = mpu.get_data_parallel_rank()
    sampler = torch.utils.data.distributed.DistributedSampler(
37
38
        dataset, num_replicas=world_size, rank=rank,
        drop_last=drop_last, shuffle=shuffle
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
    )

    # Data loader. Note that batch size is the per GPU batch size.
    data_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=micro_batch_size,
        sampler=sampler,
        shuffle=False,
        num_workers=num_workers,
        drop_last=drop_last,
        pin_memory=True,
    )

    return data_loader


def _build_infinite_size_dataloader(dataloader):
    """Build a looped dataloader with infinite size."""

    iterator = dataloader.__iter__()
    while True:
        try:
            yield iterator.__next__()
        except StopIteration:
            iterator = dataloader.__iter__()


def _build_train_valid_dataloaders(train_dataset, valid_dataset):
    """Traing and validation dataloaders."""
    args = get_args()

Vijay Korthikanti's avatar
Vijay Korthikanti committed
70
    print_rank_0('building train and validation dataloaders ...')
71
    # Training dataset.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
72
    train_dataloader = build_data_loader(train_dataset, args.micro_batch_size,
73
                                         args.num_workers, False, True)
74
75
76
77
78
    # Set the training iterations.
    args.train_iters_per_epoch = len(train_dataloader)
    args.train_iters = args.epochs * args.train_iters_per_epoch
    # Validation dataset. For this dataset, we do not need to set up
    # shuffling so we can just use a simple infinite loop.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
79
    valid_dataloader_ = build_data_loader(valid_dataset, args.micro_batch_size,
80
                                          args.num_workers, True,  False)
81
82
    valid_dataloader = _build_infinite_size_dataloader(valid_dataloader_)

Vijay Korthikanti's avatar
Vijay Korthikanti committed
83
84
85
86
87
88
89
    # Now that we've built the data loaders, set batch_size arguments
    # to the actual batch size the model will see for this dataset.
    # This is necessary so pipeline transfers know what size they are
    # and the LR schedule, which is based on samples seen, gets set
    # correctly.
    args.orig_micro_batch_size = args.micro_batch_size
    args.orig_global_batch_size = args.global_batch_size
90

Vijay Korthikanti's avatar
Vijay Korthikanti committed
91
    return train_dataloader, valid_dataloader
92

93

94
95
96
def _train(
    model,
    optimizer,
97
    opt_param_scheduler,
98
99
100
101
    forward_step,
    train_dataloader,
    valid_dataloader,
    end_of_epoch_callback,
102
    process_non_loss_data_func=None
103
104
105
106
107
108
):
    """Train the model."""
    args = get_args()
    timers = get_timers()

    # Turn on training mode which enables dropout.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
109
110
    for m in model:
        m.train()
111
112
113
114
115
116
117
118
119
120
121
122
123

    # Tracking loss.
    losses_dict_sum = {}

    # Starting epoch and iteration
    start_epoch = args.iteration // args.train_iters_per_epoch
    start_iteration = args.iteration % args.train_iters_per_epoch
    iteration = args.iteration

    # Memory reporting flag.
    report_memory_flag = True

    # For each remaining epoch
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
124
    timers("interval-time", log_level=0).start(barrier=True)
125
126
127
128
129
    for epoch in range(start_epoch, args.epochs):
        print_rank_0("working on epoch {} ...".format(epoch + 1))

        # Set the data loader epoch to shuffle the index iterator.
        train_dataloader.sampler.set_epoch(args.seed + epoch)
130
        train_dataloader.dataset.set_epoch(epoch)
131
132
133
134
135
136
137
138
139
140
141

        # For all the batches in the dataset.
        for iteration_, batch in enumerate(train_dataloader):

            # Ignore the iterations before starting value
            if iteration_ < start_iteration:
                continue
            # Set to zero so the next epoch does not skip any batches.
            start_iteration = 0

            # Train for one step.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
142
            losses_dict, skipped_iter, grad_norm, num_zeros_in_grad = train_step(
143
                forward_step, batch, model, optimizer, opt_param_scheduler
144
145
146
147
            )
            iteration += 1

            # Logging.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
148
149
            params_norm = None

150
151
152
153
154
155
156
157
            report_memory_flag = training_log(
                losses_dict,
                losses_dict_sum,
                optimizer.param_groups[0]["lr"],
                iteration,
                optimizer.get_loss_scale().item(),
                report_memory_flag,
                skipped_iter,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
158
159
160
                grad_norm,
                params_norm,
                num_zeros_in_grad
161
162
163
            )

            # Autoresume
164
165
166
167
            if args.adlr_autoresume and \
                    iteration % args.adlr_autoresume_interval == 0:
                check_adlr_autoresume_termination(iteration, model, optimizer,
                                                  opt_param_scheduler)
168
169

            # Checkpointing
170
171
172
173
            if args.save and args.save_interval and \
                    iteration % args.save_interval == 0:
                save_checkpoint(iteration, model, optimizer,
                                opt_param_scheduler)
174
175
176
177
178
179
180
181
182
183

            # Evaluation
            if args.eval_interval and iteration % args.eval_interval == 0:
                prefix = "iteration {}".format(iteration)
                evaluate_and_print_results(
                    prefix,
                    forward_step,
                    valid_dataloader,
                    model,
                    iteration,
184
                    process_non_loss_data_func,
185
186
187
188
189
190
191
192
193
194
195
                    False,
                )

        # Callback at the end of each epoch.
        if end_of_epoch_callback is not None:
            end_of_epoch_callback(model, epoch)


def finetune(
    train_valid_datasets_provider,
    model_provider,
196
197
198
    forward_step,
    model_type=ModelType.encoder_or_decoder,
    process_non_loss_data_func=None,
199
200
201
202
203
204
205
    end_of_epoch_callback_provider=None,
):
    """Main finetune function used across all tasks."""
    args = get_args()
    timers = get_timers()

    # Train and validation data loaders.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
206
    timers("train/valid/test dataset/dataloder", log_level=0).start()
207
208
209
210
211
212
213
214
    if args.epochs > 0:
        train_dataset, valid_dataset = train_valid_datasets_provider()
        train_dataloader, valid_dataloader = _build_train_valid_dataloaders(
            train_dataset, valid_dataset
        )
    timers("train/valid/test dataset/dataloder").stop()

    # Build calback function.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
215
    timers("callback function", log_level=0).start()
216
217
218
219
220
221
    end_of_epoch_callback = None
    if end_of_epoch_callback_provider is not None:
        end_of_epoch_callback = end_of_epoch_callback_provider()
    timers("callback function").stop()

    # Build model, optimizer and learning rate scheduler.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
222
    timers("model and optimizer", log_level=0).start()
223
224
225
226
227
228
    model, optimizer, opt_param_scheduler = \
        setup_model_and_optimizer(
            model_provider,
            model_type,
            scale_lr_cond=lambda name, param: ".head." in name,
            lr_mult=args.head_lr_mult)
229
230
231
232
233
    timers("model and optimizer").stop()

    # If pretrained checkpoint is provided and we have not trained for
    # any iteration (i.e., iteration is zero), then load the pretrained
    # checkpoint.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
234
    timers("pretrained checkpoint", log_level=0).start(barrier=True)
235
    if args.iteration == 0 and args.pretrained_checkpoint is not None:
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
        if args.pretrained_checkpoint_type == 'default':
            original_load = args.load
            args.load = args.pretrained_checkpoint
            _ = load_checkpoint(model, None, None, strict=False)
            args.load = original_load
        elif args.pretrained_checkpoint_type == 'external':
            unwrap_model = utils.unwrap_model(model)
            state_dict = torch.load(args.pretrained_checkpoint,
                                    map_location="cpu")
            unwrap_model[0].module.backbone.load_state_dict(state_dict,
                                                            strict=False)
        elif args.pretrained_checkpoint_type == 'constrastive':
            unwrap_model = utils.unwrap_model(model)
            state_dict = torch.load(args.pretrained_checkpoint,
                                    map_location="cpu")
            state_dict = state_dict["model"]
            state_dict = {k.replace("teacher.backbone.", ""): v
                          for k, v in state_dict.items()
                          if k.startswith("teacher.backbone.")}
            unwrap_model[0].module.backbone.load_state_dict(state_dict,
                                                            strict=False)
        else:
            raise Exception("pretrained checkpoint type {} not supported".format(args.pretrained_checkpoint_type))

260
261
262
        # This is critical when only model is loaded. We should make sure
        # master parameters are also updated.
        optimizer.reload_model_params()
263

264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
    timers("pretrained checkpoint").stop()

    # Print setup timing.
    print_rank_0("done with setups ...")
    timers.log(
        [
            "train/valid/test dataset/dataloder",
            "callback function",
            "model and optimizer",
            "pretrained checkpoint",
        ]
    )
    print_rank_0("training ...")

    # Finetune the model.
    if args.epochs > 0:
        _train(
            model,
            optimizer,
283
            opt_param_scheduler,
284
285
286
287
            forward_step,
            train_dataloader,
            valid_dataloader,
            end_of_epoch_callback,
288
            process_non_loss_data_func,
289
290
291
292
293
        )
    # Or just evaluate.
    else:
        if end_of_epoch_callback is not None:
            print_rank_0("evaluation only mode, setting epoch to -1")
294
            end_of_epoch_callback(model, epoch=-1)
295
296

    print_rank_0("done :-)")
297