train.py 22.6 KB
Newer Older
1
2
3
import datetime
import os
import time
4
import warnings
5

6
import presets
7
8
9
import torch
import torch.utils.data
import torchvision
10
import torchvision.transforms
11
import transforms
12
import utils
13
from sampler import RASampler
14
15
16
from torch import nn
from torch.utils.data.dataloader import default_collate
from torchvision.transforms.functional import InterpolationMode
17
18


19
def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema=None, scaler=None):
20
21
    model.train()
    metric_logger = utils.MetricLogger(delimiter="  ")
22
23
    metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}"))
    metric_logger.add_meter("img/s", utils.SmoothedValue(window_size=10, fmt="{value}"))
24

25
    header = f"Epoch: [{epoch}]"
26
    for i, (image, target) in enumerate(metric_logger.log_every(data_loader, args.print_freq, header)):
27
        start_time = time.time()
28
        image, target = image.to(device), target.to(device)
29
        with torch.cuda.amp.autocast(enabled=scaler is not None):
30
31
            output = model(image)
            loss = criterion(output, target)
32
33

        optimizer.zero_grad()
34
        if scaler is not None:
35
            scaler.scale(loss).backward()
36
37
38
            if args.clip_grad_norm is not None:
                # we should unscale the gradients of optimizer's assigned params if do gradient clipping
                scaler.unscale_(optimizer)
39
                nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm)
40
41
            scaler.step(optimizer)
            scaler.update()
42
43
        else:
            loss.backward()
44
            if args.clip_grad_norm is not None:
45
                nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm)
46
            optimizer.step()
47

48
49
50
51
52
53
        if model_ema and i % args.model_ema_steps == 0:
            model_ema.update_parameters(model)
            if epoch < args.lr_warmup_epochs:
                # Reset ema buffer to keep copying weights during warmup period
                model_ema.n_averaged.fill_(0)

54
55
56
        acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
        batch_size = image.shape[0]
        metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"])
57
58
59
        metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
        metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)
        metric_logger.meters["img/s"].update(batch_size / (time.time() - start_time))
60

61

62
def evaluate(model, criterion, data_loader, device, print_freq=100, log_suffix=""):
63
64
    model.eval()
    metric_logger = utils.MetricLogger(delimiter="  ")
65
    header = f"Test: {log_suffix}"
66
67

    num_processed_samples = 0
68
    with torch.inference_mode():
69
        for image, target in metric_logger.log_every(data_loader, print_freq, header):
70
71
72
73
74
75
76
77
78
79
            image = image.to(device, non_blocking=True)
            target = target.to(device, non_blocking=True)
            output = model(image)
            loss = criterion(output, target)

            acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
            # FIXME need to take into account that the datasets
            # could have been padded in distributed setup
            batch_size = image.shape[0]
            metric_logger.update(loss=loss.item())
80
81
            metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
            metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)
82
            num_processed_samples += batch_size
83
    # gather the stats from all processes
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98

    num_processed_samples = utils.reduce_across_processes(num_processed_samples)
    if (
        hasattr(data_loader.dataset, "__len__")
        and len(data_loader.dataset) != num_processed_samples
        and torch.distributed.get_rank() == 0
    ):
        # See FIXME above
        warnings.warn(
            f"It looks like the dataset has {len(data_loader.dataset)} samples, but {num_processed_samples} "
            "samples were used for the validation, which might bias the results. "
            "Try adjusting the batch size and / or the world size. "
            "Setting the world size to 1 is always a safe bet."
        )

99
100
    metric_logger.synchronize_between_processes()

101
    print(f"{header} Acc@1 {metric_logger.acc1.global_avg:.3f} Acc@5 {metric_logger.acc5.global_avg:.3f}")
102
103
104
    return metric_logger.acc1.global_avg


105
106
def _get_cache_path(filepath):
    import hashlib
107

108
109
110
111
112
113
    h = hashlib.sha1(filepath.encode()).hexdigest()
    cache_path = os.path.join("~", ".torch", "vision", "datasets", "imagefolder", h[:10] + ".pt")
    cache_path = os.path.expanduser(cache_path)
    return cache_path


114
def load_data(traindir, valdir, args):
115
116
    # Data loading code
    print("Loading data")
117
    val_resize_size, val_crop_size, train_crop_size = (
Ponku's avatar
Ponku committed
118
119
120
121
        args.val_resize_size,
        args.val_crop_size,
        args.train_crop_size,
    )
122
    interpolation = InterpolationMode(args.interpolation)
123
124
125

    print("Loading training data")
    st = time.time()
126
    cache_path = _get_cache_path(traindir)
127
    if args.cache_dataset and os.path.exists(cache_path):
128
        # Attention, as the transforms are also cached!
129
        print(f"Loading dataset_train from {cache_path}")
130
131
        dataset, _ = torch.load(cache_path)
    else:
132
133
        # We need a default value for the variables below because args may come
        # from train_quantization.py which doesn't define them.
134
135
        auto_augment_policy = getattr(args, "auto_augment", None)
        random_erase_prob = getattr(args, "random_erase", 0.0)
136
137
        ra_magnitude = getattr(args, "ra_magnitude", None)
        augmix_severity = getattr(args, "augmix_severity", None)
138
139
        dataset = torchvision.datasets.ImageFolder(
            traindir,
140
            presets.ClassificationPresetTrain(
141
142
143
144
                crop_size=train_crop_size,
                interpolation=interpolation,
                auto_augment_policy=auto_augment_policy,
                random_erase_prob=random_erase_prob,
Ponku's avatar
Ponku committed
145
146
                ra_magnitude=ra_magnitude,
                augmix_severity=augmix_severity,
147
                backend=args.backend,
148
149
            ),
        )
150
        if args.cache_dataset:
151
            print(f"Saving dataset_train to {cache_path}")
152
153
            utils.mkdir(os.path.dirname(cache_path))
            utils.save_on_master((dataset, traindir), cache_path)
154
155
156
    print("Took", time.time() - st)

    print("Loading validation data")
157
    cache_path = _get_cache_path(valdir)
158
    if args.cache_dataset and os.path.exists(cache_path):
159
        # Attention, as the transforms are also cached!
160
        print(f"Loading dataset_test from {cache_path}")
161
162
        dataset_test, _ = torch.load(cache_path)
    else:
163
164
        if args.weights and args.test_only:
            weights = torchvision.models.get_weight(args.weights)
165
166
167
168
            preprocessing = weights.transforms(antialias=True)
            if args.backend == "tensor":
                preprocessing = torchvision.transforms.Compose([torchvision.transforms.PILToTensor(), preprocessing])

169
        else:
170
            preprocessing = presets.ClassificationPresetEval(
171
172
173
174
                crop_size=val_crop_size,
                resize_size=val_resize_size,
                interpolation=interpolation,
                backend=args.backend,
175
176
            )

177
178
        dataset_test = torchvision.datasets.ImageFolder(
            valdir,
179
            preprocessing,
180
        )
181
        if args.cache_dataset:
182
            print(f"Saving dataset_test to {cache_path}")
183
184
            utils.mkdir(os.path.dirname(cache_path))
            utils.save_on_master((dataset_test, valdir), cache_path)
185
186

    print("Creating data loaders")
187
    if args.distributed:
188
        if hasattr(args, "ra_sampler") and args.ra_sampler:
189
            train_sampler = RASampler(dataset, shuffle=True, repetitions=args.ra_reps)
190
191
        else:
            train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
192
        test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test, shuffle=False)
193
194
195
196
    else:
        train_sampler = torch.utils.data.RandomSampler(dataset)
        test_sampler = torch.utils.data.SequentialSampler(dataset_test)

197
198
199
200
201
202
203
204
205
206
207
208
    return dataset, dataset_test, train_sampler, test_sampler


def main(args):
    if args.output_dir:
        utils.mkdir(args.output_dir)

    utils.init_distributed_mode(args)
    print(args)

    device = torch.device(args.device)

209
210
211
212
213
    if args.use_deterministic_algorithms:
        torch.backends.cudnn.benchmark = False
        torch.use_deterministic_algorithms(True)
    else:
        torch.backends.cudnn.benchmark = True
214

215
216
    train_dir = os.path.join(args.data_path, "train")
    val_dir = os.path.join(args.data_path, "val")
217
    dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, args)
218
219
220
221
222
223
224
225
226
227

    collate_fn = None
    num_classes = len(dataset.classes)
    mixup_transforms = []
    if args.mixup_alpha > 0.0:
        mixup_transforms.append(transforms.RandomMixup(num_classes, p=1.0, alpha=args.mixup_alpha))
    if args.cutmix_alpha > 0.0:
        mixup_transforms.append(transforms.RandomCutmix(num_classes, p=1.0, alpha=args.cutmix_alpha))
    if mixup_transforms:
        mixupcutmix = torchvision.transforms.RandomChoice(mixup_transforms)
Ponku's avatar
Ponku committed
228
229
230
231

        def collate_fn(batch):
            return mixupcutmix(*default_collate(batch))

232
    data_loader = torch.utils.data.DataLoader(
233
234
235
236
237
238
239
        dataset,
        batch_size=args.batch_size,
        sampler=train_sampler,
        num_workers=args.workers,
        pin_memory=True,
        collate_fn=collate_fn,
    )
240
    data_loader_test = torch.utils.data.DataLoader(
241
242
        dataset_test, batch_size=args.batch_size, sampler=test_sampler, num_workers=args.workers, pin_memory=True
    )
243
244

    print("Creating model")
245
    model = torchvision.models.get_model(args.model, weights=args.weights, num_classes=num_classes)
246
    model.to(device)
247

248
    if args.distributed and args.sync_bn:
Francisco Massa's avatar
Francisco Massa committed
249
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
250

251
    criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
252

253
254
255
256
    custom_keys_weight_decay = []
    if args.bias_weight_decay is not None:
        custom_keys_weight_decay.append(("bias", args.bias_weight_decay))
    if args.transformer_embedding_decay is not None:
257
        for key in ["class_token", "position_embedding", "relative_position_bias_table"]:
258
259
260
261
262
263
264
            custom_keys_weight_decay.append((key, args.transformer_embedding_decay))
    parameters = utils.set_weight_decay(
        model,
        args.weight_decay,
        norm_weight_decay=args.norm_weight_decay,
        custom_keys_weight_decay=custom_keys_weight_decay if len(custom_keys_weight_decay) > 0 else None,
    )
265

266
    opt_name = args.opt.lower()
267
    if opt_name.startswith("sgd"):
268
        optimizer = torch.optim.SGD(
269
            parameters,
270
271
272
273
274
275
276
            lr=args.lr,
            momentum=args.momentum,
            weight_decay=args.weight_decay,
            nesterov="nesterov" in opt_name,
        )
    elif opt_name == "rmsprop":
        optimizer = torch.optim.RMSprop(
277
            parameters, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, eps=0.0316, alpha=0.9
278
        )
279
280
    elif opt_name == "adamw":
        optimizer = torch.optim.AdamW(parameters, lr=args.lr, weight_decay=args.weight_decay)
281
    else:
282
        raise RuntimeError(f"Invalid optimizer {args.opt}. Only SGD, RMSprop and AdamW are supported.")
283

284
    scaler = torch.cuda.amp.GradScaler() if args.amp else None
285

286
    args.lr_scheduler = args.lr_scheduler.lower()
287
    if args.lr_scheduler == "steplr":
288
        main_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)
289
290
    elif args.lr_scheduler == "cosineannealinglr":
        main_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
291
            optimizer, T_max=args.epochs - args.lr_warmup_epochs, eta_min=args.lr_min
292
293
        )
    elif args.lr_scheduler == "exponentiallr":
294
        main_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=args.lr_gamma)
295
    else:
296
        raise RuntimeError(
297
298
            f"Invalid lr scheduler '{args.lr_scheduler}'. Only StepLR, CosineAnnealingLR and ExponentialLR "
            "are supported."
299
        )
300
301

    if args.lr_warmup_epochs > 0:
302
303
304
305
306
307
308
309
        if args.lr_warmup_method == "linear":
            warmup_lr_scheduler = torch.optim.lr_scheduler.LinearLR(
                optimizer, start_factor=args.lr_warmup_decay, total_iters=args.lr_warmup_epochs
            )
        elif args.lr_warmup_method == "constant":
            warmup_lr_scheduler = torch.optim.lr_scheduler.ConstantLR(
                optimizer, factor=args.lr_warmup_decay, total_iters=args.lr_warmup_epochs
            )
310
        else:
311
            raise RuntimeError(
312
                f"Invalid warmup lr method '{args.lr_warmup_method}'. Only linear and constant are supported."
313
            )
314
        lr_scheduler = torch.optim.lr_scheduler.SequentialLR(
315
            optimizer, schedulers=[warmup_lr_scheduler, main_lr_scheduler], milestones=[args.lr_warmup_epochs]
316
317
318
        )
    else:
        lr_scheduler = main_lr_scheduler
319
320
321
322
323
324

    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
        model_without_ddp = model.module

325
326
    model_ema = None
    if args.model_ema:
327
        # Decay adjustment that aims to keep the decay independent of other hyper-parameters originally proposed at:
328
329
330
        # https://github.com/facebookresearch/pycls/blob/f8cd9627/pycls/core/net.py#L123
        #
        # total_ema_updates = (Dataset_size / n_GPUs) * epochs / (batch_size_per_gpu * EMA_steps)
331
        # We consider constant = Dataset_size for a given dataset/setup and omit it. Thus:
332
333
334
335
336
        # adjust = 1 / total_ema_updates ~= n_GPUs * batch_size_per_gpu * EMA_steps / epochs
        adjust = args.world_size * args.batch_size * args.model_ema_steps / args.epochs
        alpha = 1.0 - args.model_ema_decay
        alpha = min(1.0, alpha * adjust)
        model_ema = utils.ExponentialMovingAverage(model_without_ddp, device=device, decay=1.0 - alpha)
337

338
    if args.resume:
339
340
        checkpoint = torch.load(args.resume, map_location="cpu")
        model_without_ddp.load_state_dict(checkpoint["model"])
341
342
343
        if not args.test_only:
            optimizer.load_state_dict(checkpoint["optimizer"])
            lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
344
        args.start_epoch = checkpoint["epoch"] + 1
345
        if model_ema:
346
            model_ema.load_state_dict(checkpoint["model_ema"])
347
348
        if scaler:
            scaler.load_state_dict(checkpoint["scaler"])
349
350

    if args.test_only:
351
352
353
        # We disable the cudnn benchmarking because it can noticeably affect the accuracy
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True
354
355
356
357
        if model_ema:
            evaluate(model_ema, criterion, data_loader_test, device=device, log_suffix="EMA")
        else:
            evaluate(model, criterion, data_loader_test, device=device)
358
359
360
361
        return

    print("Start training")
    start_time = time.time()
362
    for epoch in range(args.start_epoch, args.epochs):
363
364
        if args.distributed:
            train_sampler.set_epoch(epoch)
365
        train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema, scaler)
Francisco Massa's avatar
Francisco Massa committed
366
        lr_scheduler.step()
367
        evaluate(model, criterion, data_loader_test, device=device)
368
        if model_ema:
369
            evaluate(model_ema, criterion, data_loader_test, device=device, log_suffix="EMA")
370
        if args.output_dir:
371
            checkpoint = {
372
373
374
375
376
377
                "model": model_without_ddp.state_dict(),
                "optimizer": optimizer.state_dict(),
                "lr_scheduler": lr_scheduler.state_dict(),
                "epoch": epoch,
                "args": args,
            }
378
            if model_ema:
379
                checkpoint["model_ema"] = model_ema.state_dict()
380
381
            if scaler:
                checkpoint["scaler"] = scaler.state_dict()
382
            utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth"))
383
            utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth"))
384
385
386

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
387
    print(f"Training time {total_time_str}")
388
389


390
def get_args_parser(add_help=True):
391
    import argparse
392
393
394

    parser = argparse.ArgumentParser(description="PyTorch Classification Training", add_help=add_help)

395
396
397
398
399
400
    parser.add_argument("--data-path", default="/datasets01/imagenet_full_size/061417/", type=str, help="dataset path")
    parser.add_argument("--model", default="resnet18", type=str, help="model name")
    parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)")
    parser.add_argument(
        "-b", "--batch-size", default=32, type=int, help="images per gpu, the total batch size is $NGPU x batch_size"
    )
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
    parser.add_argument("--epochs", default=90, type=int, metavar="N", help="number of total epochs to run")
    parser.add_argument(
        "-j", "--workers", default=16, type=int, metavar="N", help="number of data loading workers (default: 16)"
    )
    parser.add_argument("--opt", default="sgd", type=str, help="optimizer")
    parser.add_argument("--lr", default=0.1, type=float, help="initial learning rate")
    parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
    parser.add_argument(
        "--wd",
        "--weight-decay",
        default=1e-4,
        type=float,
        metavar="W",
        help="weight decay (default: 1e-4)",
        dest="weight_decay",
    )
417
418
419
420
421
422
    parser.add_argument(
        "--norm-weight-decay",
        default=None,
        type=float,
        help="weight decay for Normalization layers (default: None, same value as --wd)",
    )
423
424
425
426
427
428
429
430
431
432
433
434
    parser.add_argument(
        "--bias-weight-decay",
        default=None,
        type=float,
        help="weight decay for bias parameters of all layers (default: None, same value as --wd)",
    )
    parser.add_argument(
        "--transformer-embedding-decay",
        default=None,
        type=float,
        help="weight decay for embedding parameters for vision transformer models (default: None, same value as --wd)",
    )
435
436
437
438
439
    parser.add_argument(
        "--label-smoothing", default=0.0, type=float, help="label smoothing (default: 0.0)", dest="label_smoothing"
    )
    parser.add_argument("--mixup-alpha", default=0.0, type=float, help="mixup alpha (default: 0.0)")
    parser.add_argument("--cutmix-alpha", default=0.0, type=float, help="cutmix alpha (default: 0.0)")
440
    parser.add_argument("--lr-scheduler", default="steplr", type=str, help="the lr scheduler (default: steplr)")
441
442
443
444
445
446
447
    parser.add_argument("--lr-warmup-epochs", default=0, type=int, help="the number of epochs to warmup (default: 0)")
    parser.add_argument(
        "--lr-warmup-method", default="constant", type=str, help="the warmup method (default: constant)"
    )
    parser.add_argument("--lr-warmup-decay", default=0.01, type=float, help="the decay for lr")
    parser.add_argument("--lr-step-size", default=30, type=int, help="decrease lr every step-size epochs")
    parser.add_argument("--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma")
448
    parser.add_argument("--lr-min", default=0.0, type=float, help="minimum lr of lr schedule (default: 0.0)")
449
    parser.add_argument("--print-freq", default=10, type=int, help="print frequency")
450
451
    parser.add_argument("--output-dir", default=".", type=str, help="path to save outputs")
    parser.add_argument("--resume", default="", type=str, help="path of checkpoint")
452
    parser.add_argument("--start-epoch", default=0, type=int, metavar="N", help="start epoch")
453
454
455
456
457
458
459
460
461
462
463
464
    parser.add_argument(
        "--cache-dataset",
        dest="cache_dataset",
        help="Cache the datasets for quicker initialization. It also serializes the transforms",
        action="store_true",
    )
    parser.add_argument(
        "--sync-bn",
        dest="sync_bn",
        help="Use sync batch norm",
        action="store_true",
    )
465
466
467
468
469
470
    parser.add_argument(
        "--test-only",
        dest="test_only",
        help="Only test the model",
        action="store_true",
    )
471
    parser.add_argument("--auto-augment", default=None, type=str, help="auto augment policy (default: None)")
Ponku's avatar
Ponku committed
472
473
    parser.add_argument("--ra-magnitude", default=9, type=int, help="magnitude of auto augment policy")
    parser.add_argument("--augmix-severity", default=3, type=int, help="severity of augmix policy")
474
    parser.add_argument("--random-erase", default=0.0, type=float, help="random erasing probability (default: 0.0)")
475

476
    # Mixed precision training parameters
477
    parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training")
478

479
    # distributed training parameters
480
    parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
481
    parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training")
482
    parser.add_argument(
483
484
        "--model-ema", action="store_true", help="enable tracking Exponential Moving Average of model parameters"
    )
485
486
487
488
489
490
    parser.add_argument(
        "--model-ema-steps",
        type=int,
        default=32,
        help="the number of iterations that controls how often to update the EMA model (default: 32)",
    )
491
    parser.add_argument(
492
493
        "--model-ema-decay",
        type=float,
494
495
        default=0.99998,
        help="decay factor for Exponential Moving Average of model parameters (default: 0.99998)",
496
    )
497
498
499
    parser.add_argument(
        "--use-deterministic-algorithms", action="store_true", help="Forces the use of deterministic algorithms only."
    )
500
501
502
503
504
505
506
507
508
509
510
511
    parser.add_argument(
        "--interpolation", default="bilinear", type=str, help="the interpolation method (default: bilinear)"
    )
    parser.add_argument(
        "--val-resize-size", default=256, type=int, help="the resize size used for validation (default: 256)"
    )
    parser.add_argument(
        "--val-crop-size", default=224, type=int, help="the central crop size used for validation (default: 224)"
    )
    parser.add_argument(
        "--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)"
    )
512
    parser.add_argument("--clip-grad-norm", default=None, type=float, help="the maximum gradient norm (default None)")
513
514
515
516
    parser.add_argument("--ra-sampler", action="store_true", help="whether to use Repeated Augmentation in training")
    parser.add_argument(
        "--ra-reps", default=3, type=int, help="number of repetitions for Repeated Augmentation (default: 3)"
    )
517
    parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
518
    parser.add_argument("--backend", default="PIL", type=str.lower, help="PIL or tensor - case insensitive")
519
    return parser
520

521
522

if __name__ == "__main__":
523
    args = get_args_parser().parse_args()
524
    main(args)