"vscode:/vscode.git/clone" did not exist on "6ddbf6222c3874989a176062746f804623fb7346"
train.py 16.2 KB
Newer Older
1
2
3
4
import datetime
import os
import time

5
import presets
6
7
8
import torch
import torch.utils.data
import torchvision
9
import transforms
10
import utils
11
12
13
from torch import nn
from torch.utils.data.dataloader import default_collate
from torchvision.transforms.functional import InterpolationMode
14
15


16
17
18
def train_one_epoch(
    model, criterion, optimizer, data_loader, device, epoch, print_freq, amp=False, model_ema=None, scaler=None
):
19
20
    model.train()
    metric_logger = utils.MetricLogger(delimiter="  ")
21
22
    metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}"))
    metric_logger.add_meter("img/s", utils.SmoothedValue(window_size=10, fmt="{value}"))
23

24
    header = "Epoch: [{}]".format(epoch)
25
    for image, target in metric_logger.log_every(data_loader, print_freq, header):
26
        start_time = time.time()
27
28
29
30
        image, target = image.to(device), target.to(device)
        output = model(image)

        optimizer.zero_grad()
31
32
33
34
35
36
        if amp:
            with torch.cuda.amp.autocast():
                loss = criterion(output, target)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
37
        else:
38
            loss = criterion(output, target)
39
            loss.backward()
40
41
42
43
44
        optimizer.step()

        acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
        batch_size = image.shape[0]
        metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"])
45
46
47
        metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
        metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)
        metric_logger.meters["img/s"].update(batch_size / (time.time() - start_time))
48

49
50
    if model_ema:
        model_ema.update_parameters(model)
51

52

53
def evaluate(model, criterion, data_loader, device, print_freq=100, log_suffix=""):
54
55
    model.eval()
    metric_logger = utils.MetricLogger(delimiter="  ")
56
    header = f"Test: {log_suffix}"
57
    with torch.no_grad():
58
        for image, target in metric_logger.log_every(data_loader, print_freq, header):
59
60
61
62
63
64
65
66
67
68
            image = image.to(device, non_blocking=True)
            target = target.to(device, non_blocking=True)
            output = model(image)
            loss = criterion(output, target)

            acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
            # FIXME need to take into account that the datasets
            # could have been padded in distributed setup
            batch_size = image.shape[0]
            metric_logger.update(loss=loss.item())
69
70
            metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
            metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)
71
72
73
    # gather the stats from all processes
    metric_logger.synchronize_between_processes()

74
    print(f"{header} Acc@1 {metric_logger.acc1.global_avg:.3f} Acc@5 {metric_logger.acc5.global_avg:.3f}")
75
76
77
    return metric_logger.acc1.global_avg


78
79
def _get_cache_path(filepath):
    import hashlib
80

81
82
83
84
85
86
    h = hashlib.sha1(filepath.encode()).hexdigest()
    cache_path = os.path.join("~", ".torch", "vision", "datasets", "imagefolder", h[:10] + ".pt")
    cache_path = os.path.expanduser(cache_path)
    return cache_path


87
def load_data(traindir, valdir, args):
88
89
    # Data loading code
    print("Loading data")
90
91
    resize_size, crop_size = 256, 224
    interpolation = InterpolationMode.BILINEAR
92
    if args.model == "inception_v3":
93
        resize_size, crop_size = 342, 299
94
    elif args.model.startswith("efficientnet_"):
95
        sizes = {
96
97
98
99
100
101
102
103
            "b0": (256, 224),
            "b1": (256, 240),
            "b2": (288, 288),
            "b3": (320, 300),
            "b4": (384, 380),
            "b5": (456, 456),
            "b6": (528, 528),
            "b7": (600, 600),
104
        }
105
        e_type = args.model.replace("efficientnet_", "")
106
107
        resize_size, crop_size = sizes[e_type]
        interpolation = InterpolationMode.BICUBIC
108
109
110

    print("Loading training data")
    st = time.time()
111
    cache_path = _get_cache_path(traindir)
112
    if args.cache_dataset and os.path.exists(cache_path):
113
114
115
116
        # Attention, as the transforms are also cached!
        print("Loading dataset_train from {}".format(cache_path))
        dataset, _ = torch.load(cache_path)
    else:
117
118
        auto_augment_policy = getattr(args, "auto_augment", None)
        random_erase_prob = getattr(args, "random_erase", 0.0)
119
120
        dataset = torchvision.datasets.ImageFolder(
            traindir,
121
122
123
124
            presets.ClassificationPresetTrain(
                crop_size=crop_size, auto_augment_policy=auto_augment_policy, random_erase_prob=random_erase_prob
            ),
        )
125
        if args.cache_dataset:
126
127
128
            print("Saving dataset_train to {}".format(cache_path))
            utils.mkdir(os.path.dirname(cache_path))
            utils.save_on_master((dataset, traindir), cache_path)
129
130
131
    print("Took", time.time() - st)

    print("Loading validation data")
132
    cache_path = _get_cache_path(valdir)
133
    if args.cache_dataset and os.path.exists(cache_path):
134
135
136
137
138
139
        # Attention, as the transforms are also cached!
        print("Loading dataset_test from {}".format(cache_path))
        dataset_test, _ = torch.load(cache_path)
    else:
        dataset_test = torchvision.datasets.ImageFolder(
            valdir,
140
141
            presets.ClassificationPresetEval(crop_size=crop_size, resize_size=resize_size, interpolation=interpolation),
        )
142
        if args.cache_dataset:
143
144
145
            print("Saving dataset_test to {}".format(cache_path))
            utils.mkdir(os.path.dirname(cache_path))
            utils.save_on_master((dataset_test, valdir), cache_path)
146
147

    print("Creating data loaders")
148
    if args.distributed:
149
150
151
152
153
154
        train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
        test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test)
    else:
        train_sampler = torch.utils.data.RandomSampler(dataset)
        test_sampler = torch.utils.data.SequentialSampler(dataset_test)

155
156
157
158
159
160
161
162
163
164
165
166
167
168
    return dataset, dataset_test, train_sampler, test_sampler


def main(args):
    if args.output_dir:
        utils.mkdir(args.output_dir)

    utils.init_distributed_mode(args)
    print(args)

    device = torch.device(args.device)

    torch.backends.cudnn.benchmark = True

169
170
    train_dir = os.path.join(args.data_path, "train")
    val_dir = os.path.join(args.data_path, "val")
171
    dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, args)
172
173
174
175
176
177
178
179
180
181
182

    collate_fn = None
    num_classes = len(dataset.classes)
    mixup_transforms = []
    if args.mixup_alpha > 0.0:
        mixup_transforms.append(transforms.RandomMixup(num_classes, p=1.0, alpha=args.mixup_alpha))
    if args.cutmix_alpha > 0.0:
        mixup_transforms.append(transforms.RandomCutmix(num_classes, p=1.0, alpha=args.cutmix_alpha))
    if mixup_transforms:
        mixupcutmix = torchvision.transforms.RandomChoice(mixup_transforms)
        collate_fn = lambda batch: mixupcutmix(*default_collate(batch))  # noqa: E731
183
    data_loader = torch.utils.data.DataLoader(
184
185
186
187
188
189
190
        dataset,
        batch_size=args.batch_size,
        sampler=train_sampler,
        num_workers=args.workers,
        pin_memory=True,
        collate_fn=collate_fn,
    )
191
    data_loader_test = torch.utils.data.DataLoader(
192
193
        dataset_test, batch_size=args.batch_size, sampler=test_sampler, num_workers=args.workers, pin_memory=True
    )
194
195

    print("Creating model")
196
    model = torchvision.models.__dict__[args.model](pretrained=args.pretrained, num_classes=num_classes)
197
    model.to(device)
198

199
    if args.distributed and args.sync_bn:
Francisco Massa's avatar
Francisco Massa committed
200
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
201

202
    criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
203

204
    opt_name = args.opt.lower()
205
    if opt_name.startswith("sgd"):
206
        optimizer = torch.optim.SGD(
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
            model.parameters(),
            lr=args.lr,
            momentum=args.momentum,
            weight_decay=args.weight_decay,
            nesterov="nesterov" in opt_name,
        )
    elif opt_name == "rmsprop":
        optimizer = torch.optim.RMSprop(
            model.parameters(),
            lr=args.lr,
            momentum=args.momentum,
            weight_decay=args.weight_decay,
            eps=0.0316,
            alpha=0.9,
        )
222
223
    else:
        raise RuntimeError("Invalid optimizer {}. Only SGD and RMSprop are supported.".format(args.opt))
224

225
    scaler = torch.cuda.amp.GradScaler() if args.amp else None
226

227
    args.lr_scheduler = args.lr_scheduler.lower()
228
    if args.lr_scheduler == "steplr":
229
        main_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)
230
231
232
233
234
    elif args.lr_scheduler == "cosineannealinglr":
        main_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=args.epochs - args.lr_warmup_epochs
        )
    elif args.lr_scheduler == "exponentiallr":
235
        main_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=args.lr_gamma)
236
    else:
237
238
239
240
        raise RuntimeError(
            "Invalid lr scheduler '{}'. Only StepLR, CosineAnnealingLR and ExponentialLR "
            "are supported.".format(args.lr_scheduler)
        )
241
242

    if args.lr_warmup_epochs > 0:
243
244
245
246
247
248
249
250
        if args.lr_warmup_method == "linear":
            warmup_lr_scheduler = torch.optim.lr_scheduler.LinearLR(
                optimizer, start_factor=args.lr_warmup_decay, total_iters=args.lr_warmup_epochs
            )
        elif args.lr_warmup_method == "constant":
            warmup_lr_scheduler = torch.optim.lr_scheduler.ConstantLR(
                optimizer, factor=args.lr_warmup_decay, total_iters=args.lr_warmup_epochs
            )
251
        else:
252
253
254
            raise RuntimeError(
                f"Invalid warmup lr method '{args.lr_warmup_method}'. Only linear and constant " "are supported."
            )
255
        lr_scheduler = torch.optim.lr_scheduler.SequentialLR(
256
            optimizer, schedulers=[warmup_lr_scheduler, main_lr_scheduler], milestones=[args.lr_warmup_epochs]
257
258
259
        )
    else:
        lr_scheduler = main_lr_scheduler
260
261
262
263
264
265

    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
        model_without_ddp = model.module

266
267
268
269
    model_ema = None
    if args.model_ema:
        model_ema = utils.ExponentialMovingAverage(model_without_ddp, device=device, decay=args.model_ema_decay)

270
    if args.resume:
271
272
273
274
275
        checkpoint = torch.load(args.resume, map_location="cpu")
        model_without_ddp.load_state_dict(checkpoint["model"])
        optimizer.load_state_dict(checkpoint["optimizer"])
        lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
        args.start_epoch = checkpoint["epoch"] + 1
276
        if model_ema:
277
            model_ema.load_state_dict(checkpoint["model_ema"])
278
279
280
281
282
283
284

    if args.test_only:
        evaluate(model, criterion, data_loader_test, device=device)
        return

    print("Start training")
    start_time = time.time()
285
    for epoch in range(args.start_epoch, args.epochs):
286
287
        if args.distributed:
            train_sampler.set_epoch(epoch)
288
289
290
        train_one_epoch(
            model, criterion, optimizer, data_loader, device, epoch, args.print_freq, args.amp, model_ema, scaler
        )
Francisco Massa's avatar
Francisco Massa committed
291
        lr_scheduler.step()
292
        evaluate(model, criterion, data_loader_test, device=device)
293
        if model_ema:
294
            evaluate(model_ema, criterion, data_loader_test, device=device, log_suffix="EMA")
295
        if args.output_dir:
296
            checkpoint = {
297
298
299
300
301
302
                "model": model_without_ddp.state_dict(),
                "optimizer": optimizer.state_dict(),
                "lr_scheduler": lr_scheduler.state_dict(),
                "epoch": epoch,
                "args": args,
            }
303
            if model_ema:
304
305
306
                checkpoint["model_ema"] = model_ema.state_dict()
            utils.save_on_master(checkpoint, os.path.join(args.output_dir, "model_{}.pth".format(epoch)))
            utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth"))
307
308
309

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
310
    print("Training time {}".format(total_time_str))
311
312


313
def get_args_parser(add_help=True):
314
    import argparse
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354

    parser = argparse.ArgumentParser(description="PyTorch Classification Training", add_help=add_help)

    parser.add_argument("--data-path", default="/datasets01/imagenet_full_size/061417/", help="dataset")
    parser.add_argument("--model", default="resnet18", help="model")
    parser.add_argument("--device", default="cuda", help="device")
    parser.add_argument("-b", "--batch-size", default=32, type=int)
    parser.add_argument("--epochs", default=90, type=int, metavar="N", help="number of total epochs to run")
    parser.add_argument(
        "-j", "--workers", default=16, type=int, metavar="N", help="number of data loading workers (default: 16)"
    )
    parser.add_argument("--opt", default="sgd", type=str, help="optimizer")
    parser.add_argument("--lr", default=0.1, type=float, help="initial learning rate")
    parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
    parser.add_argument(
        "--wd",
        "--weight-decay",
        default=1e-4,
        type=float,
        metavar="W",
        help="weight decay (default: 1e-4)",
        dest="weight_decay",
    )
    parser.add_argument(
        "--label-smoothing", default=0.0, type=float, help="label smoothing (default: 0.0)", dest="label_smoothing"
    )
    parser.add_argument("--mixup-alpha", default=0.0, type=float, help="mixup alpha (default: 0.0)")
    parser.add_argument("--cutmix-alpha", default=0.0, type=float, help="cutmix alpha (default: 0.0)")
    parser.add_argument("--lr-scheduler", default="steplr", help="the lr scheduler (default: steplr)")
    parser.add_argument("--lr-warmup-epochs", default=0, type=int, help="the number of epochs to warmup (default: 0)")
    parser.add_argument(
        "--lr-warmup-method", default="constant", type=str, help="the warmup method (default: constant)"
    )
    parser.add_argument("--lr-warmup-decay", default=0.01, type=float, help="the decay for lr")
    parser.add_argument("--lr-step-size", default=30, type=int, help="decrease lr every step-size epochs")
    parser.add_argument("--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma")
    parser.add_argument("--print-freq", default=10, type=int, help="print frequency")
    parser.add_argument("--output-dir", default=".", help="path where to save")
    parser.add_argument("--resume", default="", help="resume from checkpoint")
    parser.add_argument("--start-epoch", default=0, type=int, metavar="N", help="start epoch")
355
356
357
358
359
360
361
362
363
364
365
366
    parser.add_argument(
        "--cache-dataset",
        dest="cache_dataset",
        help="Cache the datasets for quicker initialization. It also serializes the transforms",
        action="store_true",
    )
    parser.add_argument(
        "--sync-bn",
        dest="sync_bn",
        help="Use sync batch norm",
        action="store_true",
    )
367
368
369
370
371
372
    parser.add_argument(
        "--test-only",
        dest="test_only",
        help="Only test the model",
        action="store_true",
    )
373
374
375
376
377
378
    parser.add_argument(
        "--pretrained",
        dest="pretrained",
        help="Use pre-trained models from the modelzoo",
        action="store_true",
    )
379
380
    parser.add_argument("--auto-augment", default=None, help="auto augment policy (default: None)")
    parser.add_argument("--random-erase", default=0.0, type=float, help="random erasing probability (default: 0.0)")
381

382
    # Mixed precision training parameters
383
    parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training")
384

385
    # distributed training parameters
386
387
    parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
    parser.add_argument("--dist-url", default="env://", help="url used to set up distributed training")
388
    parser.add_argument(
389
390
        "--model-ema", action="store_true", help="enable tracking Exponential Moving Average of model parameters"
    )
391
    parser.add_argument(
392
393
394
395
396
        "--model-ema-decay",
        type=float,
        default=0.9,
        help="decay factor for Exponential Moving Average of model parameters(default: 0.9)",
    )
397

398
    return parser
399

400
401

if __name__ == "__main__":
402
    args = get_args_parser().parse_args()
403
    main(args)