finetune_utils.py 10.6 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2
3
4
5
6
7
8
9

"""Finetune utilities."""

import torch
import torch.nn.functional as F
from megatron import get_args
from megatron import print_rank_0
from megatron import get_timers
10
from megatron import mpu, utils
11
12
13
14
15
16
17
from megatron.checkpointing import load_checkpoint
from megatron.checkpointing import save_checkpoint
from megatron.training import evaluate_and_print_results
from megatron.training import setup_model_and_optimizer
from megatron.training import train_step
from megatron.training import training_log
from megatron.utils import check_adlr_autoresume_termination
18
19
20
21
from megatron.utils import average_losses_across_data_parallel_group, print_params_min_max_norm
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module, ModelType
22
23
24
25


def process_batch(batch):
    """Process batch and produce inputs for the model."""
Vijay Korthikanti's avatar
Vijay Korthikanti committed
26
27
    images = batch[0].cuda().contiguous()
    labels = batch[1].cuda().contiguous()
28
29
30
    return images, labels


31
32
def build_data_loader(dataset, micro_batch_size,
                      num_workers, drop_last, shuffle):
33
34
35
36
37
38
    """Data loader. Note that batch-size is the local (per GPU) batch-size."""

    # Sampler.
    world_size = mpu.get_data_parallel_world_size()
    rank = mpu.get_data_parallel_rank()
    sampler = torch.utils.data.distributed.DistributedSampler(
39
40
        dataset, num_replicas=world_size, rank=rank,
        drop_last=drop_last, shuffle=shuffle
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
    )

    # Data loader. Note that batch size is the per GPU batch size.
    data_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=micro_batch_size,
        sampler=sampler,
        shuffle=False,
        num_workers=num_workers,
        drop_last=drop_last,
        pin_memory=True,
    )

    return data_loader


def _build_infinite_size_dataloader(dataloader):
    """Build a looped dataloader with infinite size."""

    iterator = dataloader.__iter__()
    while True:
        try:
            yield iterator.__next__()
        except StopIteration:
            iterator = dataloader.__iter__()


def _build_train_valid_dataloaders(train_dataset, valid_dataset):
    """Traing and validation dataloaders."""
    args = get_args()

Vijay Korthikanti's avatar
Vijay Korthikanti committed
72
    print_rank_0('building train and validation dataloaders ...')
73
    # Training dataset.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
74
    train_dataloader = build_data_loader(train_dataset, args.micro_batch_size,
75
                                         args.num_workers, False, True)
76
77
78
79
80
    # Set the training iterations.
    args.train_iters_per_epoch = len(train_dataloader)
    args.train_iters = args.epochs * args.train_iters_per_epoch
    # Validation dataset. For this dataset, we do not need to set up
    # shuffling so we can just use a simple infinite loop.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
81
    valid_dataloader_ = build_data_loader(valid_dataset, args.micro_batch_size,
82
                                          args.num_workers, True,  False)
83
84
    valid_dataloader = _build_infinite_size_dataloader(valid_dataloader_)

Vijay Korthikanti's avatar
Vijay Korthikanti committed
85
86
87
88
89
90
91
    # Now that we've built the data loaders, set batch_size arguments
    # to the actual batch size the model will see for this dataset.
    # This is necessary so pipeline transfers know what size they are
    # and the LR schedule, which is based on samples seen, gets set
    # correctly.
    args.orig_micro_batch_size = args.micro_batch_size
    args.orig_global_batch_size = args.global_batch_size
92

Vijay Korthikanti's avatar
Vijay Korthikanti committed
93
    return train_dataloader, valid_dataloader
94

95

96
97
98
def _train(
    model,
    optimizer,
99
    opt_param_scheduler,
100
101
102
103
    forward_step,
    train_dataloader,
    valid_dataloader,
    end_of_epoch_callback,
104
    process_non_loss_data_func=None
105
106
107
108
109
110
):
    """Train the model."""
    args = get_args()
    timers = get_timers()

    # Turn on training mode which enables dropout.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
111
112
    for m in model:
        m.train()
113
114
115
116
117
118
119
120
121
122
123
124
125

    # Tracking loss.
    losses_dict_sum = {}

    # Starting epoch and iteration
    start_epoch = args.iteration // args.train_iters_per_epoch
    start_iteration = args.iteration % args.train_iters_per_epoch
    iteration = args.iteration

    # Memory reporting flag.
    report_memory_flag = True

    # For each remaining epoch
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
126
    timers("interval-time", log_level=0).start(barrier=True)
127
128
129
130
131
    for epoch in range(start_epoch, args.epochs):
        print_rank_0("working on epoch {} ...".format(epoch + 1))

        # Set the data loader epoch to shuffle the index iterator.
        train_dataloader.sampler.set_epoch(args.seed + epoch)
132
        train_dataloader.dataset.set_epoch(epoch)
133
134
135
136
137
138
139
140
141
142
143

        # For all the batches in the dataset.
        for iteration_, batch in enumerate(train_dataloader):

            # Ignore the iterations before starting value
            if iteration_ < start_iteration:
                continue
            # Set to zero so the next epoch does not skip any batches.
            start_iteration = 0

            # Train for one step.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
144
            losses_dict, skipped_iter, grad_norm, num_zeros_in_grad = train_step(
145
                forward_step, batch, model, optimizer, opt_param_scheduler
146
147
148
149
            )
            iteration += 1

            # Logging.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
150
151
            params_norm = None

152
153
154
155
156
157
158
159
            report_memory_flag = training_log(
                losses_dict,
                losses_dict_sum,
                optimizer.param_groups[0]["lr"],
                iteration,
                optimizer.get_loss_scale().item(),
                report_memory_flag,
                skipped_iter,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
160
161
162
                grad_norm,
                params_norm,
                num_zeros_in_grad
163
164
165
            )

            # Autoresume
166
167
168
169
            if args.adlr_autoresume and \
                    iteration % args.adlr_autoresume_interval == 0:
                check_adlr_autoresume_termination(iteration, model, optimizer,
                                                  opt_param_scheduler)
170
171

            # Checkpointing
172
173
174
175
            if args.save and args.save_interval and \
                    iteration % args.save_interval == 0:
                save_checkpoint(iteration, model, optimizer,
                                opt_param_scheduler)
176
177
178
179
180
181
182
183
184
185

            # Evaluation
            if args.eval_interval and iteration % args.eval_interval == 0:
                prefix = "iteration {}".format(iteration)
                evaluate_and_print_results(
                    prefix,
                    forward_step,
                    valid_dataloader,
                    model,
                    iteration,
186
                    process_non_loss_data_func,
187
188
189
190
191
192
193
194
195
196
197
                    False,
                )

        # Callback at the end of each epoch.
        if end_of_epoch_callback is not None:
            end_of_epoch_callback(model, epoch)


def finetune(
    train_valid_datasets_provider,
    model_provider,
198
199
200
    forward_step,
    model_type=ModelType.encoder_or_decoder,
    process_non_loss_data_func=None,
201
202
203
204
205
206
207
    end_of_epoch_callback_provider=None,
):
    """Main finetune function used across all tasks."""
    args = get_args()
    timers = get_timers()

    # Train and validation data loaders.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
208
    timers("train/valid/test dataset/dataloder", log_level=0).start()
209
210
211
212
213
214
215
216
    if args.epochs > 0:
        train_dataset, valid_dataset = train_valid_datasets_provider()
        train_dataloader, valid_dataloader = _build_train_valid_dataloaders(
            train_dataset, valid_dataset
        )
    timers("train/valid/test dataset/dataloder").stop()

    # Build calback function.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
217
    timers("callback function", log_level=0).start()
218
219
220
221
222
223
    end_of_epoch_callback = None
    if end_of_epoch_callback_provider is not None:
        end_of_epoch_callback = end_of_epoch_callback_provider()
    timers("callback function").stop()

    # Build model, optimizer and learning rate scheduler.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
224
    timers("model and optimizer", log_level=0).start()
225
226
227
228
229
230
    model, optimizer, opt_param_scheduler = \
        setup_model_and_optimizer(
            model_provider,
            model_type,
            scale_lr_cond=lambda name, param: ".head." in name,
            lr_mult=args.head_lr_mult)
231
232
233
234
235
    timers("model and optimizer").stop()

    # If pretrained checkpoint is provided and we have not trained for
    # any iteration (i.e., iteration is zero), then load the pretrained
    # checkpoint.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
236
    timers("pretrained checkpoint", log_level=0).start(barrier=True)
237
    if args.iteration == 0 and args.pretrained_checkpoint is not None:
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
        if args.pretrained_checkpoint_type == 'default':
            original_load = args.load
            args.load = args.pretrained_checkpoint
            _ = load_checkpoint(model, None, None, strict=False)
            args.load = original_load
        elif args.pretrained_checkpoint_type == 'external':
            unwrap_model = utils.unwrap_model(model)
            state_dict = torch.load(args.pretrained_checkpoint,
                                    map_location="cpu")
            unwrap_model[0].module.backbone.load_state_dict(state_dict,
                                                            strict=False)
        elif args.pretrained_checkpoint_type == 'constrastive':
            unwrap_model = utils.unwrap_model(model)
            state_dict = torch.load(args.pretrained_checkpoint,
                                    map_location="cpu")
            state_dict = state_dict["model"]
            state_dict = {k.replace("teacher.backbone.", ""): v
                          for k, v in state_dict.items()
                          if k.startswith("teacher.backbone.")}
            unwrap_model[0].module.backbone.load_state_dict(state_dict,
                                                            strict=False)
        else:
            raise Exception("pretrained checkpoint type {} not supported".format(args.pretrained_checkpoint_type))

262
263
264
        # This is critical when only model is loaded. We should make sure
        # master parameters are also updated.
        optimizer.reload_model_params()
265

266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
    timers("pretrained checkpoint").stop()

    # Print setup timing.
    print_rank_0("done with setups ...")
    timers.log(
        [
            "train/valid/test dataset/dataloder",
            "callback function",
            "model and optimizer",
            "pretrained checkpoint",
        ]
    )
    print_rank_0("training ...")

    # Finetune the model.
    if args.epochs > 0:
        _train(
            model,
            optimizer,
285
            opt_param_scheduler,
286
287
288
289
            forward_step,
            train_dataloader,
            valid_dataloader,
            end_of_epoch_callback,
290
            process_non_loss_data_func,
291
292
293
294
295
        )
    # Or just evaluate.
    else:
        if end_of_epoch_callback is not None:
            print_rank_0("evaluation only mode, setting epoch to -1")
296
            end_of_epoch_callback(model, epoch=-1)
297
298

    print_rank_0("done :-)")
299