finetune_utils.py 10.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Finetune utilities."""

import torch
import torch.nn.functional as F
Vijay Korthikanti's avatar
Vijay Korthikanti committed
20
from functools import partial
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from megatron import get_args
from megatron import print_rank_0
from megatron import get_timers
from megatron import mpu
from megatron.checkpointing import load_checkpoint
from megatron.checkpointing import save_checkpoint
from megatron.training import evaluate_and_print_results
from megatron.training import setup_model_and_optimizer
from megatron.training import train_step
from megatron.training import training_log
from megatron.utils import check_adlr_autoresume_termination
from megatron.utils import average_losses_across_data_parallel_group


def process_batch(batch):
    """Process batch and produce inputs for the model."""
Vijay Korthikanti's avatar
Vijay Korthikanti committed
37
38
    images = batch[0].cuda().contiguous()
    labels = batch[1].cuda().contiguous()
39
40
41
    return images, labels


Vijay Korthikanti's avatar
Vijay Korthikanti committed
42
43
44
45
46
47
48
49
50
51
52
53
54
def cross_entropy_loss_func(labels, output_tensor):
    logits = output_tensor

    # Cross-entropy loss.
    loss = F.cross_entropy(logits.contiguous().float(), labels)

    # Reduce loss for logging.
    averaged_loss = average_losses_across_data_parallel_group([loss])

    return loss, {'lm loss': averaged_loss[0]}


def _cross_entropy_forward_step(batch, model):
55
56
57
58
59
60
61
62
63
64
65
66
    """Simple forward step with cross-entropy loss."""
    timers = get_timers()

    # Get the batch.
    timers("batch generator").start()
    try:
        batch_ = next(batch)
    except BaseException:
        batch_ = batch
    images, labels = process_batch(batch_)
    timers("batch generator").stop()

Vijay Korthikanti's avatar
Vijay Korthikanti committed
67
68
69
70
   # Forward model.
    output_tensor = model(images)
  
    return output_tensor, partial(cross_entropy_loss_func, labels)
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111


def build_data_loader(dataset, micro_batch_size, num_workers, drop_last):
    """Data loader. Note that batch-size is the local (per GPU) batch-size."""

    # Sampler.
    world_size = mpu.get_data_parallel_world_size()
    rank = mpu.get_data_parallel_rank()
    sampler = torch.utils.data.distributed.DistributedSampler(
        dataset, num_replicas=world_size, rank=rank
    )

    # Data loader. Note that batch size is the per GPU batch size.
    data_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=micro_batch_size,
        sampler=sampler,
        shuffle=False,
        num_workers=num_workers,
        drop_last=drop_last,
        pin_memory=True,
    )

    return data_loader


def _build_infinite_size_dataloader(dataloader):
    """Build a looped dataloader with infinite size."""

    iterator = dataloader.__iter__()
    while True:
        try:
            yield iterator.__next__()
        except StopIteration:
            iterator = dataloader.__iter__()


def _build_train_valid_dataloaders(train_dataset, valid_dataset):
    """Traing and validation dataloaders."""
    args = get_args()

Vijay Korthikanti's avatar
Vijay Korthikanti committed
112
    print_rank_0('building train and validation dataloaders ...')
113
    # Training dataset.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
114
115
    train_dataloader = build_data_loader(train_dataset, args.micro_batch_size,
                                           args.num_workers, not args.keep_last)
116
117
118
119
120
    # Set the training iterations.
    args.train_iters_per_epoch = len(train_dataloader)
    args.train_iters = args.epochs * args.train_iters_per_epoch
    # Validation dataset. For this dataset, we do not need to set up
    # shuffling so we can just use a simple infinite loop.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
121
122
    valid_dataloader_ = build_data_loader(valid_dataset, args.micro_batch_size,
                                            args.num_workers, not args.keep_last)
123
124
    valid_dataloader = _build_infinite_size_dataloader(valid_dataloader_)

Vijay Korthikanti's avatar
Vijay Korthikanti committed
125
126
127
128
129
130
131
    # Now that we've built the data loaders, set batch_size arguments
    # to the actual batch size the model will see for this dataset.
    # This is necessary so pipeline transfers know what size they are
    # and the LR schedule, which is based on samples seen, gets set
    # correctly.
    args.orig_micro_batch_size = args.micro_batch_size
    args.orig_global_batch_size = args.global_batch_size
132

Vijay Korthikanti's avatar
Vijay Korthikanti committed
133
    return train_dataloader, valid_dataloader
134
135
136
137

def _train(
    model,
    optimizer,
138
    opt_param_scheduler,
139
140
141
142
143
144
145
146
147
148
    forward_step,
    train_dataloader,
    valid_dataloader,
    end_of_epoch_callback,
):
    """Train the model."""
    args = get_args()
    timers = get_timers()

    # Turn on training mode which enables dropout.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
149
150
    for m in model:
        m.train()
151
152
153
154
155
156
157
158
159
160
161
162
163

    # Tracking loss.
    losses_dict_sum = {}

    # Starting epoch and iteration
    start_epoch = args.iteration // args.train_iters_per_epoch
    start_iteration = args.iteration % args.train_iters_per_epoch
    iteration = args.iteration

    # Memory reporting flag.
    report_memory_flag = True

    # For each remaining epoch
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
164
    timers("interval-time").start()
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
    for epoch in range(start_epoch, args.epochs):
        print_rank_0("working on epoch {} ...".format(epoch + 1))

        # Set the data loader epoch to shuffle the index iterator.
        train_dataloader.sampler.set_epoch(args.seed + epoch)

        # For all the batches in the dataset.
        for iteration_, batch in enumerate(train_dataloader):

            # Ignore the iterations before starting value
            if iteration_ < start_iteration:
                continue
            # Set to zero so the next epoch does not skip any batches.
            start_iteration = 0

            # Train for one step.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
181
            losses_dict, skipped_iter, grad_norm, num_zeros_in_grad = train_step(
182
                forward_step, batch, model, optimizer, opt_param_scheduler
183
184
185
186
            )
            iteration += 1

            # Logging.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
187
188
189
190
            params_norm = None
            if args.log_params_norm:
                params_norm = calc_params_l2_norm(model)

191
192
193
194
195
196
197
198
            report_memory_flag = training_log(
                losses_dict,
                losses_dict_sum,
                optimizer.param_groups[0]["lr"],
                iteration,
                optimizer.get_loss_scale().item(),
                report_memory_flag,
                skipped_iter,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
199
200
201
                grad_norm,
                params_norm,
                num_zeros_in_grad
202
203
204
205
206
207
208
            )

            # Autoresume
            if args.adlr_autoresume and (
                iteration % args.adlr_autoresume_interval == 0
            ):
                check_adlr_autoresume_termination(
209
                    iteration, model, optimizer, opt_param_scheduler
210
211
212
213
214
215
216
217
                )

            # Checkpointing
            if (
                args.save
                and args.save_interval
                and iteration % args.save_interval == 0
            ):
218
                save_checkpoint(iteration, model, optimizer, opt_param_scheduler)
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233

            # Evaluation
            if args.eval_interval and iteration % args.eval_interval == 0:
                prefix = "iteration {}".format(iteration)
                evaluate_and_print_results(
                    prefix,
                    forward_step,
                    valid_dataloader,
                    model,
                    iteration,
                    False,
                )

        # Checkpointing at the end of each epoch.
        if args.save:
234
            save_checkpoint(iteration, model, optimizer, opt_param_scheduler)
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268

        # Callback at the end of each epoch.
        if end_of_epoch_callback is not None:
            end_of_epoch_callback(model, epoch)


def finetune(
    train_valid_datasets_provider,
    model_provider,
    forward_step=_cross_entropy_forward_step,
    end_of_epoch_callback_provider=None,
):
    """Main finetune function used across all tasks."""
    args = get_args()
    timers = get_timers()

    # Train and validation data loaders.
    timers("train/valid/test dataset/dataloder").start()
    if args.epochs > 0:
        train_dataset, valid_dataset = train_valid_datasets_provider()
        train_dataloader, valid_dataloader = _build_train_valid_dataloaders(
            train_dataset, valid_dataset
        )
    timers("train/valid/test dataset/dataloder").stop()

    # Build calback function.
    timers("callback function").start()
    end_of_epoch_callback = None
    if end_of_epoch_callback_provider is not None:
        end_of_epoch_callback = end_of_epoch_callback_provider()
    timers("callback function").stop()

    # Build model, optimizer and learning rate scheduler.
    timers("model and optimizer").start()
269
    model, optimizer, opt_param_scheduler = setup_model_and_optimizer(model_provider)
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
    timers("model and optimizer").stop()

    # If pretrained checkpoint is provided and we have not trained for
    # any iteration (i.e., iteration is zero), then load the pretrained
    # checkpoint.
    timers("pretrained checkpoint").start()
    if args.iteration == 0 and args.pretrained_checkpoint is not None:
        original_load = args.load
        args.load = args.pretrained_checkpoint
        _ = load_checkpoint(model, None, None, strict=False)
        args.load = original_load
        # This is critical when only model is loaded. We should make sure
        # master parameters are also updated.
        optimizer.reload_model_params()
    timers("pretrained checkpoint").stop()

    # Print setup timing.
    print_rank_0("done with setups ...")
    timers.log(
        [
            "train/valid/test dataset/dataloder",
            "callback function",
            "model and optimizer",
            "pretrained checkpoint",
        ]
    )
    print_rank_0("training ...")

    # Finetune the model.
    if args.epochs > 0:
        _train(
            model,
            optimizer,
303
            opt_param_scheduler,
304
305
306
307
308
309
310
311
312
313
314
315
            forward_step,
            train_dataloader,
            valid_dataloader,
            end_of_epoch_callback,
        )
    # Or just evaluate.
    else:
        if end_of_epoch_callback is not None:
            print_rank_0("evaluation only mode, setting epoch to -1")
            end_of_epoch_callback(model, epoch=-1, output_predictions=True)

    print_rank_0("done :-)")