"vscode:/vscode.git/clone" did not exist on "72b5f3d0bb79818aa9906b4fd43b75ba15572f45"
train.py 17 KB
Newer Older
1
2
3
4
5
6
import datetime
import os
import time

import torch
import torch.utils.data
7
from torch.utils.data.dataloader import default_collate
8
9
from torch import nn
import torchvision
10
from torchvision.transforms.functional import InterpolationMode
11

12
import presets
13
import transforms
14
15
import utils

16
17
18
19
try:
    from apex import amp
except ImportError:
    amp = None
20

21

22
23
def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch,
                    print_freq, apex=False, model_ema=None):
24
25
26
    model.train()
    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value}'))
27
28
    metric_logger.add_meter('img/s', utils.SmoothedValue(window_size=10, fmt='{value}'))

29
30
    header = 'Epoch: [{}]'.format(epoch)
    for image, target in metric_logger.log_every(data_loader, print_freq, header):
31
        start_time = time.time()
32
33
34
35
36
        image, target = image.to(device), target.to(device)
        output = model(image)
        loss = criterion(output, target)

        optimizer.zero_grad()
37
38
39
40
41
        if apex:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()
42
43
44
45
46
47
48
        optimizer.step()

        acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
        batch_size = image.shape[0]
        metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"])
        metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
        metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
49
        metric_logger.meters['img/s'].update(batch_size / (time.time() - start_time))
50

51
52
    if model_ema:
        model_ema.update_parameters(model)
53

54
55

def evaluate(model, criterion, data_loader, device, print_freq=100, log_suffix=''):
56
57
    model.eval()
    metric_logger = utils.MetricLogger(delimiter="  ")
58
    header = f'Test: {log_suffix}'
59
    with torch.no_grad():
60
        for image, target in metric_logger.log_every(data_loader, print_freq, header):
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
            image = image.to(device, non_blocking=True)
            target = target.to(device, non_blocking=True)
            output = model(image)
            loss = criterion(output, target)

            acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
            # FIXME need to take into account that the datasets
            # could have been padded in distributed setup
            batch_size = image.shape[0]
            metric_logger.update(loss=loss.item())
            metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
            metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
    # gather the stats from all processes
    metric_logger.synchronize_between_processes()

76
    print(f'{header} Acc@1 {metric_logger.acc1.global_avg:.3f} Acc@5 {metric_logger.acc5.global_avg:.3f}')
77
78
79
    return metric_logger.acc1.global_avg


80
81
82
83
84
85
86
87
def _get_cache_path(filepath):
    import hashlib
    h = hashlib.sha1(filepath.encode()).hexdigest()
    cache_path = os.path.join("~", ".torch", "vision", "datasets", "imagefolder", h[:10] + ".pt")
    cache_path = os.path.expanduser(cache_path)
    return cache_path


88
def load_data(traindir, valdir, args):
89
90
    # Data loading code
    print("Loading data")
91
92
93
94
95
96
97
98
99
100
101
102
    resize_size, crop_size = 256, 224
    interpolation = InterpolationMode.BILINEAR
    if args.model == 'inception_v3':
        resize_size, crop_size = 342, 299
    elif args.model.startswith('efficientnet_'):
        sizes = {
            'b0': (256, 224), 'b1': (256, 240), 'b2': (288, 288), 'b3': (320, 300),
            'b4': (384, 380), 'b5': (456, 456), 'b6': (528, 528), 'b7': (600, 600),
        }
        e_type = args.model.replace('efficientnet_', '')
        resize_size, crop_size = sizes[e_type]
        interpolation = InterpolationMode.BICUBIC
103
104
105

    print("Loading training data")
    st = time.time()
106
    cache_path = _get_cache_path(traindir)
107
    if args.cache_dataset and os.path.exists(cache_path):
108
109
110
111
        # Attention, as the transforms are also cached!
        print("Loading dataset_train from {}".format(cache_path))
        dataset, _ = torch.load(cache_path)
    else:
112
113
        auto_augment_policy = getattr(args, "auto_augment", None)
        random_erase_prob = getattr(args, "random_erase", 0.0)
114
115
        dataset = torchvision.datasets.ImageFolder(
            traindir,
116
117
            presets.ClassificationPresetTrain(crop_size=crop_size, auto_augment_policy=auto_augment_policy,
                                              random_erase_prob=random_erase_prob))
118
        if args.cache_dataset:
119
120
121
            print("Saving dataset_train to {}".format(cache_path))
            utils.mkdir(os.path.dirname(cache_path))
            utils.save_on_master((dataset, traindir), cache_path)
122
123
124
    print("Took", time.time() - st)

    print("Loading validation data")
125
    cache_path = _get_cache_path(valdir)
126
    if args.cache_dataset and os.path.exists(cache_path):
127
128
129
130
131
132
        # Attention, as the transforms are also cached!
        print("Loading dataset_test from {}".format(cache_path))
        dataset_test, _ = torch.load(cache_path)
    else:
        dataset_test = torchvision.datasets.ImageFolder(
            valdir,
133
134
            presets.ClassificationPresetEval(crop_size=crop_size, resize_size=resize_size,
                                             interpolation=interpolation))
135
        if args.cache_dataset:
136
137
138
            print("Saving dataset_test to {}".format(cache_path))
            utils.mkdir(os.path.dirname(cache_path))
            utils.save_on_master((dataset_test, valdir), cache_path)
139
140

    print("Creating data loaders")
141
    if args.distributed:
142
143
144
145
146
147
        train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
        test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test)
    else:
        train_sampler = torch.utils.data.RandomSampler(dataset)
        test_sampler = torch.utils.data.SequentialSampler(dataset_test)

148
149
150
151
    return dataset, dataset_test, train_sampler, test_sampler


def main(args):
152
153
154
    if args.apex and amp is None:
        raise RuntimeError("Failed to import apex. Please install apex from https://www.github.com/nvidia/apex "
                           "to enable mixed-precision training.")
155
156
157
158
159
160
161
162
163
164
165
166
167

    if args.output_dir:
        utils.mkdir(args.output_dir)

    utils.init_distributed_mode(args)
    print(args)

    device = torch.device(args.device)

    torch.backends.cudnn.benchmark = True

    train_dir = os.path.join(args.data_path, 'train')
    val_dir = os.path.join(args.data_path, 'val')
168
    dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, args)
169
170
171
172
173
174
175
176
177
178
179

    collate_fn = None
    num_classes = len(dataset.classes)
    mixup_transforms = []
    if args.mixup_alpha > 0.0:
        mixup_transforms.append(transforms.RandomMixup(num_classes, p=1.0, alpha=args.mixup_alpha))
    if args.cutmix_alpha > 0.0:
        mixup_transforms.append(transforms.RandomCutmix(num_classes, p=1.0, alpha=args.cutmix_alpha))
    if mixup_transforms:
        mixupcutmix = torchvision.transforms.RandomChoice(mixup_transforms)
        collate_fn = lambda batch: mixupcutmix(*default_collate(batch))  # noqa: E731
180
181
    data_loader = torch.utils.data.DataLoader(
        dataset, batch_size=args.batch_size,
182
183
        sampler=train_sampler, num_workers=args.workers, pin_memory=True,
        collate_fn=collate_fn)
184
185
186
187
188
    data_loader_test = torch.utils.data.DataLoader(
        dataset_test, batch_size=args.batch_size,
        sampler=test_sampler, num_workers=args.workers, pin_memory=True)

    print("Creating model")
189
    model = torchvision.models.__dict__[args.model](pretrained=args.pretrained, num_classes=num_classes)
190
    model.to(device)
191

192
    if args.distributed and args.sync_bn:
Francisco Massa's avatar
Francisco Massa committed
193
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
194

195
    criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
196

197
    opt_name = args.opt.lower()
198
    if opt_name.startswith("sgd"):
199
        optimizer = torch.optim.SGD(
200
201
            model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay,
            nesterov="nesterov" in opt_name)
202
203
204
205
206
    elif opt_name == 'rmsprop':
        optimizer = torch.optim.RMSprop(model.parameters(), lr=args.lr, momentum=args.momentum,
                                        weight_decay=args.weight_decay, eps=0.0316, alpha=0.9)
    else:
        raise RuntimeError("Invalid optimizer {}. Only SGD and RMSprop are supported.".format(args.opt))
207

208
209
210
211
212
    if args.apex:
        model, optimizer = amp.initialize(model, optimizer,
                                          opt_level=args.apex_opt_level
                                          )

213
214
215
216
217
218
    args.lr_scheduler = args.lr_scheduler.lower()
    if args.lr_scheduler == 'steplr':
        main_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)
    elif args.lr_scheduler == 'cosineannealinglr':
        main_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                                       T_max=args.epochs - args.lr_warmup_epochs)
219
220
    elif args.lr_scheduler == 'exponentiallr':
        main_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=args.lr_gamma)
221
    else:
222
        raise RuntimeError("Invalid lr scheduler '{}'. Only StepLR, CosineAnnealingLR and ExponentialLR "
223
224
225
                           "are supported.".format(args.lr_scheduler))

    if args.lr_warmup_epochs > 0:
226
227
228
229
230
231
232
233
234
        if args.lr_warmup_method == 'linear':
            warmup_lr_scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=args.lr_warmup_decay,
                                                                    total_iters=args.lr_warmup_epochs)
        elif args.lr_warmup_method == 'constant':
            warmup_lr_scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer, factor=args.lr_warmup_decay,
                                                                      total_iters=args.lr_warmup_epochs)
        else:
            raise RuntimeError(f"Invalid warmup lr method '{args.lr_warmup_method}'. Only linear and constant "
                               "are supported.")
235
236
        lr_scheduler = torch.optim.lr_scheduler.SequentialLR(
            optimizer,
237
            schedulers=[warmup_lr_scheduler, main_lr_scheduler],
238
239
240
241
            milestones=[args.lr_warmup_epochs]
        )
    else:
        lr_scheduler = main_lr_scheduler
242
243
244
245
246
247

    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
        model_without_ddp = model.module

248
249
250
251
    model_ema = None
    if args.model_ema:
        model_ema = utils.ExponentialMovingAverage(model_without_ddp, device=device, decay=args.model_ema_decay)

252
253
254
255
256
    if args.resume:
        checkpoint = torch.load(args.resume, map_location='cpu')
        model_without_ddp.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
257
        args.start_epoch = checkpoint['epoch'] + 1
258
259
        if model_ema:
            model_ema.load_state_dict(checkpoint['model_ema'])
260
261
262
263
264
265
266

    if args.test_only:
        evaluate(model, criterion, data_loader_test, device=device)
        return

    print("Start training")
    start_time = time.time()
267
    for epoch in range(args.start_epoch, args.epochs):
268
269
        if args.distributed:
            train_sampler.set_epoch(epoch)
270
        train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args.print_freq, args.apex, model_ema)
Francisco Massa's avatar
Francisco Massa committed
271
        lr_scheduler.step()
272
        evaluate(model, criterion, data_loader_test, device=device)
273
274
        if model_ema:
            evaluate(model_ema, criterion, data_loader_test, device=device, log_suffix='EMA')
275
        if args.output_dir:
276
            checkpoint = {
277
278
279
                'model': model_without_ddp.state_dict(),
                'optimizer': optimizer.state_dict(),
                'lr_scheduler': lr_scheduler.state_dict(),
280
281
                'epoch': epoch,
                'args': args}
282
283
            if model_ema:
                checkpoint['model_ema'] = model_ema.state_dict()
284
285
            utils.save_on_master(
                checkpoint,
286
                os.path.join(args.output_dir, 'model_{}.pth'.format(epoch)))
287
288
289
            utils.save_on_master(
                checkpoint,
                os.path.join(args.output_dir, 'checkpoint.pth'))
290
291
292
293
294
295

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))


296
def get_args_parser(add_help=True):
297
    import argparse
298
    parser = argparse.ArgumentParser(description='PyTorch Classification Training', add_help=add_help)
299
300
301
302
303
304
305
306
307

    parser.add_argument('--data-path', default='/datasets01/imagenet_full_size/061417/', help='dataset')
    parser.add_argument('--model', default='resnet18', help='model')
    parser.add_argument('--device', default='cuda', help='device')
    parser.add_argument('-b', '--batch-size', default=32, type=int)
    parser.add_argument('--epochs', default=90, type=int, metavar='N',
                        help='number of total epochs to run')
    parser.add_argument('-j', '--workers', default=16, type=int, metavar='N',
                        help='number of data loading workers (default: 16)')
308
    parser.add_argument('--opt', default='sgd', type=str, help='optimizer')
309
310
311
312
313
314
    parser.add_argument('--lr', default=0.1, type=float, help='initial learning rate')
    parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
                        help='momentum')
    parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
                        metavar='W', help='weight decay (default: 1e-4)',
                        dest='weight_decay')
315
316
317
    parser.add_argument('--label-smoothing', default=0.0, type=float,
                        help='label smoothing (default: 0.0)',
                        dest='label_smoothing')
318
319
    parser.add_argument('--mixup-alpha', default=0.0, type=float, help='mixup alpha (default: 0.0)')
    parser.add_argument('--cutmix-alpha', default=0.0, type=float, help='cutmix alpha (default: 0.0)')
320
321
    parser.add_argument('--lr-scheduler', default="steplr", help='the lr scheduler (default: steplr)')
    parser.add_argument('--lr-warmup-epochs', default=0, type=int, help='the number of epochs to warmup (default: 0)')
322
323
324
    parser.add_argument('--lr-warmup-method', default="constant", type=str,
                        help='the warmup method (default: constant)')
    parser.add_argument('--lr-warmup-decay', default=0.01, type=float, help='the decay for lr')
325
326
327
328
329
    parser.add_argument('--lr-step-size', default=30, type=int, help='decrease lr every step-size epochs')
    parser.add_argument('--lr-gamma', default=0.1, type=float, help='decrease lr by a factor of lr-gamma')
    parser.add_argument('--print-freq', default=10, type=int, help='print frequency')
    parser.add_argument('--output-dir', default='.', help='path where to save')
    parser.add_argument('--resume', default='', help='resume from checkpoint')
330
331
332
333
334
335
336
337
338
339
340
341
342
343
    parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
                        help='start epoch')
    parser.add_argument(
        "--cache-dataset",
        dest="cache_dataset",
        help="Cache the datasets for quicker initialization. It also serializes the transforms",
        action="store_true",
    )
    parser.add_argument(
        "--sync-bn",
        dest="sync_bn",
        help="Use sync batch norm",
        action="store_true",
    )
344
345
346
347
348
349
    parser.add_argument(
        "--test-only",
        dest="test_only",
        help="Only test the model",
        action="store_true",
    )
350
351
352
353
354
355
    parser.add_argument(
        "--pretrained",
        dest="pretrained",
        help="Use pre-trained models from the modelzoo",
        action="store_true",
    )
356
357
    parser.add_argument('--auto-augment', default=None, help='auto augment policy (default: None)')
    parser.add_argument('--random-erase', default=0.0, type=float, help='random erasing probability (default: 0.0)')
358

359
360
361
362
363
364
365
366
367
    # Mixed precision training parameters
    parser.add_argument('--apex', action='store_true',
                        help='Use apex for mixed precision training')
    parser.add_argument('--apex-opt-level', default='O1', type=str,
                        help='For apex mixed precision training'
                             'O0 for FP32 training, O1 for mixed precision training.'
                             'For further detail, see https://github.com/NVIDIA/apex/tree/master/examples/imagenet'
                        )

368
369
370
371
    # distributed training parameters
    parser.add_argument('--world-size', default=1, type=int,
                        help='number of distributed processes')
    parser.add_argument('--dist-url', default='env://', help='url used to set up distributed training')
372
373
374
375
    parser.add_argument(
        '--model-ema', action='store_true',
        help='enable tracking Exponential Moving Average of model parameters')
    parser.add_argument(
376
377
        '--model-ema-decay', type=float, default=0.9,
        help='decay factor for Exponential Moving Average of model parameters(default: 0.9)')
378

379
    return parser
380

381
382

if __name__ == "__main__":
383
    args = get_args_parser().parse_args()
384
    main(args)