train.py 17.4 KB
Newer Older
1
2
3
import datetime
import os
import time
4
import warnings
5

6
import presets
7
8
9
import torch
import torch.utils.data
import torchvision
10
import transforms
11
import utils
12
13
14
from torch import nn
from torch.utils.data.dataloader import default_collate
from torchvision.transforms.functional import InterpolationMode
15
16


17
18
19
def train_one_epoch(
    model, criterion, optimizer, data_loader, device, epoch, print_freq, amp=False, model_ema=None, scaler=None
):
20
21
    model.train()
    metric_logger = utils.MetricLogger(delimiter="  ")
22
23
    metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}"))
    metric_logger.add_meter("img/s", utils.SmoothedValue(window_size=10, fmt="{value}"))
24

25
    header = "Epoch: [{}]".format(epoch)
26
    for image, target in metric_logger.log_every(data_loader, print_freq, header):
27
        start_time = time.time()
28
29
30
31
        image, target = image.to(device), target.to(device)
        output = model(image)

        optimizer.zero_grad()
32
33
34
35
36
37
        if amp:
            with torch.cuda.amp.autocast():
                loss = criterion(output, target)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
38
        else:
39
            loss = criterion(output, target)
40
            loss.backward()
41
42
43
44
45
        optimizer.step()

        acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
        batch_size = image.shape[0]
        metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"])
46
47
48
        metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
        metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)
        metric_logger.meters["img/s"].update(batch_size / (time.time() - start_time))
49

50
51
    if model_ema:
        model_ema.update_parameters(model)
52

53

54
def evaluate(model, criterion, data_loader, device, print_freq=100, log_suffix=""):
55
56
    model.eval()
    metric_logger = utils.MetricLogger(delimiter="  ")
57
    header = f"Test: {log_suffix}"
58
59

    num_processed_samples = 0
60
    with torch.no_grad():
61
        for image, target in metric_logger.log_every(data_loader, print_freq, header):
62
63
64
65
66
67
68
69
70
71
            image = image.to(device, non_blocking=True)
            target = target.to(device, non_blocking=True)
            output = model(image)
            loss = criterion(output, target)

            acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
            # FIXME need to take into account that the datasets
            # could have been padded in distributed setup
            batch_size = image.shape[0]
            metric_logger.update(loss=loss.item())
72
73
            metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
            metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)
74
            num_processed_samples += batch_size
75
    # gather the stats from all processes
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90

    num_processed_samples = utils.reduce_across_processes(num_processed_samples)
    if (
        hasattr(data_loader.dataset, "__len__")
        and len(data_loader.dataset) != num_processed_samples
        and torch.distributed.get_rank() == 0
    ):
        # See FIXME above
        warnings.warn(
            f"It looks like the dataset has {len(data_loader.dataset)} samples, but {num_processed_samples} "
            "samples were used for the validation, which might bias the results. "
            "Try adjusting the batch size and / or the world size. "
            "Setting the world size to 1 is always a safe bet."
        )

91
92
    metric_logger.synchronize_between_processes()

93
    print(f"{header} Acc@1 {metric_logger.acc1.global_avg:.3f} Acc@5 {metric_logger.acc5.global_avg:.3f}")
94
95
96
    return metric_logger.acc1.global_avg


97
98
def _get_cache_path(filepath):
    import hashlib
99

100
101
102
103
104
105
    h = hashlib.sha1(filepath.encode()).hexdigest()
    cache_path = os.path.join("~", ".torch", "vision", "datasets", "imagefolder", h[:10] + ".pt")
    cache_path = os.path.expanduser(cache_path)
    return cache_path


106
def load_data(traindir, valdir, args):
107
108
    # Data loading code
    print("Loading data")
109
110
    resize_size, crop_size = 256, 224
    interpolation = InterpolationMode.BILINEAR
111
    if args.model == "inception_v3":
112
        resize_size, crop_size = 342, 299
113
    elif args.model.startswith("efficientnet_"):
114
        sizes = {
115
116
117
118
119
120
121
122
            "b0": (256, 224),
            "b1": (256, 240),
            "b2": (288, 288),
            "b3": (320, 300),
            "b4": (384, 380),
            "b5": (456, 456),
            "b6": (528, 528),
            "b7": (600, 600),
123
        }
124
        e_type = args.model.replace("efficientnet_", "")
125
126
        resize_size, crop_size = sizes[e_type]
        interpolation = InterpolationMode.BICUBIC
127
128
129

    print("Loading training data")
    st = time.time()
130
    cache_path = _get_cache_path(traindir)
131
    if args.cache_dataset and os.path.exists(cache_path):
132
133
134
135
        # Attention, as the transforms are also cached!
        print("Loading dataset_train from {}".format(cache_path))
        dataset, _ = torch.load(cache_path)
    else:
136
137
        auto_augment_policy = getattr(args, "auto_augment", None)
        random_erase_prob = getattr(args, "random_erase", 0.0)
138
139
        dataset = torchvision.datasets.ImageFolder(
            traindir,
140
141
142
143
            presets.ClassificationPresetTrain(
                crop_size=crop_size, auto_augment_policy=auto_augment_policy, random_erase_prob=random_erase_prob
            ),
        )
144
        if args.cache_dataset:
145
146
147
            print("Saving dataset_train to {}".format(cache_path))
            utils.mkdir(os.path.dirname(cache_path))
            utils.save_on_master((dataset, traindir), cache_path)
148
149
150
    print("Took", time.time() - st)

    print("Loading validation data")
151
    cache_path = _get_cache_path(valdir)
152
    if args.cache_dataset and os.path.exists(cache_path):
153
154
155
156
157
158
        # Attention, as the transforms are also cached!
        print("Loading dataset_test from {}".format(cache_path))
        dataset_test, _ = torch.load(cache_path)
    else:
        dataset_test = torchvision.datasets.ImageFolder(
            valdir,
159
160
            presets.ClassificationPresetEval(crop_size=crop_size, resize_size=resize_size, interpolation=interpolation),
        )
161
        if args.cache_dataset:
162
163
164
            print("Saving dataset_test to {}".format(cache_path))
            utils.mkdir(os.path.dirname(cache_path))
            utils.save_on_master((dataset_test, valdir), cache_path)
165
166

    print("Creating data loaders")
167
    if args.distributed:
168
        train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
169
        test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test, shuffle=False)
170
171
172
173
    else:
        train_sampler = torch.utils.data.RandomSampler(dataset)
        test_sampler = torch.utils.data.SequentialSampler(dataset_test)

174
175
176
177
178
179
180
181
182
183
184
185
    return dataset, dataset_test, train_sampler, test_sampler


def main(args):
    if args.output_dir:
        utils.mkdir(args.output_dir)

    utils.init_distributed_mode(args)
    print(args)

    device = torch.device(args.device)

186
187
188
189
190
    if args.use_deterministic_algorithms:
        torch.backends.cudnn.benchmark = False
        torch.use_deterministic_algorithms(True)
    else:
        torch.backends.cudnn.benchmark = True
191

192
193
    train_dir = os.path.join(args.data_path, "train")
    val_dir = os.path.join(args.data_path, "val")
194
    dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, args)
195
196
197
198
199
200
201
202
203
204
205

    collate_fn = None
    num_classes = len(dataset.classes)
    mixup_transforms = []
    if args.mixup_alpha > 0.0:
        mixup_transforms.append(transforms.RandomMixup(num_classes, p=1.0, alpha=args.mixup_alpha))
    if args.cutmix_alpha > 0.0:
        mixup_transforms.append(transforms.RandomCutmix(num_classes, p=1.0, alpha=args.cutmix_alpha))
    if mixup_transforms:
        mixupcutmix = torchvision.transforms.RandomChoice(mixup_transforms)
        collate_fn = lambda batch: mixupcutmix(*default_collate(batch))  # noqa: E731
206
    data_loader = torch.utils.data.DataLoader(
207
208
209
210
211
212
213
        dataset,
        batch_size=args.batch_size,
        sampler=train_sampler,
        num_workers=args.workers,
        pin_memory=True,
        collate_fn=collate_fn,
    )
214
    data_loader_test = torch.utils.data.DataLoader(
215
216
        dataset_test, batch_size=args.batch_size, sampler=test_sampler, num_workers=args.workers, pin_memory=True
    )
217
218

    print("Creating model")
219
    model = torchvision.models.__dict__[args.model](pretrained=args.pretrained, num_classes=num_classes)
220
    model.to(device)
221

222
    if args.distributed and args.sync_bn:
Francisco Massa's avatar
Francisco Massa committed
223
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
224

225
    criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
226

227
    opt_name = args.opt.lower()
228
    if opt_name.startswith("sgd"):
229
        optimizer = torch.optim.SGD(
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
            model.parameters(),
            lr=args.lr,
            momentum=args.momentum,
            weight_decay=args.weight_decay,
            nesterov="nesterov" in opt_name,
        )
    elif opt_name == "rmsprop":
        optimizer = torch.optim.RMSprop(
            model.parameters(),
            lr=args.lr,
            momentum=args.momentum,
            weight_decay=args.weight_decay,
            eps=0.0316,
            alpha=0.9,
        )
245
246
    else:
        raise RuntimeError("Invalid optimizer {}. Only SGD and RMSprop are supported.".format(args.opt))
247

248
    scaler = torch.cuda.amp.GradScaler() if args.amp else None
249

250
    args.lr_scheduler = args.lr_scheduler.lower()
251
    if args.lr_scheduler == "steplr":
252
        main_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)
253
254
255
256
257
    elif args.lr_scheduler == "cosineannealinglr":
        main_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=args.epochs - args.lr_warmup_epochs
        )
    elif args.lr_scheduler == "exponentiallr":
258
        main_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=args.lr_gamma)
259
    else:
260
261
262
263
        raise RuntimeError(
            "Invalid lr scheduler '{}'. Only StepLR, CosineAnnealingLR and ExponentialLR "
            "are supported.".format(args.lr_scheduler)
        )
264
265

    if args.lr_warmup_epochs > 0:
266
267
268
269
270
271
272
273
        if args.lr_warmup_method == "linear":
            warmup_lr_scheduler = torch.optim.lr_scheduler.LinearLR(
                optimizer, start_factor=args.lr_warmup_decay, total_iters=args.lr_warmup_epochs
            )
        elif args.lr_warmup_method == "constant":
            warmup_lr_scheduler = torch.optim.lr_scheduler.ConstantLR(
                optimizer, factor=args.lr_warmup_decay, total_iters=args.lr_warmup_epochs
            )
274
        else:
275
276
277
            raise RuntimeError(
                f"Invalid warmup lr method '{args.lr_warmup_method}'. Only linear and constant " "are supported."
            )
278
        lr_scheduler = torch.optim.lr_scheduler.SequentialLR(
279
            optimizer, schedulers=[warmup_lr_scheduler, main_lr_scheduler], milestones=[args.lr_warmup_epochs]
280
281
282
        )
    else:
        lr_scheduler = main_lr_scheduler
283
284
285
286
287
288

    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
        model_without_ddp = model.module

289
290
291
292
    model_ema = None
    if args.model_ema:
        model_ema = utils.ExponentialMovingAverage(model_without_ddp, device=device, decay=args.model_ema_decay)

293
    if args.resume:
294
295
296
297
298
        checkpoint = torch.load(args.resume, map_location="cpu")
        model_without_ddp.load_state_dict(checkpoint["model"])
        optimizer.load_state_dict(checkpoint["optimizer"])
        lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
        args.start_epoch = checkpoint["epoch"] + 1
299
        if model_ema:
300
            model_ema.load_state_dict(checkpoint["model_ema"])
301
302

    if args.test_only:
303
304
305
306
        # We disable the cudnn benchmarking because it can noticeably affect the accuracy
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True

307
308
309
310
311
        evaluate(model, criterion, data_loader_test, device=device)
        return

    print("Start training")
    start_time = time.time()
312
    for epoch in range(args.start_epoch, args.epochs):
313
314
        if args.distributed:
            train_sampler.set_epoch(epoch)
315
316
317
        train_one_epoch(
            model, criterion, optimizer, data_loader, device, epoch, args.print_freq, args.amp, model_ema, scaler
        )
Francisco Massa's avatar
Francisco Massa committed
318
        lr_scheduler.step()
319
        evaluate(model, criterion, data_loader_test, device=device)
320
        if model_ema:
321
            evaluate(model_ema, criterion, data_loader_test, device=device, log_suffix="EMA")
322
        if args.output_dir:
323
            checkpoint = {
324
325
326
327
328
329
                "model": model_without_ddp.state_dict(),
                "optimizer": optimizer.state_dict(),
                "lr_scheduler": lr_scheduler.state_dict(),
                "epoch": epoch,
                "args": args,
            }
330
            if model_ema:
331
332
333
                checkpoint["model_ema"] = model_ema.state_dict()
            utils.save_on_master(checkpoint, os.path.join(args.output_dir, "model_{}.pth".format(epoch)))
            utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth"))
334
335
336

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
337
    print("Training time {}".format(total_time_str))
338
339


340
def get_args_parser(add_help=True):
341
    import argparse
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381

    parser = argparse.ArgumentParser(description="PyTorch Classification Training", add_help=add_help)

    parser.add_argument("--data-path", default="/datasets01/imagenet_full_size/061417/", help="dataset")
    parser.add_argument("--model", default="resnet18", help="model")
    parser.add_argument("--device", default="cuda", help="device")
    parser.add_argument("-b", "--batch-size", default=32, type=int)
    parser.add_argument("--epochs", default=90, type=int, metavar="N", help="number of total epochs to run")
    parser.add_argument(
        "-j", "--workers", default=16, type=int, metavar="N", help="number of data loading workers (default: 16)"
    )
    parser.add_argument("--opt", default="sgd", type=str, help="optimizer")
    parser.add_argument("--lr", default=0.1, type=float, help="initial learning rate")
    parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
    parser.add_argument(
        "--wd",
        "--weight-decay",
        default=1e-4,
        type=float,
        metavar="W",
        help="weight decay (default: 1e-4)",
        dest="weight_decay",
    )
    parser.add_argument(
        "--label-smoothing", default=0.0, type=float, help="label smoothing (default: 0.0)", dest="label_smoothing"
    )
    parser.add_argument("--mixup-alpha", default=0.0, type=float, help="mixup alpha (default: 0.0)")
    parser.add_argument("--cutmix-alpha", default=0.0, type=float, help="cutmix alpha (default: 0.0)")
    parser.add_argument("--lr-scheduler", default="steplr", help="the lr scheduler (default: steplr)")
    parser.add_argument("--lr-warmup-epochs", default=0, type=int, help="the number of epochs to warmup (default: 0)")
    parser.add_argument(
        "--lr-warmup-method", default="constant", type=str, help="the warmup method (default: constant)"
    )
    parser.add_argument("--lr-warmup-decay", default=0.01, type=float, help="the decay for lr")
    parser.add_argument("--lr-step-size", default=30, type=int, help="decrease lr every step-size epochs")
    parser.add_argument("--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma")
    parser.add_argument("--print-freq", default=10, type=int, help="print frequency")
    parser.add_argument("--output-dir", default=".", help="path where to save")
    parser.add_argument("--resume", default="", help="resume from checkpoint")
    parser.add_argument("--start-epoch", default=0, type=int, metavar="N", help="start epoch")
382
383
384
385
386
387
388
389
390
391
392
393
    parser.add_argument(
        "--cache-dataset",
        dest="cache_dataset",
        help="Cache the datasets for quicker initialization. It also serializes the transforms",
        action="store_true",
    )
    parser.add_argument(
        "--sync-bn",
        dest="sync_bn",
        help="Use sync batch norm",
        action="store_true",
    )
394
395
396
397
398
399
    parser.add_argument(
        "--test-only",
        dest="test_only",
        help="Only test the model",
        action="store_true",
    )
400
401
402
403
404
405
    parser.add_argument(
        "--pretrained",
        dest="pretrained",
        help="Use pre-trained models from the modelzoo",
        action="store_true",
    )
406
407
    parser.add_argument("--auto-augment", default=None, help="auto augment policy (default: None)")
    parser.add_argument("--random-erase", default=0.0, type=float, help="random erasing probability (default: 0.0)")
408

409
    # Mixed precision training parameters
410
    parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training")
411

412
    # distributed training parameters
413
414
    parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
    parser.add_argument("--dist-url", default="env://", help="url used to set up distributed training")
415
    parser.add_argument(
416
417
        "--model-ema", action="store_true", help="enable tracking Exponential Moving Average of model parameters"
    )
418
    parser.add_argument(
419
420
421
422
423
        "--model-ema-decay",
        type=float,
        default=0.9,
        help="decay factor for Exponential Moving Average of model parameters(default: 0.9)",
    )
424
425
426
    parser.add_argument(
        "--use-deterministic-algorithms", action="store_true", help="Forces the use of deterministic algorithms only."
    )
427

428
    return parser
429

430
431

if __name__ == "__main__":
432
    args = get_args_parser().parse_args()
433
    main(args)