finetune_utils.py 10.6 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2
3
4
5
6
7
8
9

"""Finetune utilities."""

import torch
import torch.nn.functional as F
from megatron import get_args
from megatron import print_rank_0
from megatron import get_timers
10
11
from megatron import utils
from megatron.core import mpu
12
13
14
15
16
17
18
from megatron.checkpointing import load_checkpoint
from megatron.checkpointing import save_checkpoint
from megatron.training import evaluate_and_print_results
from megatron.training import setup_model_and_optimizer
from megatron.training import train_step
from megatron.training import training_log
from megatron.utils import check_adlr_autoresume_termination
19
20
21
from megatron.utils import average_losses_across_data_parallel_group, print_params_min_max_norm
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron.model import DistributedDataParallel as LocalDDP
22
23
from megatron.model import Float16Module
from megatron.core.enums import ModelType
24
25
26

def process_batch(batch):
    """Process batch and produce inputs for the model."""
Vijay Korthikanti's avatar
Vijay Korthikanti committed
27
28
    images = batch[0].cuda().contiguous()
    labels = batch[1].cuda().contiguous()
29
30
31
    return images, labels


32
33
def build_data_loader(dataset, micro_batch_size,
                      num_workers, drop_last, shuffle):
34
35
36
37
38
39
    """Data loader. Note that batch-size is the local (per GPU) batch-size."""

    # Sampler.
    world_size = mpu.get_data_parallel_world_size()
    rank = mpu.get_data_parallel_rank()
    sampler = torch.utils.data.distributed.DistributedSampler(
40
41
        dataset, num_replicas=world_size, rank=rank,
        drop_last=drop_last, shuffle=shuffle
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
    )

    # Data loader. Note that batch size is the per GPU batch size.
    data_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=micro_batch_size,
        sampler=sampler,
        shuffle=False,
        num_workers=num_workers,
        drop_last=drop_last,
        pin_memory=True,
    )

    return data_loader


def _build_infinite_size_dataloader(dataloader):
    """Build a looped dataloader with infinite size."""

    iterator = dataloader.__iter__()
    while True:
        try:
            yield iterator.__next__()
        except StopIteration:
            iterator = dataloader.__iter__()


def _build_train_valid_dataloaders(train_dataset, valid_dataset):
    """Traing and validation dataloaders."""
    args = get_args()

Vijay Korthikanti's avatar
Vijay Korthikanti committed
73
    print_rank_0('building train and validation dataloaders ...')
74
    # Training dataset.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
75
    train_dataloader = build_data_loader(train_dataset, args.micro_batch_size,
76
                                         args.num_workers, False, True)
77
78
79
80
81
    # Set the training iterations.
    args.train_iters_per_epoch = len(train_dataloader)
    args.train_iters = args.epochs * args.train_iters_per_epoch
    # Validation dataset. For this dataset, we do not need to set up
    # shuffling so we can just use a simple infinite loop.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
82
    valid_dataloader_ = build_data_loader(valid_dataset, args.micro_batch_size,
83
                                          args.num_workers, True,  False)
84
85
    valid_dataloader = _build_infinite_size_dataloader(valid_dataloader_)

Vijay Korthikanti's avatar
Vijay Korthikanti committed
86
87
88
89
90
91
92
    # Now that we've built the data loaders, set batch_size arguments
    # to the actual batch size the model will see for this dataset.
    # This is necessary so pipeline transfers know what size they are
    # and the LR schedule, which is based on samples seen, gets set
    # correctly.
    args.orig_micro_batch_size = args.micro_batch_size
    args.orig_global_batch_size = args.global_batch_size
93

Vijay Korthikanti's avatar
Vijay Korthikanti committed
94
    return train_dataloader, valid_dataloader
95

96

97
98
99
def _train(
    model,
    optimizer,
100
    opt_param_scheduler,
101
102
103
104
    forward_step,
    train_dataloader,
    valid_dataloader,
    end_of_epoch_callback,
105
    process_non_loss_data_func=None
106
107
108
109
110
111
):
    """Train the model."""
    args = get_args()
    timers = get_timers()

    # Turn on training mode which enables dropout.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
112
113
    for m in model:
        m.train()
114
115
116
117
118
119
120
121
122
123
124
125
126

    # Tracking loss.
    losses_dict_sum = {}

    # Starting epoch and iteration
    start_epoch = args.iteration // args.train_iters_per_epoch
    start_iteration = args.iteration % args.train_iters_per_epoch
    iteration = args.iteration

    # Memory reporting flag.
    report_memory_flag = True

    # For each remaining epoch
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
127
    timers("interval-time", log_level=0).start(barrier=True)
128
129
130
131
132
    for epoch in range(start_epoch, args.epochs):
        print_rank_0("working on epoch {} ...".format(epoch + 1))

        # Set the data loader epoch to shuffle the index iterator.
        train_dataloader.sampler.set_epoch(args.seed + epoch)
133
        train_dataloader.dataset.set_epoch(epoch)
134
135
136
137
138
139
140
141
142
143
144

        # For all the batches in the dataset.
        for iteration_, batch in enumerate(train_dataloader):

            # Ignore the iterations before starting value
            if iteration_ < start_iteration:
                continue
            # Set to zero so the next epoch does not skip any batches.
            start_iteration = 0

            # Train for one step.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
145
            losses_dict, skipped_iter, grad_norm, num_zeros_in_grad = train_step(
146
                forward_step, batch, model, optimizer, opt_param_scheduler
147
148
149
150
            )
            iteration += 1

            # Logging.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
151
152
            params_norm = None

153
154
155
156
157
158
159
160
            report_memory_flag = training_log(
                losses_dict,
                losses_dict_sum,
                optimizer.param_groups[0]["lr"],
                iteration,
                optimizer.get_loss_scale().item(),
                report_memory_flag,
                skipped_iter,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
161
162
163
                grad_norm,
                params_norm,
                num_zeros_in_grad
164
165
166
            )

            # Autoresume
167
168
169
170
            if args.adlr_autoresume and \
                    iteration % args.adlr_autoresume_interval == 0:
                check_adlr_autoresume_termination(iteration, model, optimizer,
                                                  opt_param_scheduler)
171
172

            # Checkpointing
173
174
175
176
            if args.save and args.save_interval and \
                    iteration % args.save_interval == 0:
                save_checkpoint(iteration, model, optimizer,
                                opt_param_scheduler)
177
178
179
180
181
182
183
184
185
186

            # Evaluation
            if args.eval_interval and iteration % args.eval_interval == 0:
                prefix = "iteration {}".format(iteration)
                evaluate_and_print_results(
                    prefix,
                    forward_step,
                    valid_dataloader,
                    model,
                    iteration,
187
                    process_non_loss_data_func,
188
189
190
191
192
193
194
195
196
197
198
                    False,
                )

        # Callback at the end of each epoch.
        if end_of_epoch_callback is not None:
            end_of_epoch_callback(model, epoch)


def finetune(
    train_valid_datasets_provider,
    model_provider,
199
200
201
    forward_step,
    model_type=ModelType.encoder_or_decoder,
    process_non_loss_data_func=None,
202
203
204
205
206
207
208
    end_of_epoch_callback_provider=None,
):
    """Main finetune function used across all tasks."""
    args = get_args()
    timers = get_timers()

    # Train and validation data loaders.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
209
    timers("train/valid/test dataset/dataloder", log_level=0).start()
210
211
212
213
214
215
216
217
    if args.epochs > 0:
        train_dataset, valid_dataset = train_valid_datasets_provider()
        train_dataloader, valid_dataloader = _build_train_valid_dataloaders(
            train_dataset, valid_dataset
        )
    timers("train/valid/test dataset/dataloder").stop()

    # Build calback function.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
218
    timers("callback function", log_level=0).start()
219
220
221
222
223
224
    end_of_epoch_callback = None
    if end_of_epoch_callback_provider is not None:
        end_of_epoch_callback = end_of_epoch_callback_provider()
    timers("callback function").stop()

    # Build model, optimizer and learning rate scheduler.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
225
    timers("model and optimizer", log_level=0).start()
226
227
228
229
230
231
    model, optimizer, opt_param_scheduler = \
        setup_model_and_optimizer(
            model_provider,
            model_type,
            scale_lr_cond=lambda name, param: ".head." in name,
            lr_mult=args.head_lr_mult)
232
233
234
235
236
    timers("model and optimizer").stop()

    # If pretrained checkpoint is provided and we have not trained for
    # any iteration (i.e., iteration is zero), then load the pretrained
    # checkpoint.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
237
    timers("pretrained checkpoint", log_level=0).start(barrier=True)
238
    if args.iteration == 0 and args.pretrained_checkpoint is not None:
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
        if args.pretrained_checkpoint_type == 'default':
            original_load = args.load
            args.load = args.pretrained_checkpoint
            _ = load_checkpoint(model, None, None, strict=False)
            args.load = original_load
        elif args.pretrained_checkpoint_type == 'external':
            unwrap_model = utils.unwrap_model(model)
            state_dict = torch.load(args.pretrained_checkpoint,
                                    map_location="cpu")
            unwrap_model[0].module.backbone.load_state_dict(state_dict,
                                                            strict=False)
        elif args.pretrained_checkpoint_type == 'constrastive':
            unwrap_model = utils.unwrap_model(model)
            state_dict = torch.load(args.pretrained_checkpoint,
                                    map_location="cpu")
            state_dict = state_dict["model"]
            state_dict = {k.replace("teacher.backbone.", ""): v
                          for k, v in state_dict.items()
                          if k.startswith("teacher.backbone.")}
            unwrap_model[0].module.backbone.load_state_dict(state_dict,
                                                            strict=False)
        else:
            raise Exception("pretrained checkpoint type {} not supported".format(args.pretrained_checkpoint_type))

263
264
265
        # This is critical when only model is loaded. We should make sure
        # master parameters are also updated.
        optimizer.reload_model_params()
266

267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
    timers("pretrained checkpoint").stop()

    # Print setup timing.
    print_rank_0("done with setups ...")
    timers.log(
        [
            "train/valid/test dataset/dataloder",
            "callback function",
            "model and optimizer",
            "pretrained checkpoint",
        ]
    )
    print_rank_0("training ...")

    # Finetune the model.
    if args.epochs > 0:
        _train(
            model,
            optimizer,
286
            opt_param_scheduler,
287
288
289
290
            forward_step,
            train_dataloader,
            valid_dataloader,
            end_of_epoch_callback,
291
            process_non_loss_data_func,
292
293
294
295
296
        )
    # Or just evaluate.
    else:
        if end_of_epoch_callback is not None:
            print_rank_0("evaluation only mode, setting epoch to -1")
297
            end_of_epoch_callback(model, epoch=-1)
298
299

    print_rank_0("done :-)")
300