train.py 16.7 KB
Newer Older
1
2
3
4
import datetime
import os
import time

5
import presets
6
7
8
import torch
import torch.utils.data
import torchvision
9
import transforms
10
import utils
11
12
13
from torch import nn
from torch.utils.data.dataloader import default_collate
from torchvision.transforms.functional import InterpolationMode
14

15
16
17
18
try:
    from apex import amp
except ImportError:
    amp = None
19

20

21
def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, print_freq, apex=False, model_ema=None):
22
23
    model.train()
    metric_logger = utils.MetricLogger(delimiter="  ")
24
25
    metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}"))
    metric_logger.add_meter("img/s", utils.SmoothedValue(window_size=10, fmt="{value}"))
26

27
    header = "Epoch: [{}]".format(epoch)
28
    for image, target in metric_logger.log_every(data_loader, print_freq, header):
29
        start_time = time.time()
30
31
32
33
34
        image, target = image.to(device), target.to(device)
        output = model(image)
        loss = criterion(output, target)

        optimizer.zero_grad()
35
36
37
38
39
        if apex:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()
40
41
42
43
44
        optimizer.step()

        acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
        batch_size = image.shape[0]
        metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"])
45
46
47
        metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
        metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)
        metric_logger.meters["img/s"].update(batch_size / (time.time() - start_time))
48

49
50
    if model_ema:
        model_ema.update_parameters(model)
51

52

53
def evaluate(model, criterion, data_loader, device, print_freq=100, log_suffix=""):
54
55
    model.eval()
    metric_logger = utils.MetricLogger(delimiter="  ")
56
    header = f"Test: {log_suffix}"
57
    with torch.no_grad():
58
        for image, target in metric_logger.log_every(data_loader, print_freq, header):
59
60
61
62
63
64
65
66
67
68
            image = image.to(device, non_blocking=True)
            target = target.to(device, non_blocking=True)
            output = model(image)
            loss = criterion(output, target)

            acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
            # FIXME need to take into account that the datasets
            # could have been padded in distributed setup
            batch_size = image.shape[0]
            metric_logger.update(loss=loss.item())
69
70
            metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
            metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)
71
72
73
    # gather the stats from all processes
    metric_logger.synchronize_between_processes()

74
    print(f"{header} Acc@1 {metric_logger.acc1.global_avg:.3f} Acc@5 {metric_logger.acc5.global_avg:.3f}")
75
76
77
    return metric_logger.acc1.global_avg


78
79
def _get_cache_path(filepath):
    import hashlib
80

81
82
83
84
85
86
    h = hashlib.sha1(filepath.encode()).hexdigest()
    cache_path = os.path.join("~", ".torch", "vision", "datasets", "imagefolder", h[:10] + ".pt")
    cache_path = os.path.expanduser(cache_path)
    return cache_path


87
def load_data(traindir, valdir, args):
88
89
    # Data loading code
    print("Loading data")
90
91
    resize_size, crop_size = 256, 224
    interpolation = InterpolationMode.BILINEAR
92
    if args.model == "inception_v3":
93
        resize_size, crop_size = 342, 299
94
    elif args.model.startswith("efficientnet_"):
95
        sizes = {
96
97
98
99
100
101
102
103
            "b0": (256, 224),
            "b1": (256, 240),
            "b2": (288, 288),
            "b3": (320, 300),
            "b4": (384, 380),
            "b5": (456, 456),
            "b6": (528, 528),
            "b7": (600, 600),
104
        }
105
        e_type = args.model.replace("efficientnet_", "")
106
107
        resize_size, crop_size = sizes[e_type]
        interpolation = InterpolationMode.BICUBIC
108
109
110

    print("Loading training data")
    st = time.time()
111
    cache_path = _get_cache_path(traindir)
112
    if args.cache_dataset and os.path.exists(cache_path):
113
114
115
116
        # Attention, as the transforms are also cached!
        print("Loading dataset_train from {}".format(cache_path))
        dataset, _ = torch.load(cache_path)
    else:
117
118
        auto_augment_policy = getattr(args, "auto_augment", None)
        random_erase_prob = getattr(args, "random_erase", 0.0)
119
120
        dataset = torchvision.datasets.ImageFolder(
            traindir,
121
122
123
124
            presets.ClassificationPresetTrain(
                crop_size=crop_size, auto_augment_policy=auto_augment_policy, random_erase_prob=random_erase_prob
            ),
        )
125
        if args.cache_dataset:
126
127
128
            print("Saving dataset_train to {}".format(cache_path))
            utils.mkdir(os.path.dirname(cache_path))
            utils.save_on_master((dataset, traindir), cache_path)
129
130
131
    print("Took", time.time() - st)

    print("Loading validation data")
132
    cache_path = _get_cache_path(valdir)
133
    if args.cache_dataset and os.path.exists(cache_path):
134
135
136
137
138
139
        # Attention, as the transforms are also cached!
        print("Loading dataset_test from {}".format(cache_path))
        dataset_test, _ = torch.load(cache_path)
    else:
        dataset_test = torchvision.datasets.ImageFolder(
            valdir,
140
141
            presets.ClassificationPresetEval(crop_size=crop_size, resize_size=resize_size, interpolation=interpolation),
        )
142
        if args.cache_dataset:
143
144
145
            print("Saving dataset_test to {}".format(cache_path))
            utils.mkdir(os.path.dirname(cache_path))
            utils.save_on_master((dataset_test, valdir), cache_path)
146
147

    print("Creating data loaders")
148
    if args.distributed:
149
150
151
152
153
154
        train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
        test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test)
    else:
        train_sampler = torch.utils.data.RandomSampler(dataset)
        test_sampler = torch.utils.data.SequentialSampler(dataset_test)

155
156
157
158
    return dataset, dataset_test, train_sampler, test_sampler


def main(args):
159
    if args.apex and amp is None:
160
161
162
163
        raise RuntimeError(
            "Failed to import apex. Please install apex from https://www.github.com/nvidia/apex "
            "to enable mixed-precision training."
        )
164
165
166
167
168
169
170
171
172
173
174

    if args.output_dir:
        utils.mkdir(args.output_dir)

    utils.init_distributed_mode(args)
    print(args)

    device = torch.device(args.device)

    torch.backends.cudnn.benchmark = True

175
176
    train_dir = os.path.join(args.data_path, "train")
    val_dir = os.path.join(args.data_path, "val")
177
    dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, args)
178
179
180
181
182
183
184
185
186
187
188

    collate_fn = None
    num_classes = len(dataset.classes)
    mixup_transforms = []
    if args.mixup_alpha > 0.0:
        mixup_transforms.append(transforms.RandomMixup(num_classes, p=1.0, alpha=args.mixup_alpha))
    if args.cutmix_alpha > 0.0:
        mixup_transforms.append(transforms.RandomCutmix(num_classes, p=1.0, alpha=args.cutmix_alpha))
    if mixup_transforms:
        mixupcutmix = torchvision.transforms.RandomChoice(mixup_transforms)
        collate_fn = lambda batch: mixupcutmix(*default_collate(batch))  # noqa: E731
189
    data_loader = torch.utils.data.DataLoader(
190
191
192
193
194
195
196
        dataset,
        batch_size=args.batch_size,
        sampler=train_sampler,
        num_workers=args.workers,
        pin_memory=True,
        collate_fn=collate_fn,
    )
197
    data_loader_test = torch.utils.data.DataLoader(
198
199
        dataset_test, batch_size=args.batch_size, sampler=test_sampler, num_workers=args.workers, pin_memory=True
    )
200
201

    print("Creating model")
202
    model = torchvision.models.__dict__[args.model](pretrained=args.pretrained, num_classes=num_classes)
203
    model.to(device)
204

205
    if args.distributed and args.sync_bn:
Francisco Massa's avatar
Francisco Massa committed
206
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
207

208
    criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
209

210
    opt_name = args.opt.lower()
211
    if opt_name.startswith("sgd"):
212
        optimizer = torch.optim.SGD(
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
            model.parameters(),
            lr=args.lr,
            momentum=args.momentum,
            weight_decay=args.weight_decay,
            nesterov="nesterov" in opt_name,
        )
    elif opt_name == "rmsprop":
        optimizer = torch.optim.RMSprop(
            model.parameters(),
            lr=args.lr,
            momentum=args.momentum,
            weight_decay=args.weight_decay,
            eps=0.0316,
            alpha=0.9,
        )
228
229
    else:
        raise RuntimeError("Invalid optimizer {}. Only SGD and RMSprop are supported.".format(args.opt))
230

231
    if args.apex:
232
        model, optimizer = amp.initialize(model, optimizer, opt_level=args.apex_opt_level)
233

234
    args.lr_scheduler = args.lr_scheduler.lower()
235
    if args.lr_scheduler == "steplr":
236
        main_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)
237
238
239
240
241
    elif args.lr_scheduler == "cosineannealinglr":
        main_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=args.epochs - args.lr_warmup_epochs
        )
    elif args.lr_scheduler == "exponentiallr":
242
        main_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=args.lr_gamma)
243
    else:
244
245
246
247
        raise RuntimeError(
            "Invalid lr scheduler '{}'. Only StepLR, CosineAnnealingLR and ExponentialLR "
            "are supported.".format(args.lr_scheduler)
        )
248
249

    if args.lr_warmup_epochs > 0:
250
251
252
253
254
255
256
257
        if args.lr_warmup_method == "linear":
            warmup_lr_scheduler = torch.optim.lr_scheduler.LinearLR(
                optimizer, start_factor=args.lr_warmup_decay, total_iters=args.lr_warmup_epochs
            )
        elif args.lr_warmup_method == "constant":
            warmup_lr_scheduler = torch.optim.lr_scheduler.ConstantLR(
                optimizer, factor=args.lr_warmup_decay, total_iters=args.lr_warmup_epochs
            )
258
        else:
259
260
261
            raise RuntimeError(
                f"Invalid warmup lr method '{args.lr_warmup_method}'. Only linear and constant " "are supported."
            )
262
        lr_scheduler = torch.optim.lr_scheduler.SequentialLR(
263
            optimizer, schedulers=[warmup_lr_scheduler, main_lr_scheduler], milestones=[args.lr_warmup_epochs]
264
265
266
        )
    else:
        lr_scheduler = main_lr_scheduler
267
268
269
270
271
272

    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
        model_without_ddp = model.module

273
274
275
276
    model_ema = None
    if args.model_ema:
        model_ema = utils.ExponentialMovingAverage(model_without_ddp, device=device, decay=args.model_ema_decay)

277
    if args.resume:
278
279
280
281
282
        checkpoint = torch.load(args.resume, map_location="cpu")
        model_without_ddp.load_state_dict(checkpoint["model"])
        optimizer.load_state_dict(checkpoint["optimizer"])
        lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
        args.start_epoch = checkpoint["epoch"] + 1
283
        if model_ema:
284
            model_ema.load_state_dict(checkpoint["model_ema"])
285
286
287
288
289
290
291

    if args.test_only:
        evaluate(model, criterion, data_loader_test, device=device)
        return

    print("Start training")
    start_time = time.time()
292
    for epoch in range(args.start_epoch, args.epochs):
293
294
        if args.distributed:
            train_sampler.set_epoch(epoch)
295
        train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args.print_freq, args.apex, model_ema)
Francisco Massa's avatar
Francisco Massa committed
296
        lr_scheduler.step()
297
        evaluate(model, criterion, data_loader_test, device=device)
298
        if model_ema:
299
            evaluate(model_ema, criterion, data_loader_test, device=device, log_suffix="EMA")
300
        if args.output_dir:
301
            checkpoint = {
302
303
304
305
306
307
                "model": model_without_ddp.state_dict(),
                "optimizer": optimizer.state_dict(),
                "lr_scheduler": lr_scheduler.state_dict(),
                "epoch": epoch,
                "args": args,
            }
308
            if model_ema:
309
310
311
                checkpoint["model_ema"] = model_ema.state_dict()
            utils.save_on_master(checkpoint, os.path.join(args.output_dir, "model_{}.pth".format(epoch)))
            utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth"))
312
313
314

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
315
    print("Training time {}".format(total_time_str))
316
317


318
def get_args_parser(add_help=True):
319
    import argparse
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359

    parser = argparse.ArgumentParser(description="PyTorch Classification Training", add_help=add_help)

    parser.add_argument("--data-path", default="/datasets01/imagenet_full_size/061417/", help="dataset")
    parser.add_argument("--model", default="resnet18", help="model")
    parser.add_argument("--device", default="cuda", help="device")
    parser.add_argument("-b", "--batch-size", default=32, type=int)
    parser.add_argument("--epochs", default=90, type=int, metavar="N", help="number of total epochs to run")
    parser.add_argument(
        "-j", "--workers", default=16, type=int, metavar="N", help="number of data loading workers (default: 16)"
    )
    parser.add_argument("--opt", default="sgd", type=str, help="optimizer")
    parser.add_argument("--lr", default=0.1, type=float, help="initial learning rate")
    parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
    parser.add_argument(
        "--wd",
        "--weight-decay",
        default=1e-4,
        type=float,
        metavar="W",
        help="weight decay (default: 1e-4)",
        dest="weight_decay",
    )
    parser.add_argument(
        "--label-smoothing", default=0.0, type=float, help="label smoothing (default: 0.0)", dest="label_smoothing"
    )
    parser.add_argument("--mixup-alpha", default=0.0, type=float, help="mixup alpha (default: 0.0)")
    parser.add_argument("--cutmix-alpha", default=0.0, type=float, help="cutmix alpha (default: 0.0)")
    parser.add_argument("--lr-scheduler", default="steplr", help="the lr scheduler (default: steplr)")
    parser.add_argument("--lr-warmup-epochs", default=0, type=int, help="the number of epochs to warmup (default: 0)")
    parser.add_argument(
        "--lr-warmup-method", default="constant", type=str, help="the warmup method (default: constant)"
    )
    parser.add_argument("--lr-warmup-decay", default=0.01, type=float, help="the decay for lr")
    parser.add_argument("--lr-step-size", default=30, type=int, help="decrease lr every step-size epochs")
    parser.add_argument("--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma")
    parser.add_argument("--print-freq", default=10, type=int, help="print frequency")
    parser.add_argument("--output-dir", default=".", help="path where to save")
    parser.add_argument("--resume", default="", help="resume from checkpoint")
    parser.add_argument("--start-epoch", default=0, type=int, metavar="N", help="start epoch")
360
361
362
363
364
365
366
367
368
369
370
371
    parser.add_argument(
        "--cache-dataset",
        dest="cache_dataset",
        help="Cache the datasets for quicker initialization. It also serializes the transforms",
        action="store_true",
    )
    parser.add_argument(
        "--sync-bn",
        dest="sync_bn",
        help="Use sync batch norm",
        action="store_true",
    )
372
373
374
375
376
377
    parser.add_argument(
        "--test-only",
        dest="test_only",
        help="Only test the model",
        action="store_true",
    )
378
379
380
381
382
383
    parser.add_argument(
        "--pretrained",
        dest="pretrained",
        help="Use pre-trained models from the modelzoo",
        action="store_true",
    )
384
385
    parser.add_argument("--auto-augment", default=None, help="auto augment policy (default: None)")
    parser.add_argument("--random-erase", default=0.0, type=float, help="random erasing probability (default: 0.0)")
386

387
    # Mixed precision training parameters
388
389
390
391
392
393
394
395
396
    parser.add_argument("--apex", action="store_true", help="Use apex for mixed precision training")
    parser.add_argument(
        "--apex-opt-level",
        default="O1",
        type=str,
        help="For apex mixed precision training"
        "O0 for FP32 training, O1 for mixed precision training."
        "For further detail, see https://github.com/NVIDIA/apex/tree/master/examples/imagenet",
    )
397

398
    # distributed training parameters
399
400
    parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
    parser.add_argument("--dist-url", default="env://", help="url used to set up distributed training")
401
    parser.add_argument(
402
403
        "--model-ema", action="store_true", help="enable tracking Exponential Moving Average of model parameters"
    )
404
    parser.add_argument(
405
406
407
408
409
        "--model-ema-decay",
        type=float,
        default=0.9,
        help="decay factor for Exponential Moving Average of model parameters(default: 0.9)",
    )
410

411
    return parser
412

413
414

if __name__ == "__main__":
415
    args = get_args_parser().parse_args()
416
    main(args)