train.py 21.5 KB
Newer Older
1
2
3
import datetime
import os
import time
4
import warnings
5

6
import presets
7
8
9
import torch
import torch.utils.data
import torchvision
10
import transforms
11
import utils
12
from sampler import RASampler
13
14
15
from torch import nn
from torch.utils.data.dataloader import default_collate
from torchvision.transforms.functional import InterpolationMode
16
17


18
def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema=None, scaler=None):
19
20
    model.train()
    metric_logger = utils.MetricLogger(delimiter="  ")
21
22
    metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}"))
    metric_logger.add_meter("img/s", utils.SmoothedValue(window_size=10, fmt="{value}"))
23

24
    header = f"Epoch: [{epoch}]"
25
    for i, (image, target) in enumerate(metric_logger.log_every(data_loader, args.print_freq, header)):
26
        start_time = time.time()
27
        image, target = image.to(device), target.to(device)
28
        with torch.cuda.amp.autocast(enabled=scaler is not None):
29
30
            output = model(image)
            loss = criterion(output, target)
31
32

        optimizer.zero_grad()
33
        if scaler is not None:
34
            scaler.scale(loss).backward()
35
36
37
            if args.clip_grad_norm is not None:
                # we should unscale the gradients of optimizer's assigned params if do gradient clipping
                scaler.unscale_(optimizer)
38
                nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm)
39
40
            scaler.step(optimizer)
            scaler.update()
41
42
        else:
            loss.backward()
43
            if args.clip_grad_norm is not None:
44
                nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm)
45
            optimizer.step()
46

47
48
49
50
51
52
        if model_ema and i % args.model_ema_steps == 0:
            model_ema.update_parameters(model)
            if epoch < args.lr_warmup_epochs:
                # Reset ema buffer to keep copying weights during warmup period
                model_ema.n_averaged.fill_(0)

53
54
55
        acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
        batch_size = image.shape[0]
        metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"])
56
57
58
        metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
        metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)
        metric_logger.meters["img/s"].update(batch_size / (time.time() - start_time))
59

60

61
def evaluate(model, criterion, data_loader, device, print_freq=100, log_suffix=""):
62
63
    model.eval()
    metric_logger = utils.MetricLogger(delimiter="  ")
64
    header = f"Test: {log_suffix}"
65
66

    num_processed_samples = 0
67
    with torch.inference_mode():
68
        for image, target in metric_logger.log_every(data_loader, print_freq, header):
69
70
71
72
73
74
75
76
77
78
            image = image.to(device, non_blocking=True)
            target = target.to(device, non_blocking=True)
            output = model(image)
            loss = criterion(output, target)

            acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
            # FIXME need to take into account that the datasets
            # could have been padded in distributed setup
            batch_size = image.shape[0]
            metric_logger.update(loss=loss.item())
79
80
            metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
            metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)
81
            num_processed_samples += batch_size
82
    # gather the stats from all processes
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97

    num_processed_samples = utils.reduce_across_processes(num_processed_samples)
    if (
        hasattr(data_loader.dataset, "__len__")
        and len(data_loader.dataset) != num_processed_samples
        and torch.distributed.get_rank() == 0
    ):
        # See FIXME above
        warnings.warn(
            f"It looks like the dataset has {len(data_loader.dataset)} samples, but {num_processed_samples} "
            "samples were used for the validation, which might bias the results. "
            "Try adjusting the batch size and / or the world size. "
            "Setting the world size to 1 is always a safe bet."
        )

98
99
    metric_logger.synchronize_between_processes()

100
    print(f"{header} Acc@1 {metric_logger.acc1.global_avg:.3f} Acc@5 {metric_logger.acc5.global_avg:.3f}")
101
102
103
    return metric_logger.acc1.global_avg


104
105
def _get_cache_path(filepath):
    import hashlib
106

107
108
109
110
111
112
    h = hashlib.sha1(filepath.encode()).hexdigest()
    cache_path = os.path.join("~", ".torch", "vision", "datasets", "imagefolder", h[:10] + ".pt")
    cache_path = os.path.expanduser(cache_path)
    return cache_path


113
def load_data(traindir, valdir, args):
114
115
    # Data loading code
    print("Loading data")
116
117
    val_resize_size, val_crop_size, train_crop_size = args.val_resize_size, args.val_crop_size, args.train_crop_size
    interpolation = InterpolationMode(args.interpolation)
118
119
120

    print("Loading training data")
    st = time.time()
121
    cache_path = _get_cache_path(traindir)
122
    if args.cache_dataset and os.path.exists(cache_path):
123
        # Attention, as the transforms are also cached!
124
        print(f"Loading dataset_train from {cache_path}")
125
126
        dataset, _ = torch.load(cache_path)
    else:
127
128
        auto_augment_policy = getattr(args, "auto_augment", None)
        random_erase_prob = getattr(args, "random_erase", 0.0)
129
130
        dataset = torchvision.datasets.ImageFolder(
            traindir,
131
            presets.ClassificationPresetTrain(
132
133
134
135
                crop_size=train_crop_size,
                interpolation=interpolation,
                auto_augment_policy=auto_augment_policy,
                random_erase_prob=random_erase_prob,
136
137
            ),
        )
138
        if args.cache_dataset:
139
            print(f"Saving dataset_train to {cache_path}")
140
141
            utils.mkdir(os.path.dirname(cache_path))
            utils.save_on_master((dataset, traindir), cache_path)
142
143
144
    print("Took", time.time() - st)

    print("Loading validation data")
145
    cache_path = _get_cache_path(valdir)
146
    if args.cache_dataset and os.path.exists(cache_path):
147
        # Attention, as the transforms are also cached!
148
        print(f"Loading dataset_test from {cache_path}")
149
150
        dataset_test, _ = torch.load(cache_path)
    else:
151
152
153
154
        if args.weights and args.test_only:
            weights = torchvision.models.get_weight(args.weights)
            preprocessing = weights.transforms()
        else:
155
156
157
158
            preprocessing = presets.ClassificationPresetEval(
                crop_size=val_crop_size, resize_size=val_resize_size, interpolation=interpolation
            )

159
160
        dataset_test = torchvision.datasets.ImageFolder(
            valdir,
161
            preprocessing,
162
        )
163
        if args.cache_dataset:
164
            print(f"Saving dataset_test to {cache_path}")
165
166
            utils.mkdir(os.path.dirname(cache_path))
            utils.save_on_master((dataset_test, valdir), cache_path)
167
168

    print("Creating data loaders")
169
    if args.distributed:
170
        if hasattr(args, "ra_sampler") and args.ra_sampler:
171
            train_sampler = RASampler(dataset, shuffle=True, repetitions=args.ra_reps)
172
173
        else:
            train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
174
        test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test, shuffle=False)
175
176
177
178
    else:
        train_sampler = torch.utils.data.RandomSampler(dataset)
        test_sampler = torch.utils.data.SequentialSampler(dataset_test)

179
180
181
182
183
184
185
186
187
188
189
190
    return dataset, dataset_test, train_sampler, test_sampler


def main(args):
    if args.output_dir:
        utils.mkdir(args.output_dir)

    utils.init_distributed_mode(args)
    print(args)

    device = torch.device(args.device)

191
192
193
194
195
    if args.use_deterministic_algorithms:
        torch.backends.cudnn.benchmark = False
        torch.use_deterministic_algorithms(True)
    else:
        torch.backends.cudnn.benchmark = True
196

197
198
    train_dir = os.path.join(args.data_path, "train")
    val_dir = os.path.join(args.data_path, "val")
199
    dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, args)
200
201
202
203
204
205
206
207
208
209
210

    collate_fn = None
    num_classes = len(dataset.classes)
    mixup_transforms = []
    if args.mixup_alpha > 0.0:
        mixup_transforms.append(transforms.RandomMixup(num_classes, p=1.0, alpha=args.mixup_alpha))
    if args.cutmix_alpha > 0.0:
        mixup_transforms.append(transforms.RandomCutmix(num_classes, p=1.0, alpha=args.cutmix_alpha))
    if mixup_transforms:
        mixupcutmix = torchvision.transforms.RandomChoice(mixup_transforms)
        collate_fn = lambda batch: mixupcutmix(*default_collate(batch))  # noqa: E731
211
    data_loader = torch.utils.data.DataLoader(
212
213
214
215
216
217
218
        dataset,
        batch_size=args.batch_size,
        sampler=train_sampler,
        num_workers=args.workers,
        pin_memory=True,
        collate_fn=collate_fn,
    )
219
    data_loader_test = torch.utils.data.DataLoader(
220
221
        dataset_test, batch_size=args.batch_size, sampler=test_sampler, num_workers=args.workers, pin_memory=True
    )
222
223

    print("Creating model")
224
    model = torchvision.models.__dict__[args.model](weights=args.weights, num_classes=num_classes)
225
    model.to(device)
226

227
    if args.distributed and args.sync_bn:
Francisco Massa's avatar
Francisco Massa committed
228
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
229

230
    criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
231

232
233
234
235
236
237
238
239
240
241
242
243
    custom_keys_weight_decay = []
    if args.bias_weight_decay is not None:
        custom_keys_weight_decay.append(("bias", args.bias_weight_decay))
    if args.transformer_embedding_decay is not None:
        for key in ["class_token", "position_embedding", "relative_position_bias"]:
            custom_keys_weight_decay.append((key, args.transformer_embedding_decay))
    parameters = utils.set_weight_decay(
        model,
        args.weight_decay,
        norm_weight_decay=args.norm_weight_decay,
        custom_keys_weight_decay=custom_keys_weight_decay if len(custom_keys_weight_decay) > 0 else None,
    )
244

245
    opt_name = args.opt.lower()
246
    if opt_name.startswith("sgd"):
247
        optimizer = torch.optim.SGD(
248
            parameters,
249
250
251
252
253
254
255
            lr=args.lr,
            momentum=args.momentum,
            weight_decay=args.weight_decay,
            nesterov="nesterov" in opt_name,
        )
    elif opt_name == "rmsprop":
        optimizer = torch.optim.RMSprop(
256
            parameters, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, eps=0.0316, alpha=0.9
257
        )
258
259
    elif opt_name == "adamw":
        optimizer = torch.optim.AdamW(parameters, lr=args.lr, weight_decay=args.weight_decay)
260
    else:
261
        raise RuntimeError(f"Invalid optimizer {args.opt}. Only SGD, RMSprop and AdamW are supported.")
262

263
    scaler = torch.cuda.amp.GradScaler() if args.amp else None
264

265
    args.lr_scheduler = args.lr_scheduler.lower()
266
    if args.lr_scheduler == "steplr":
267
        main_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)
268
269
270
271
272
    elif args.lr_scheduler == "cosineannealinglr":
        main_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=args.epochs - args.lr_warmup_epochs
        )
    elif args.lr_scheduler == "exponentiallr":
273
        main_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=args.lr_gamma)
274
    else:
275
        raise RuntimeError(
276
277
            f"Invalid lr scheduler '{args.lr_scheduler}'. Only StepLR, CosineAnnealingLR and ExponentialLR "
            "are supported."
278
        )
279
280

    if args.lr_warmup_epochs > 0:
281
282
283
284
285
286
287
288
        if args.lr_warmup_method == "linear":
            warmup_lr_scheduler = torch.optim.lr_scheduler.LinearLR(
                optimizer, start_factor=args.lr_warmup_decay, total_iters=args.lr_warmup_epochs
            )
        elif args.lr_warmup_method == "constant":
            warmup_lr_scheduler = torch.optim.lr_scheduler.ConstantLR(
                optimizer, factor=args.lr_warmup_decay, total_iters=args.lr_warmup_epochs
            )
289
        else:
290
            raise RuntimeError(
291
                f"Invalid warmup lr method '{args.lr_warmup_method}'. Only linear and constant are supported."
292
            )
293
        lr_scheduler = torch.optim.lr_scheduler.SequentialLR(
294
            optimizer, schedulers=[warmup_lr_scheduler, main_lr_scheduler], milestones=[args.lr_warmup_epochs]
295
296
297
        )
    else:
        lr_scheduler = main_lr_scheduler
298
299
300
301
302
303

    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
        model_without_ddp = model.module

304
305
    model_ema = None
    if args.model_ema:
306
307
308
309
310
311
312
313
314
315
        # Decay adjustment that aims to keep the decay independent from other hyper-parameters originally proposed at:
        # https://github.com/facebookresearch/pycls/blob/f8cd9627/pycls/core/net.py#L123
        #
        # total_ema_updates = (Dataset_size / n_GPUs) * epochs / (batch_size_per_gpu * EMA_steps)
        # We consider constant = Dataset_size for a given dataset/setup and ommit it. Thus:
        # adjust = 1 / total_ema_updates ~= n_GPUs * batch_size_per_gpu * EMA_steps / epochs
        adjust = args.world_size * args.batch_size * args.model_ema_steps / args.epochs
        alpha = 1.0 - args.model_ema_decay
        alpha = min(1.0, alpha * adjust)
        model_ema = utils.ExponentialMovingAverage(model_without_ddp, device=device, decay=1.0 - alpha)
316

317
    if args.resume:
318
319
        checkpoint = torch.load(args.resume, map_location="cpu")
        model_without_ddp.load_state_dict(checkpoint["model"])
320
321
322
        if not args.test_only:
            optimizer.load_state_dict(checkpoint["optimizer"])
            lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
323
        args.start_epoch = checkpoint["epoch"] + 1
324
        if model_ema:
325
            model_ema.load_state_dict(checkpoint["model_ema"])
326
327
        if scaler:
            scaler.load_state_dict(checkpoint["scaler"])
328
329

    if args.test_only:
330
331
332
        # We disable the cudnn benchmarking because it can noticeably affect the accuracy
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True
333
334
335
336
        if model_ema:
            evaluate(model_ema, criterion, data_loader_test, device=device, log_suffix="EMA")
        else:
            evaluate(model, criterion, data_loader_test, device=device)
337
338
339
340
        return

    print("Start training")
    start_time = time.time()
341
    for epoch in range(args.start_epoch, args.epochs):
342
343
        if args.distributed:
            train_sampler.set_epoch(epoch)
344
        train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema, scaler)
Francisco Massa's avatar
Francisco Massa committed
345
        lr_scheduler.step()
346
        evaluate(model, criterion, data_loader_test, device=device)
347
        if model_ema:
348
            evaluate(model_ema, criterion, data_loader_test, device=device, log_suffix="EMA")
349
        if args.output_dir:
350
            checkpoint = {
351
352
353
354
355
356
                "model": model_without_ddp.state_dict(),
                "optimizer": optimizer.state_dict(),
                "lr_scheduler": lr_scheduler.state_dict(),
                "epoch": epoch,
                "args": args,
            }
357
            if model_ema:
358
                checkpoint["model_ema"] = model_ema.state_dict()
359
360
            if scaler:
                checkpoint["scaler"] = scaler.state_dict()
361
            utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth"))
362
            utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth"))
363
364
365

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
366
    print(f"Training time {total_time_str}")
367
368


369
def get_args_parser(add_help=True):
370
    import argparse
371
372
373

    parser = argparse.ArgumentParser(description="PyTorch Classification Training", add_help=add_help)

374
375
376
377
378
379
    parser.add_argument("--data-path", default="/datasets01/imagenet_full_size/061417/", type=str, help="dataset path")
    parser.add_argument("--model", default="resnet18", type=str, help="model name")
    parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)")
    parser.add_argument(
        "-b", "--batch-size", default=32, type=int, help="images per gpu, the total batch size is $NGPU x batch_size"
    )
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
    parser.add_argument("--epochs", default=90, type=int, metavar="N", help="number of total epochs to run")
    parser.add_argument(
        "-j", "--workers", default=16, type=int, metavar="N", help="number of data loading workers (default: 16)"
    )
    parser.add_argument("--opt", default="sgd", type=str, help="optimizer")
    parser.add_argument("--lr", default=0.1, type=float, help="initial learning rate")
    parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
    parser.add_argument(
        "--wd",
        "--weight-decay",
        default=1e-4,
        type=float,
        metavar="W",
        help="weight decay (default: 1e-4)",
        dest="weight_decay",
    )
396
397
398
399
400
401
    parser.add_argument(
        "--norm-weight-decay",
        default=None,
        type=float,
        help="weight decay for Normalization layers (default: None, same value as --wd)",
    )
402
403
404
405
406
407
408
409
410
411
412
413
    parser.add_argument(
        "--bias-weight-decay",
        default=None,
        type=float,
        help="weight decay for bias parameters of all layers (default: None, same value as --wd)",
    )
    parser.add_argument(
        "--transformer-embedding-decay",
        default=None,
        type=float,
        help="weight decay for embedding parameters for vision transformer models (default: None, same value as --wd)",
    )
414
415
416
417
418
    parser.add_argument(
        "--label-smoothing", default=0.0, type=float, help="label smoothing (default: 0.0)", dest="label_smoothing"
    )
    parser.add_argument("--mixup-alpha", default=0.0, type=float, help="mixup alpha (default: 0.0)")
    parser.add_argument("--cutmix-alpha", default=0.0, type=float, help="cutmix alpha (default: 0.0)")
419
    parser.add_argument("--lr-scheduler", default="steplr", type=str, help="the lr scheduler (default: steplr)")
420
421
422
423
424
425
426
427
    parser.add_argument("--lr-warmup-epochs", default=0, type=int, help="the number of epochs to warmup (default: 0)")
    parser.add_argument(
        "--lr-warmup-method", default="constant", type=str, help="the warmup method (default: constant)"
    )
    parser.add_argument("--lr-warmup-decay", default=0.01, type=float, help="the decay for lr")
    parser.add_argument("--lr-step-size", default=30, type=int, help="decrease lr every step-size epochs")
    parser.add_argument("--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma")
    parser.add_argument("--print-freq", default=10, type=int, help="print frequency")
428
429
    parser.add_argument("--output-dir", default=".", type=str, help="path to save outputs")
    parser.add_argument("--resume", default="", type=str, help="path of checkpoint")
430
    parser.add_argument("--start-epoch", default=0, type=int, metavar="N", help="start epoch")
431
432
433
434
435
436
437
438
439
440
441
442
    parser.add_argument(
        "--cache-dataset",
        dest="cache_dataset",
        help="Cache the datasets for quicker initialization. It also serializes the transforms",
        action="store_true",
    )
    parser.add_argument(
        "--sync-bn",
        dest="sync_bn",
        help="Use sync batch norm",
        action="store_true",
    )
443
444
445
446
447
448
    parser.add_argument(
        "--test-only",
        dest="test_only",
        help="Only test the model",
        action="store_true",
    )
449
    parser.add_argument("--auto-augment", default=None, type=str, help="auto augment policy (default: None)")
450
    parser.add_argument("--random-erase", default=0.0, type=float, help="random erasing probability (default: 0.0)")
451

452
    # Mixed precision training parameters
453
    parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training")
454

455
    # distributed training parameters
456
    parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
457
    parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training")
458
    parser.add_argument(
459
460
        "--model-ema", action="store_true", help="enable tracking Exponential Moving Average of model parameters"
    )
461
462
463
464
465
466
    parser.add_argument(
        "--model-ema-steps",
        type=int,
        default=32,
        help="the number of iterations that controls how often to update the EMA model (default: 32)",
    )
467
    parser.add_argument(
468
469
        "--model-ema-decay",
        type=float,
470
471
        default=0.99998,
        help="decay factor for Exponential Moving Average of model parameters (default: 0.99998)",
472
    )
473
474
475
    parser.add_argument(
        "--use-deterministic-algorithms", action="store_true", help="Forces the use of deterministic algorithms only."
    )
476
477
478
479
480
481
482
483
484
485
486
487
    parser.add_argument(
        "--interpolation", default="bilinear", type=str, help="the interpolation method (default: bilinear)"
    )
    parser.add_argument(
        "--val-resize-size", default=256, type=int, help="the resize size used for validation (default: 256)"
    )
    parser.add_argument(
        "--val-crop-size", default=224, type=int, help="the central crop size used for validation (default: 224)"
    )
    parser.add_argument(
        "--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)"
    )
488
    parser.add_argument("--clip-grad-norm", default=None, type=float, help="the maximum gradient norm (default None)")
489
490
491
492
    parser.add_argument("--ra-sampler", action="store_true", help="whether to use Repeated Augmentation in training")
    parser.add_argument(
        "--ra-reps", default=3, type=int, help="number of repetitions for Repeated Augmentation (default: 3)"
    )
493
494
    parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")

495
    return parser
496

497
498

if __name__ == "__main__":
499
    args = get_args_parser().parse_args()
500
    main(args)