train.py 21.1 KB
Newer Older
1
2
3
import datetime
import os
import time
4
import warnings
5

6
import presets
7
8
9
import torch
import torch.utils.data
import torchvision
10
import transforms
11
import utils
12
from references.classification.sampler import RASampler
13
14
15
from torch import nn
from torch.utils.data.dataloader import default_collate
from torchvision.transforms.functional import InterpolationMode
16
17


18
19
20
21
22
23
try:
    from torchvision.prototype import models as PM
except ImportError:
    PM = None


24
def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema=None, scaler=None):
25
26
    model.train()
    metric_logger = utils.MetricLogger(delimiter="  ")
27
28
    metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}"))
    metric_logger.add_meter("img/s", utils.SmoothedValue(window_size=10, fmt="{value}"))
29

30
    header = f"Epoch: [{epoch}]"
31
    for i, (image, target) in enumerate(metric_logger.log_every(data_loader, args.print_freq, header)):
32
        start_time = time.time()
33
        image, target = image.to(device), target.to(device)
34
        with torch.cuda.amp.autocast(enabled=scaler is not None):
35
36
            output = model(image)
            loss = criterion(output, target)
37
38

        optimizer.zero_grad()
39
        if scaler is not None:
40
            scaler.scale(loss).backward()
41
42
43
            if args.clip_grad_norm is not None:
                # we should unscale the gradients of optimizer's assigned params if do gradient clipping
                scaler.unscale_(optimizer)
44
                nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm)
45
46
            scaler.step(optimizer)
            scaler.update()
47
48
        else:
            loss.backward()
49
            if args.clip_grad_norm is not None:
50
                nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm)
51
            optimizer.step()
52

53
54
55
56
57
58
        if model_ema and i % args.model_ema_steps == 0:
            model_ema.update_parameters(model)
            if epoch < args.lr_warmup_epochs:
                # Reset ema buffer to keep copying weights during warmup period
                model_ema.n_averaged.fill_(0)

59
60
61
        acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
        batch_size = image.shape[0]
        metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"])
62
63
64
        metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
        metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)
        metric_logger.meters["img/s"].update(batch_size / (time.time() - start_time))
65

66

67
def evaluate(model, criterion, data_loader, device, print_freq=100, log_suffix=""):
68
69
    model.eval()
    metric_logger = utils.MetricLogger(delimiter="  ")
70
    header = f"Test: {log_suffix}"
71
72

    num_processed_samples = 0
73
    with torch.inference_mode():
74
        for image, target in metric_logger.log_every(data_loader, print_freq, header):
75
76
77
78
79
80
81
82
83
84
            image = image.to(device, non_blocking=True)
            target = target.to(device, non_blocking=True)
            output = model(image)
            loss = criterion(output, target)

            acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
            # FIXME need to take into account that the datasets
            # could have been padded in distributed setup
            batch_size = image.shape[0]
            metric_logger.update(loss=loss.item())
85
86
            metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
            metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)
87
            num_processed_samples += batch_size
88
    # gather the stats from all processes
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103

    num_processed_samples = utils.reduce_across_processes(num_processed_samples)
    if (
        hasattr(data_loader.dataset, "__len__")
        and len(data_loader.dataset) != num_processed_samples
        and torch.distributed.get_rank() == 0
    ):
        # See FIXME above
        warnings.warn(
            f"It looks like the dataset has {len(data_loader.dataset)} samples, but {num_processed_samples} "
            "samples were used for the validation, which might bias the results. "
            "Try adjusting the batch size and / or the world size. "
            "Setting the world size to 1 is always a safe bet."
        )

104
105
    metric_logger.synchronize_between_processes()

106
    print(f"{header} Acc@1 {metric_logger.acc1.global_avg:.3f} Acc@5 {metric_logger.acc5.global_avg:.3f}")
107
108
109
    return metric_logger.acc1.global_avg


110
111
def _get_cache_path(filepath):
    import hashlib
112

113
114
115
116
117
118
    h = hashlib.sha1(filepath.encode()).hexdigest()
    cache_path = os.path.join("~", ".torch", "vision", "datasets", "imagefolder", h[:10] + ".pt")
    cache_path = os.path.expanduser(cache_path)
    return cache_path


119
def load_data(traindir, valdir, args):
120
121
    # Data loading code
    print("Loading data")
122
123
    val_resize_size, val_crop_size, train_crop_size = args.val_resize_size, args.val_crop_size, args.train_crop_size
    interpolation = InterpolationMode(args.interpolation)
124
125
126

    print("Loading training data")
    st = time.time()
127
    cache_path = _get_cache_path(traindir)
128
    if args.cache_dataset and os.path.exists(cache_path):
129
        # Attention, as the transforms are also cached!
130
        print(f"Loading dataset_train from {cache_path}")
131
132
        dataset, _ = torch.load(cache_path)
    else:
133
134
        auto_augment_policy = getattr(args, "auto_augment", None)
        random_erase_prob = getattr(args, "random_erase", 0.0)
135
136
        dataset = torchvision.datasets.ImageFolder(
            traindir,
137
            presets.ClassificationPresetTrain(
138
139
140
141
                crop_size=train_crop_size,
                interpolation=interpolation,
                auto_augment_policy=auto_augment_policy,
                random_erase_prob=random_erase_prob,
142
143
            ),
        )
144
        if args.cache_dataset:
145
            print(f"Saving dataset_train to {cache_path}")
146
147
            utils.mkdir(os.path.dirname(cache_path))
            utils.save_on_master((dataset, traindir), cache_path)
148
149
150
    print("Took", time.time() - st)

    print("Loading validation data")
151
    cache_path = _get_cache_path(valdir)
152
    if args.cache_dataset and os.path.exists(cache_path):
153
        # Attention, as the transforms are also cached!
154
        print(f"Loading dataset_test from {cache_path}")
155
156
        dataset_test, _ = torch.load(cache_path)
    else:
157
158
159
160
161
        if not args.weights:
            preprocessing = presets.ClassificationPresetEval(
                crop_size=val_crop_size, resize_size=val_resize_size, interpolation=interpolation
            )
        else:
162
            weights = PM.get_weight(args.weights)
163
164
            preprocessing = weights.transforms()

165
166
        dataset_test = torchvision.datasets.ImageFolder(
            valdir,
167
            preprocessing,
168
        )
169
        if args.cache_dataset:
170
            print(f"Saving dataset_test to {cache_path}")
171
172
            utils.mkdir(os.path.dirname(cache_path))
            utils.save_on_master((dataset_test, valdir), cache_path)
173
174

    print("Creating data loaders")
175
    if args.distributed:
176
177
178
179
        if args.ra_sampler:
            train_sampler = RASampler(dataset, shuffle=True)
        else:
            train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
180
        test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test, shuffle=False)
181
182
183
184
    else:
        train_sampler = torch.utils.data.RandomSampler(dataset)
        test_sampler = torch.utils.data.SequentialSampler(dataset_test)

185
186
187
188
    return dataset, dataset_test, train_sampler, test_sampler


def main(args):
189
190
    if args.weights and PM is None:
        raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.")
191
192
193
194
195
196
197
198
    if args.output_dir:
        utils.mkdir(args.output_dir)

    utils.init_distributed_mode(args)
    print(args)

    device = torch.device(args.device)

199
200
201
202
203
    if args.use_deterministic_algorithms:
        torch.backends.cudnn.benchmark = False
        torch.use_deterministic_algorithms(True)
    else:
        torch.backends.cudnn.benchmark = True
204

205
206
    train_dir = os.path.join(args.data_path, "train")
    val_dir = os.path.join(args.data_path, "val")
207
    dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, args)
208
209
210
211
212
213
214
215
216
217
218

    collate_fn = None
    num_classes = len(dataset.classes)
    mixup_transforms = []
    if args.mixup_alpha > 0.0:
        mixup_transforms.append(transforms.RandomMixup(num_classes, p=1.0, alpha=args.mixup_alpha))
    if args.cutmix_alpha > 0.0:
        mixup_transforms.append(transforms.RandomCutmix(num_classes, p=1.0, alpha=args.cutmix_alpha))
    if mixup_transforms:
        mixupcutmix = torchvision.transforms.RandomChoice(mixup_transforms)
        collate_fn = lambda batch: mixupcutmix(*default_collate(batch))  # noqa: E731
219
    data_loader = torch.utils.data.DataLoader(
220
221
222
223
224
225
226
        dataset,
        batch_size=args.batch_size,
        sampler=train_sampler,
        num_workers=args.workers,
        pin_memory=True,
        collate_fn=collate_fn,
    )
227
    data_loader_test = torch.utils.data.DataLoader(
228
229
        dataset_test, batch_size=args.batch_size, sampler=test_sampler, num_workers=args.workers, pin_memory=True
    )
230
231

    print("Creating model")
232
233
234
235
    if not args.weights:
        model = torchvision.models.__dict__[args.model](pretrained=args.pretrained, num_classes=num_classes)
    else:
        model = PM.__dict__[args.model](weights=args.weights, num_classes=num_classes)
236
    model.to(device)
237

238
    if args.distributed and args.sync_bn:
Francisco Massa's avatar
Francisco Massa committed
239
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
240

241
    criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
242

243
244
245
246
247
248
249
    if args.norm_weight_decay is None:
        parameters = model.parameters()
    else:
        param_groups = torchvision.ops._utils.split_normalization_params(model)
        wd_groups = [args.norm_weight_decay, args.weight_decay]
        parameters = [{"params": p, "weight_decay": w} for p, w in zip(param_groups, wd_groups) if p]

250
    opt_name = args.opt.lower()
251
    if opt_name.startswith("sgd"):
252
        optimizer = torch.optim.SGD(
253
            parameters,
254
255
256
257
258
259
260
            lr=args.lr,
            momentum=args.momentum,
            weight_decay=args.weight_decay,
            nesterov="nesterov" in opt_name,
        )
    elif opt_name == "rmsprop":
        optimizer = torch.optim.RMSprop(
261
            parameters, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, eps=0.0316, alpha=0.9
262
        )
263
264
    elif opt_name == "adamw":
        optimizer = torch.optim.AdamW(parameters, lr=args.lr, weight_decay=args.weight_decay)
265
    else:
266
        raise RuntimeError(f"Invalid optimizer {args.opt}. Only SGD, RMSprop and AdamW are supported.")
267

268
    scaler = torch.cuda.amp.GradScaler() if args.amp else None
269

270
    args.lr_scheduler = args.lr_scheduler.lower()
271
    if args.lr_scheduler == "steplr":
272
        main_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)
273
274
275
276
277
    elif args.lr_scheduler == "cosineannealinglr":
        main_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=args.epochs - args.lr_warmup_epochs
        )
    elif args.lr_scheduler == "exponentiallr":
278
        main_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=args.lr_gamma)
279
    else:
280
        raise RuntimeError(
281
282
            f"Invalid lr scheduler '{args.lr_scheduler}'. Only StepLR, CosineAnnealingLR and ExponentialLR "
            "are supported."
283
        )
284
285

    if args.lr_warmup_epochs > 0:
286
287
288
289
290
291
292
293
        if args.lr_warmup_method == "linear":
            warmup_lr_scheduler = torch.optim.lr_scheduler.LinearLR(
                optimizer, start_factor=args.lr_warmup_decay, total_iters=args.lr_warmup_epochs
            )
        elif args.lr_warmup_method == "constant":
            warmup_lr_scheduler = torch.optim.lr_scheduler.ConstantLR(
                optimizer, factor=args.lr_warmup_decay, total_iters=args.lr_warmup_epochs
            )
294
        else:
295
            raise RuntimeError(
296
                f"Invalid warmup lr method '{args.lr_warmup_method}'. Only linear and constant are supported."
297
            )
298
        lr_scheduler = torch.optim.lr_scheduler.SequentialLR(
299
            optimizer, schedulers=[warmup_lr_scheduler, main_lr_scheduler], milestones=[args.lr_warmup_epochs]
300
301
302
        )
    else:
        lr_scheduler = main_lr_scheduler
303
304
305
306
307
308

    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
        model_without_ddp = model.module

309
310
    model_ema = None
    if args.model_ema:
311
312
313
314
315
316
317
318
319
320
        # Decay adjustment that aims to keep the decay independent from other hyper-parameters originally proposed at:
        # https://github.com/facebookresearch/pycls/blob/f8cd9627/pycls/core/net.py#L123
        #
        # total_ema_updates = (Dataset_size / n_GPUs) * epochs / (batch_size_per_gpu * EMA_steps)
        # We consider constant = Dataset_size for a given dataset/setup and ommit it. Thus:
        # adjust = 1 / total_ema_updates ~= n_GPUs * batch_size_per_gpu * EMA_steps / epochs
        adjust = args.world_size * args.batch_size * args.model_ema_steps / args.epochs
        alpha = 1.0 - args.model_ema_decay
        alpha = min(1.0, alpha * adjust)
        model_ema = utils.ExponentialMovingAverage(model_without_ddp, device=device, decay=1.0 - alpha)
321

322
    if args.resume:
323
324
        checkpoint = torch.load(args.resume, map_location="cpu")
        model_without_ddp.load_state_dict(checkpoint["model"])
325
326
327
        if not args.test_only:
            optimizer.load_state_dict(checkpoint["optimizer"])
            lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
328
        args.start_epoch = checkpoint["epoch"] + 1
329
        if model_ema:
330
            model_ema.load_state_dict(checkpoint["model_ema"])
331
332
        if scaler:
            scaler.load_state_dict(checkpoint["scaler"])
333
334

    if args.test_only:
335
336
337
        # We disable the cudnn benchmarking because it can noticeably affect the accuracy
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True
338
339
340
341
        if model_ema:
            evaluate(model_ema, criterion, data_loader_test, device=device, log_suffix="EMA")
        else:
            evaluate(model, criterion, data_loader_test, device=device)
342
343
344
345
        return

    print("Start training")
    start_time = time.time()
346
    for epoch in range(args.start_epoch, args.epochs):
347
348
        if args.distributed:
            train_sampler.set_epoch(epoch)
349
        train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema, scaler)
Francisco Massa's avatar
Francisco Massa committed
350
        lr_scheduler.step()
351
        evaluate(model, criterion, data_loader_test, device=device)
352
        if model_ema:
353
            evaluate(model_ema, criterion, data_loader_test, device=device, log_suffix="EMA")
354
        if args.output_dir:
355
            checkpoint = {
356
357
358
359
360
361
                "model": model_without_ddp.state_dict(),
                "optimizer": optimizer.state_dict(),
                "lr_scheduler": lr_scheduler.state_dict(),
                "epoch": epoch,
                "args": args,
            }
362
            if model_ema:
363
                checkpoint["model_ema"] = model_ema.state_dict()
364
365
            if scaler:
                checkpoint["scaler"] = scaler.state_dict()
366
            utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth"))
367
            utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth"))
368
369
370

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
371
    print(f"Training time {total_time_str}")
372
373


374
def get_args_parser(add_help=True):
375
    import argparse
376
377
378

    parser = argparse.ArgumentParser(description="PyTorch Classification Training", add_help=add_help)

379
380
381
382
383
384
    parser.add_argument("--data-path", default="/datasets01/imagenet_full_size/061417/", type=str, help="dataset path")
    parser.add_argument("--model", default="resnet18", type=str, help="model name")
    parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)")
    parser.add_argument(
        "-b", "--batch-size", default=32, type=int, help="images per gpu, the total batch size is $NGPU x batch_size"
    )
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
    parser.add_argument("--epochs", default=90, type=int, metavar="N", help="number of total epochs to run")
    parser.add_argument(
        "-j", "--workers", default=16, type=int, metavar="N", help="number of data loading workers (default: 16)"
    )
    parser.add_argument("--opt", default="sgd", type=str, help="optimizer")
    parser.add_argument("--lr", default=0.1, type=float, help="initial learning rate")
    parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
    parser.add_argument(
        "--wd",
        "--weight-decay",
        default=1e-4,
        type=float,
        metavar="W",
        help="weight decay (default: 1e-4)",
        dest="weight_decay",
    )
401
402
403
404
405
406
    parser.add_argument(
        "--norm-weight-decay",
        default=None,
        type=float,
        help="weight decay for Normalization layers (default: None, same value as --wd)",
    )
407
408
409
410
411
    parser.add_argument(
        "--label-smoothing", default=0.0, type=float, help="label smoothing (default: 0.0)", dest="label_smoothing"
    )
    parser.add_argument("--mixup-alpha", default=0.0, type=float, help="mixup alpha (default: 0.0)")
    parser.add_argument("--cutmix-alpha", default=0.0, type=float, help="cutmix alpha (default: 0.0)")
412
    parser.add_argument("--lr-scheduler", default="steplr", type=str, help="the lr scheduler (default: steplr)")
413
414
415
416
417
418
419
420
    parser.add_argument("--lr-warmup-epochs", default=0, type=int, help="the number of epochs to warmup (default: 0)")
    parser.add_argument(
        "--lr-warmup-method", default="constant", type=str, help="the warmup method (default: constant)"
    )
    parser.add_argument("--lr-warmup-decay", default=0.01, type=float, help="the decay for lr")
    parser.add_argument("--lr-step-size", default=30, type=int, help="decrease lr every step-size epochs")
    parser.add_argument("--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma")
    parser.add_argument("--print-freq", default=10, type=int, help="print frequency")
421
422
    parser.add_argument("--output-dir", default=".", type=str, help="path to save outputs")
    parser.add_argument("--resume", default="", type=str, help="path of checkpoint")
423
    parser.add_argument("--start-epoch", default=0, type=int, metavar="N", help="start epoch")
424
425
426
427
428
429
430
431
432
433
434
435
    parser.add_argument(
        "--cache-dataset",
        dest="cache_dataset",
        help="Cache the datasets for quicker initialization. It also serializes the transforms",
        action="store_true",
    )
    parser.add_argument(
        "--sync-bn",
        dest="sync_bn",
        help="Use sync batch norm",
        action="store_true",
    )
436
437
438
439
440
441
    parser.add_argument(
        "--test-only",
        dest="test_only",
        help="Only test the model",
        action="store_true",
    )
442
443
444
445
446
447
    parser.add_argument(
        "--pretrained",
        dest="pretrained",
        help="Use pre-trained models from the modelzoo",
        action="store_true",
    )
448
    parser.add_argument("--auto-augment", default=None, type=str, help="auto augment policy (default: None)")
449
    parser.add_argument("--random-erase", default=0.0, type=float, help="random erasing probability (default: 0.0)")
450

451
    # Mixed precision training parameters
452
    parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training")
453

454
    # distributed training parameters
455
    parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
456
    parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training")
457
    parser.add_argument(
458
459
        "--model-ema", action="store_true", help="enable tracking Exponential Moving Average of model parameters"
    )
460
461
462
463
464
465
    parser.add_argument(
        "--model-ema-steps",
        type=int,
        default=32,
        help="the number of iterations that controls how often to update the EMA model (default: 32)",
    )
466
    parser.add_argument(
467
468
        "--model-ema-decay",
        type=float,
469
470
        default=0.99998,
        help="decay factor for Exponential Moving Average of model parameters (default: 0.99998)",
471
    )
472
473
474
    parser.add_argument(
        "--use-deterministic-algorithms", action="store_true", help="Forces the use of deterministic algorithms only."
    )
475
476
477
478
479
480
481
482
483
484
485
486
    parser.add_argument(
        "--interpolation", default="bilinear", type=str, help="the interpolation method (default: bilinear)"
    )
    parser.add_argument(
        "--val-resize-size", default=256, type=int, help="the resize size used for validation (default: 256)"
    )
    parser.add_argument(
        "--val-crop-size", default=224, type=int, help="the central crop size used for validation (default: 224)"
    )
    parser.add_argument(
        "--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)"
    )
487
    parser.add_argument("--clip-grad-norm", default=None, type=float, help="the maximum gradient norm (default None)")
488
    parser.add_argument("--ra-sampler", action="store_true", help="whether to use ra_sampler in training")
489

490
491
492
    # Prototype models only
    parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")

493
    return parser
494

495
496

if __name__ == "__main__":
497
    args = get_args_parser().parse_args()
498
    main(args)