train.py 22.8 KB
Newer Older
1
2
3
import datetime
import os
import time
4
import warnings
5

6
import presets
7
8
9
import torch
import torch.utils.data
import torchvision
10
import torchvision.transforms
11
import transforms
12
import utils
13
from sampler import RASampler
14
15
16
from torch import nn
from torch.utils.data.dataloader import default_collate
from torchvision.transforms.functional import InterpolationMode
17
18


19
def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema=None, scaler=None):
20
21
    model.train()
    metric_logger = utils.MetricLogger(delimiter="  ")
22
23
    metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}"))
    metric_logger.add_meter("img/s", utils.SmoothedValue(window_size=10, fmt="{value}"))
24

25
    header = f"Epoch: [{epoch}]"
26
    for i, (image, target) in enumerate(metric_logger.log_every(data_loader, args.print_freq, header)):
27
        start_time = time.time()
28
        image, target = image.to(device), target.to(device)
29
        with torch.cuda.amp.autocast(enabled=scaler is not None):
30
31
            output = model(image)
            loss = criterion(output, target)
32
33

        optimizer.zero_grad()
34
        if scaler is not None:
35
            scaler.scale(loss).backward()
36
37
38
            if args.clip_grad_norm is not None:
                # we should unscale the gradients of optimizer's assigned params if do gradient clipping
                scaler.unscale_(optimizer)
39
                nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm)
40
41
            scaler.step(optimizer)
            scaler.update()
42
43
        else:
            loss.backward()
44
            if args.clip_grad_norm is not None:
45
                nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm)
46
            optimizer.step()
47

48
49
50
51
52
53
        if model_ema and i % args.model_ema_steps == 0:
            model_ema.update_parameters(model)
            if epoch < args.lr_warmup_epochs:
                # Reset ema buffer to keep copying weights during warmup period
                model_ema.n_averaged.fill_(0)

54
55
56
        acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
        batch_size = image.shape[0]
        metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"])
57
58
59
        metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
        metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)
        metric_logger.meters["img/s"].update(batch_size / (time.time() - start_time))
60

61

62
def evaluate(model, criterion, data_loader, device, print_freq=100, log_suffix=""):
63
64
    model.eval()
    metric_logger = utils.MetricLogger(delimiter="  ")
65
    header = f"Test: {log_suffix}"
66
67

    num_processed_samples = 0
68
    with torch.inference_mode():
69
        for image, target in metric_logger.log_every(data_loader, print_freq, header):
70
71
72
73
74
75
76
77
78
79
            image = image.to(device, non_blocking=True)
            target = target.to(device, non_blocking=True)
            output = model(image)
            loss = criterion(output, target)

            acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
            # FIXME need to take into account that the datasets
            # could have been padded in distributed setup
            batch_size = image.shape[0]
            metric_logger.update(loss=loss.item())
80
81
            metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
            metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)
82
            num_processed_samples += batch_size
83
    # gather the stats from all processes
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98

    num_processed_samples = utils.reduce_across_processes(num_processed_samples)
    if (
        hasattr(data_loader.dataset, "__len__")
        and len(data_loader.dataset) != num_processed_samples
        and torch.distributed.get_rank() == 0
    ):
        # See FIXME above
        warnings.warn(
            f"It looks like the dataset has {len(data_loader.dataset)} samples, but {num_processed_samples} "
            "samples were used for the validation, which might bias the results. "
            "Try adjusting the batch size and / or the world size. "
            "Setting the world size to 1 is always a safe bet."
        )

99
100
    metric_logger.synchronize_between_processes()

101
    print(f"{header} Acc@1 {metric_logger.acc1.global_avg:.3f} Acc@5 {metric_logger.acc5.global_avg:.3f}")
102
103
104
    return metric_logger.acc1.global_avg


105
106
def _get_cache_path(filepath):
    import hashlib
107

108
109
110
111
112
113
    h = hashlib.sha1(filepath.encode()).hexdigest()
    cache_path = os.path.join("~", ".torch", "vision", "datasets", "imagefolder", h[:10] + ".pt")
    cache_path = os.path.expanduser(cache_path)
    return cache_path


114
def load_data(traindir, valdir, args):
115
116
    # Data loading code
    print("Loading data")
117
    val_resize_size, val_crop_size, train_crop_size = (
Ponku's avatar
Ponku committed
118
119
120
121
        args.val_resize_size,
        args.val_crop_size,
        args.train_crop_size,
    )
122
    interpolation = InterpolationMode(args.interpolation)
123
124
125

    print("Loading training data")
    st = time.time()
126
    cache_path = _get_cache_path(traindir)
127
    if args.cache_dataset and os.path.exists(cache_path):
128
        # Attention, as the transforms are also cached!
129
        print(f"Loading dataset_train from {cache_path}")
130
131
        dataset, _ = torch.load(cache_path)
    else:
132
133
        # We need a default value for the variables below because args may come
        # from train_quantization.py which doesn't define them.
134
135
        auto_augment_policy = getattr(args, "auto_augment", None)
        random_erase_prob = getattr(args, "random_erase", 0.0)
136
137
        ra_magnitude = getattr(args, "ra_magnitude", None)
        augmix_severity = getattr(args, "augmix_severity", None)
138
139
        dataset = torchvision.datasets.ImageFolder(
            traindir,
140
            presets.ClassificationPresetTrain(
141
142
143
144
                crop_size=train_crop_size,
                interpolation=interpolation,
                auto_augment_policy=auto_augment_policy,
                random_erase_prob=random_erase_prob,
Ponku's avatar
Ponku committed
145
146
                ra_magnitude=ra_magnitude,
                augmix_severity=augmix_severity,
147
                backend=args.backend,
148
                use_v2=args.use_v2,
149
150
            ),
        )
151
        if args.cache_dataset:
152
            print(f"Saving dataset_train to {cache_path}")
153
154
            utils.mkdir(os.path.dirname(cache_path))
            utils.save_on_master((dataset, traindir), cache_path)
155
156
157
    print("Took", time.time() - st)

    print("Loading validation data")
158
    cache_path = _get_cache_path(valdir)
159
    if args.cache_dataset and os.path.exists(cache_path):
160
        # Attention, as the transforms are also cached!
161
        print(f"Loading dataset_test from {cache_path}")
162
163
        dataset_test, _ = torch.load(cache_path)
    else:
164
165
        if args.weights and args.test_only:
            weights = torchvision.models.get_weight(args.weights)
166
167
168
169
            preprocessing = weights.transforms(antialias=True)
            if args.backend == "tensor":
                preprocessing = torchvision.transforms.Compose([torchvision.transforms.PILToTensor(), preprocessing])

170
        else:
171
            preprocessing = presets.ClassificationPresetEval(
172
173
174
175
                crop_size=val_crop_size,
                resize_size=val_resize_size,
                interpolation=interpolation,
                backend=args.backend,
176
                use_v2=args.use_v2,
177
178
            )

179
180
        dataset_test = torchvision.datasets.ImageFolder(
            valdir,
181
            preprocessing,
182
        )
183
        if args.cache_dataset:
184
            print(f"Saving dataset_test to {cache_path}")
185
186
            utils.mkdir(os.path.dirname(cache_path))
            utils.save_on_master((dataset_test, valdir), cache_path)
187
188

    print("Creating data loaders")
189
    if args.distributed:
190
        if hasattr(args, "ra_sampler") and args.ra_sampler:
191
            train_sampler = RASampler(dataset, shuffle=True, repetitions=args.ra_reps)
192
193
        else:
            train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
194
        test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test, shuffle=False)
195
196
197
198
    else:
        train_sampler = torch.utils.data.RandomSampler(dataset)
        test_sampler = torch.utils.data.SequentialSampler(dataset_test)

199
200
201
202
203
204
205
206
207
208
209
210
    return dataset, dataset_test, train_sampler, test_sampler


def main(args):
    if args.output_dir:
        utils.mkdir(args.output_dir)

    utils.init_distributed_mode(args)
    print(args)

    device = torch.device(args.device)

211
212
213
214
215
    if args.use_deterministic_algorithms:
        torch.backends.cudnn.benchmark = False
        torch.use_deterministic_algorithms(True)
    else:
        torch.backends.cudnn.benchmark = True
216

217
218
    train_dir = os.path.join(args.data_path, "train")
    val_dir = os.path.join(args.data_path, "val")
219
    dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, args)
220
221
222
223
224
225
226
227
228
229

    collate_fn = None
    num_classes = len(dataset.classes)
    mixup_transforms = []
    if args.mixup_alpha > 0.0:
        mixup_transforms.append(transforms.RandomMixup(num_classes, p=1.0, alpha=args.mixup_alpha))
    if args.cutmix_alpha > 0.0:
        mixup_transforms.append(transforms.RandomCutmix(num_classes, p=1.0, alpha=args.cutmix_alpha))
    if mixup_transforms:
        mixupcutmix = torchvision.transforms.RandomChoice(mixup_transforms)
Ponku's avatar
Ponku committed
230
231
232
233

        def collate_fn(batch):
            return mixupcutmix(*default_collate(batch))

234
    data_loader = torch.utils.data.DataLoader(
235
236
237
238
239
240
241
        dataset,
        batch_size=args.batch_size,
        sampler=train_sampler,
        num_workers=args.workers,
        pin_memory=True,
        collate_fn=collate_fn,
    )
242
    data_loader_test = torch.utils.data.DataLoader(
243
244
        dataset_test, batch_size=args.batch_size, sampler=test_sampler, num_workers=args.workers, pin_memory=True
    )
245
246

    print("Creating model")
247
    model = torchvision.models.get_model(args.model, weights=args.weights, num_classes=num_classes)
248
    model.to(device)
249

250
    if args.distributed and args.sync_bn:
Francisco Massa's avatar
Francisco Massa committed
251
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
252

253
    criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
254

255
256
257
258
    custom_keys_weight_decay = []
    if args.bias_weight_decay is not None:
        custom_keys_weight_decay.append(("bias", args.bias_weight_decay))
    if args.transformer_embedding_decay is not None:
259
        for key in ["class_token", "position_embedding", "relative_position_bias_table"]:
260
261
262
263
264
265
266
            custom_keys_weight_decay.append((key, args.transformer_embedding_decay))
    parameters = utils.set_weight_decay(
        model,
        args.weight_decay,
        norm_weight_decay=args.norm_weight_decay,
        custom_keys_weight_decay=custom_keys_weight_decay if len(custom_keys_weight_decay) > 0 else None,
    )
267

268
    opt_name = args.opt.lower()
269
    if opt_name.startswith("sgd"):
270
        optimizer = torch.optim.SGD(
271
            parameters,
272
273
274
275
276
277
278
            lr=args.lr,
            momentum=args.momentum,
            weight_decay=args.weight_decay,
            nesterov="nesterov" in opt_name,
        )
    elif opt_name == "rmsprop":
        optimizer = torch.optim.RMSprop(
279
            parameters, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, eps=0.0316, alpha=0.9
280
        )
281
282
    elif opt_name == "adamw":
        optimizer = torch.optim.AdamW(parameters, lr=args.lr, weight_decay=args.weight_decay)
283
    else:
284
        raise RuntimeError(f"Invalid optimizer {args.opt}. Only SGD, RMSprop and AdamW are supported.")
285

286
    scaler = torch.cuda.amp.GradScaler() if args.amp else None
287

288
    args.lr_scheduler = args.lr_scheduler.lower()
289
    if args.lr_scheduler == "steplr":
290
        main_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)
291
292
    elif args.lr_scheduler == "cosineannealinglr":
        main_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
293
            optimizer, T_max=args.epochs - args.lr_warmup_epochs, eta_min=args.lr_min
294
295
        )
    elif args.lr_scheduler == "exponentiallr":
296
        main_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=args.lr_gamma)
297
    else:
298
        raise RuntimeError(
299
300
            f"Invalid lr scheduler '{args.lr_scheduler}'. Only StepLR, CosineAnnealingLR and ExponentialLR "
            "are supported."
301
        )
302
303

    if args.lr_warmup_epochs > 0:
304
305
306
307
308
309
310
311
        if args.lr_warmup_method == "linear":
            warmup_lr_scheduler = torch.optim.lr_scheduler.LinearLR(
                optimizer, start_factor=args.lr_warmup_decay, total_iters=args.lr_warmup_epochs
            )
        elif args.lr_warmup_method == "constant":
            warmup_lr_scheduler = torch.optim.lr_scheduler.ConstantLR(
                optimizer, factor=args.lr_warmup_decay, total_iters=args.lr_warmup_epochs
            )
312
        else:
313
            raise RuntimeError(
314
                f"Invalid warmup lr method '{args.lr_warmup_method}'. Only linear and constant are supported."
315
            )
316
        lr_scheduler = torch.optim.lr_scheduler.SequentialLR(
317
            optimizer, schedulers=[warmup_lr_scheduler, main_lr_scheduler], milestones=[args.lr_warmup_epochs]
318
319
320
        )
    else:
        lr_scheduler = main_lr_scheduler
321
322
323
324
325
326

    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
        model_without_ddp = model.module

327
328
    model_ema = None
    if args.model_ema:
329
        # Decay adjustment that aims to keep the decay independent of other hyper-parameters originally proposed at:
330
331
332
        # https://github.com/facebookresearch/pycls/blob/f8cd9627/pycls/core/net.py#L123
        #
        # total_ema_updates = (Dataset_size / n_GPUs) * epochs / (batch_size_per_gpu * EMA_steps)
333
        # We consider constant = Dataset_size for a given dataset/setup and omit it. Thus:
334
335
336
337
338
        # adjust = 1 / total_ema_updates ~= n_GPUs * batch_size_per_gpu * EMA_steps / epochs
        adjust = args.world_size * args.batch_size * args.model_ema_steps / args.epochs
        alpha = 1.0 - args.model_ema_decay
        alpha = min(1.0, alpha * adjust)
        model_ema = utils.ExponentialMovingAverage(model_without_ddp, device=device, decay=1.0 - alpha)
339

340
    if args.resume:
341
342
        checkpoint = torch.load(args.resume, map_location="cpu")
        model_without_ddp.load_state_dict(checkpoint["model"])
343
344
345
        if not args.test_only:
            optimizer.load_state_dict(checkpoint["optimizer"])
            lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
346
        args.start_epoch = checkpoint["epoch"] + 1
347
        if model_ema:
348
            model_ema.load_state_dict(checkpoint["model_ema"])
349
350
        if scaler:
            scaler.load_state_dict(checkpoint["scaler"])
351
352

    if args.test_only:
353
354
355
        # We disable the cudnn benchmarking because it can noticeably affect the accuracy
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True
356
357
358
359
        if model_ema:
            evaluate(model_ema, criterion, data_loader_test, device=device, log_suffix="EMA")
        else:
            evaluate(model, criterion, data_loader_test, device=device)
360
361
362
363
        return

    print("Start training")
    start_time = time.time()
364
    for epoch in range(args.start_epoch, args.epochs):
365
366
        if args.distributed:
            train_sampler.set_epoch(epoch)
367
        train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema, scaler)
Francisco Massa's avatar
Francisco Massa committed
368
        lr_scheduler.step()
369
        evaluate(model, criterion, data_loader_test, device=device)
370
        if model_ema:
371
            evaluate(model_ema, criterion, data_loader_test, device=device, log_suffix="EMA")
372
        if args.output_dir:
373
            checkpoint = {
374
375
376
377
378
379
                "model": model_without_ddp.state_dict(),
                "optimizer": optimizer.state_dict(),
                "lr_scheduler": lr_scheduler.state_dict(),
                "epoch": epoch,
                "args": args,
            }
380
            if model_ema:
381
                checkpoint["model_ema"] = model_ema.state_dict()
382
383
            if scaler:
                checkpoint["scaler"] = scaler.state_dict()
384
            utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth"))
385
            utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth"))
386
387
388

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
389
    print(f"Training time {total_time_str}")
390
391


392
def get_args_parser(add_help=True):
393
    import argparse
394
395
396

    parser = argparse.ArgumentParser(description="PyTorch Classification Training", add_help=add_help)

397
398
399
400
401
402
    parser.add_argument("--data-path", default="/datasets01/imagenet_full_size/061417/", type=str, help="dataset path")
    parser.add_argument("--model", default="resnet18", type=str, help="model name")
    parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)")
    parser.add_argument(
        "-b", "--batch-size", default=32, type=int, help="images per gpu, the total batch size is $NGPU x batch_size"
    )
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
    parser.add_argument("--epochs", default=90, type=int, metavar="N", help="number of total epochs to run")
    parser.add_argument(
        "-j", "--workers", default=16, type=int, metavar="N", help="number of data loading workers (default: 16)"
    )
    parser.add_argument("--opt", default="sgd", type=str, help="optimizer")
    parser.add_argument("--lr", default=0.1, type=float, help="initial learning rate")
    parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
    parser.add_argument(
        "--wd",
        "--weight-decay",
        default=1e-4,
        type=float,
        metavar="W",
        help="weight decay (default: 1e-4)",
        dest="weight_decay",
    )
419
420
421
422
423
424
    parser.add_argument(
        "--norm-weight-decay",
        default=None,
        type=float,
        help="weight decay for Normalization layers (default: None, same value as --wd)",
    )
425
426
427
428
429
430
431
432
433
434
435
436
    parser.add_argument(
        "--bias-weight-decay",
        default=None,
        type=float,
        help="weight decay for bias parameters of all layers (default: None, same value as --wd)",
    )
    parser.add_argument(
        "--transformer-embedding-decay",
        default=None,
        type=float,
        help="weight decay for embedding parameters for vision transformer models (default: None, same value as --wd)",
    )
437
438
439
440
441
    parser.add_argument(
        "--label-smoothing", default=0.0, type=float, help="label smoothing (default: 0.0)", dest="label_smoothing"
    )
    parser.add_argument("--mixup-alpha", default=0.0, type=float, help="mixup alpha (default: 0.0)")
    parser.add_argument("--cutmix-alpha", default=0.0, type=float, help="cutmix alpha (default: 0.0)")
442
    parser.add_argument("--lr-scheduler", default="steplr", type=str, help="the lr scheduler (default: steplr)")
443
444
445
446
447
448
449
    parser.add_argument("--lr-warmup-epochs", default=0, type=int, help="the number of epochs to warmup (default: 0)")
    parser.add_argument(
        "--lr-warmup-method", default="constant", type=str, help="the warmup method (default: constant)"
    )
    parser.add_argument("--lr-warmup-decay", default=0.01, type=float, help="the decay for lr")
    parser.add_argument("--lr-step-size", default=30, type=int, help="decrease lr every step-size epochs")
    parser.add_argument("--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma")
450
    parser.add_argument("--lr-min", default=0.0, type=float, help="minimum lr of lr schedule (default: 0.0)")
451
    parser.add_argument("--print-freq", default=10, type=int, help="print frequency")
452
453
    parser.add_argument("--output-dir", default=".", type=str, help="path to save outputs")
    parser.add_argument("--resume", default="", type=str, help="path of checkpoint")
454
    parser.add_argument("--start-epoch", default=0, type=int, metavar="N", help="start epoch")
455
456
457
458
459
460
461
462
463
464
465
466
    parser.add_argument(
        "--cache-dataset",
        dest="cache_dataset",
        help="Cache the datasets for quicker initialization. It also serializes the transforms",
        action="store_true",
    )
    parser.add_argument(
        "--sync-bn",
        dest="sync_bn",
        help="Use sync batch norm",
        action="store_true",
    )
467
468
469
470
471
472
    parser.add_argument(
        "--test-only",
        dest="test_only",
        help="Only test the model",
        action="store_true",
    )
473
    parser.add_argument("--auto-augment", default=None, type=str, help="auto augment policy (default: None)")
Ponku's avatar
Ponku committed
474
475
    parser.add_argument("--ra-magnitude", default=9, type=int, help="magnitude of auto augment policy")
    parser.add_argument("--augmix-severity", default=3, type=int, help="severity of augmix policy")
476
    parser.add_argument("--random-erase", default=0.0, type=float, help="random erasing probability (default: 0.0)")
477

478
    # Mixed precision training parameters
479
    parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training")
480

481
    # distributed training parameters
482
    parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
483
    parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training")
484
    parser.add_argument(
485
486
        "--model-ema", action="store_true", help="enable tracking Exponential Moving Average of model parameters"
    )
487
488
489
490
491
492
    parser.add_argument(
        "--model-ema-steps",
        type=int,
        default=32,
        help="the number of iterations that controls how often to update the EMA model (default: 32)",
    )
493
    parser.add_argument(
494
495
        "--model-ema-decay",
        type=float,
496
497
        default=0.99998,
        help="decay factor for Exponential Moving Average of model parameters (default: 0.99998)",
498
    )
499
500
501
    parser.add_argument(
        "--use-deterministic-algorithms", action="store_true", help="Forces the use of deterministic algorithms only."
    )
502
503
504
505
506
507
508
509
510
511
512
513
    parser.add_argument(
        "--interpolation", default="bilinear", type=str, help="the interpolation method (default: bilinear)"
    )
    parser.add_argument(
        "--val-resize-size", default=256, type=int, help="the resize size used for validation (default: 256)"
    )
    parser.add_argument(
        "--val-crop-size", default=224, type=int, help="the central crop size used for validation (default: 224)"
    )
    parser.add_argument(
        "--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)"
    )
514
    parser.add_argument("--clip-grad-norm", default=None, type=float, help="the maximum gradient norm (default None)")
515
516
517
518
    parser.add_argument("--ra-sampler", action="store_true", help="whether to use Repeated Augmentation in training")
    parser.add_argument(
        "--ra-reps", default=3, type=int, help="number of repetitions for Repeated Augmentation (default: 3)"
    )
519
    parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
520
    parser.add_argument("--backend", default="PIL", type=str.lower, help="PIL or tensor - case insensitive")
521
    parser.add_argument("--use-v2", action="store_true", help="Use V2 transforms")
522
    return parser
523

524
525

if __name__ == "__main__":
526
    args = get_args_parser().parse_args()
527
    main(args)