train.py 17.9 KB
Newer Older
1
2
3
import datetime
import os
import time
limm's avatar
limm committed
4
5
6
7
import warnings

import datasets
import presets
8
9
10
11
12
import torch
import torch.utils.data
import torchvision
import torchvision.datasets.video_utils
import utils
limm's avatar
limm committed
13
14
15
from torch import nn
from torch.utils.data.dataloader import default_collate
from torchvision.datasets.samplers import DistributedSampler, RandomClipSampler, UniformClipSampler
16
17


limm's avatar
limm committed
18
def train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, device, epoch, print_freq, scaler=None):
19
20
    model.train()
    metric_logger = utils.MetricLogger(delimiter="  ")
limm's avatar
limm committed
21
22
    metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}"))
    metric_logger.add_meter("clips/s", utils.SmoothedValue(window_size=10, fmt="{value:.3f}"))
23

limm's avatar
limm committed
24
25
    header = f"Epoch: [{epoch}]"
    for video, target, _ in metric_logger.log_every(data_loader, print_freq, header):
26
27
        start_time = time.time()
        video, target = video.to(device), target.to(device)
limm's avatar
limm committed
28
29
30
        with torch.cuda.amp.autocast(enabled=scaler is not None):
            output = model(video)
            loss = criterion(output, target)
31
32

        optimizer.zero_grad()
limm's avatar
limm committed
33
34
35
36
37

        if scaler is not None:
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
38
39
        else:
            loss.backward()
limm's avatar
limm committed
40
            optimizer.step()
41
42
43
44

        acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
        batch_size = video.shape[0]
        metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"])
limm's avatar
limm committed
45
46
47
        metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
        metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)
        metric_logger.meters["clips/s"].update(batch_size / (time.time() - start_time))
48
49
50
51
52
53
        lr_scheduler.step()


def evaluate(model, criterion, data_loader, device):
    model.eval()
    metric_logger = utils.MetricLogger(delimiter="  ")
limm's avatar
limm committed
54
55
56
57
58
59
60
61
62
    header = "Test:"
    num_processed_samples = 0
    # Group and aggregate output of a video
    num_videos = len(data_loader.dataset.samples)
    num_classes = len(data_loader.dataset.classes)
    agg_preds = torch.zeros((num_videos, num_classes), dtype=torch.float32, device=device)
    agg_targets = torch.zeros((num_videos), dtype=torch.int32, device=device)
    with torch.inference_mode():
        for video, target, video_idx in metric_logger.log_every(data_loader, 100, header):
63
64
65
66
67
            video = video.to(device, non_blocking=True)
            target = target.to(device, non_blocking=True)
            output = model(video)
            loss = criterion(output, target)

limm's avatar
limm committed
68
69
70
71
72
73
74
            # Use softmax to convert output into prediction probability
            preds = torch.softmax(output, dim=1)
            for b in range(video.size(0)):
                idx = video_idx[b].item()
                agg_preds[idx] += preds[b].detach()
                agg_targets[idx] = target[b].detach().item()

75
76
77
78
79
            acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
            # FIXME need to take into account that the datasets
            # could have been padded in distributed setup
            batch_size = video.shape[0]
            metric_logger.update(loss=loss.item())
limm's avatar
limm committed
80
81
82
            metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
            metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)
            num_processed_samples += batch_size
83
    # gather the stats from all processes
limm's avatar
limm committed
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
    num_processed_samples = utils.reduce_across_processes(num_processed_samples)
    if isinstance(data_loader.sampler, DistributedSampler):
        # Get the len of UniformClipSampler inside DistributedSampler
        num_data_from_sampler = len(data_loader.sampler.dataset)
    else:
        num_data_from_sampler = len(data_loader.sampler)

    if (
        hasattr(data_loader.dataset, "__len__")
        and num_data_from_sampler != num_processed_samples
        and torch.distributed.get_rank() == 0
    ):
        # See FIXME above
        warnings.warn(
            f"It looks like the sampler has {num_data_from_sampler} samples, but {num_processed_samples} "
            "samples were used for the validation, which might bias the results. "
            "Try adjusting the batch size and / or the world size. "
            "Setting the world size to 1 is always a safe bet."
        )

104
105
    metric_logger.synchronize_between_processes()

limm's avatar
limm committed
106
107
108
109
110
111
112
113
114
115
    print(
        " * Clip Acc@1 {top1.global_avg:.3f} Clip Acc@5 {top5.global_avg:.3f}".format(
            top1=metric_logger.acc1, top5=metric_logger.acc5
        )
    )
    # Reduce the agg_preds and agg_targets from all gpu and show result
    agg_preds = utils.reduce_across_processes(agg_preds)
    agg_targets = utils.reduce_across_processes(agg_targets, op=torch.distributed.ReduceOp.MAX)
    agg_acc1, agg_acc5 = utils.accuracy(agg_preds, agg_targets, topk=(1, 5))
    print(" * Video Acc@1 {acc1:.3f} Video Acc@5 {acc5:.3f}".format(acc1=agg_acc1, acc5=agg_acc5))
116
117
118
    return metric_logger.acc1.global_avg


limm's avatar
limm committed
119
def _get_cache_path(filepath, args):
120
    import hashlib
limm's avatar
limm committed
121
122
123

    value = f"{filepath}-{args.clip_len}-{args.kinetics_version}-{args.frame_rate}"
    h = hashlib.sha1(value.encode()).hexdigest()
124
125
126
127
128
129
130
    cache_path = os.path.join("~", ".torch", "vision", "datasets", "kinetics", h[:10] + ".pt")
    cache_path = os.path.expanduser(cache_path)
    return cache_path


def collate_fn(batch):
    # remove audio from the batch
limm's avatar
limm committed
131
    batch = [(d[0], d[2], d[3]) for d in batch]
132
133
134
135
136
137
138
139
140
141
142
143
    return default_collate(batch)


def main(args):
    if args.output_dir:
        utils.mkdir(args.output_dir)

    utils.init_distributed_mode(args)
    print(args)

    device = torch.device(args.device)

limm's avatar
limm committed
144
145
146
147
148
    if args.use_deterministic_algorithms:
        torch.backends.cudnn.benchmark = False
        torch.use_deterministic_algorithms(True)
    else:
        torch.backends.cudnn.benchmark = True
149
150
151

    # Data loading code
    print("Loading data")
limm's avatar
limm committed
152
153
154
155
156
157
158
    val_resize_size = tuple(args.val_resize_size)
    val_crop_size = tuple(args.val_crop_size)
    train_resize_size = tuple(args.train_resize_size)
    train_crop_size = tuple(args.train_crop_size)

    traindir = os.path.join(args.data_path, "train")
    valdir = os.path.join(args.data_path, "val")
159
160
161

    print("Loading training data")
    st = time.time()
limm's avatar
limm committed
162
163
    cache_path = _get_cache_path(traindir, args)
    transform_train = presets.VideoClassificationPresetTrain(crop_size=train_crop_size, resize_size=train_resize_size)
164
165

    if args.cache_dataset and os.path.exists(cache_path):
limm's avatar
limm committed
166
167
        print(f"Loading dataset_train from {cache_path}")
        dataset, _ = torch.load(cache_path, weights_only=True)
168
169
170
        dataset.transform = transform_train
    else:
        if args.distributed:
limm's avatar
limm committed
171
172
173
            print("It is recommended to pre-compute the dataset cache on a single-gpu first, as it will be faster")
        dataset = datasets.KineticsWithVideoId(
            args.data_path,
174
            frames_per_clip=args.clip_len,
limm's avatar
limm committed
175
176
            num_classes=args.kinetics_version,
            split="train",
177
            step_between_clips=1,
178
            transform=transform_train,
limm's avatar
limm committed
179
180
181
182
183
184
            frame_rate=args.frame_rate,
            extensions=(
                "avi",
                "mp4",
            ),
            output_format="TCHW",
185
186
        )
        if args.cache_dataset:
limm's avatar
limm committed
187
            print(f"Saving dataset_train to {cache_path}")
188
189
190
191
192
193
            utils.mkdir(os.path.dirname(cache_path))
            utils.save_on_master((dataset, traindir), cache_path)

    print("Took", time.time() - st)

    print("Loading validation data")
limm's avatar
limm committed
194
    cache_path = _get_cache_path(valdir, args)
195

limm's avatar
limm committed
196
197
198
199
200
    if args.weights and args.test_only:
        weights = torchvision.models.get_weight(args.weights)
        transform_test = weights.transforms()
    else:
        transform_test = presets.VideoClassificationPresetEval(crop_size=val_crop_size, resize_size=val_resize_size)
201
202

    if args.cache_dataset and os.path.exists(cache_path):
limm's avatar
limm committed
203
204
        print(f"Loading dataset_test from {cache_path}")
        dataset_test, _ = torch.load(cache_path, weights_only=True)
205
206
207
        dataset_test.transform = transform_test
    else:
        if args.distributed:
limm's avatar
limm committed
208
209
210
            print("It is recommended to pre-compute the dataset cache on a single-gpu first, as it will be faster")
        dataset_test = datasets.KineticsWithVideoId(
            args.data_path,
211
            frames_per_clip=args.clip_len,
limm's avatar
limm committed
212
213
            num_classes=args.kinetics_version,
            split="val",
214
            step_between_clips=1,
215
            transform=transform_test,
limm's avatar
limm committed
216
217
218
219
220
221
            frame_rate=args.frame_rate,
            extensions=(
                "avi",
                "mp4",
            ),
            output_format="TCHW",
222
223
        )
        if args.cache_dataset:
limm's avatar
limm committed
224
            print(f"Saving dataset_test to {cache_path}")
225
226
227
228
            utils.mkdir(os.path.dirname(cache_path))
            utils.save_on_master((dataset_test, valdir), cache_path)

    print("Creating data loaders")
229
    train_sampler = RandomClipSampler(dataset.video_clips, args.clips_per_video)
230
231
232
    test_sampler = UniformClipSampler(dataset_test.video_clips, args.clips_per_video)
    if args.distributed:
        train_sampler = DistributedSampler(train_sampler)
limm's avatar
limm committed
233
        test_sampler = DistributedSampler(test_sampler, shuffle=False)
234
235

    data_loader = torch.utils.data.DataLoader(
limm's avatar
limm committed
236
237
238
239
240
241
242
        dataset,
        batch_size=args.batch_size,
        sampler=train_sampler,
        num_workers=args.workers,
        pin_memory=True,
        collate_fn=collate_fn,
    )
243
244

    data_loader_test = torch.utils.data.DataLoader(
limm's avatar
limm committed
245
246
247
248
249
250
251
        dataset_test,
        batch_size=args.batch_size,
        sampler=test_sampler,
        num_workers=args.workers,
        pin_memory=True,
        collate_fn=collate_fn,
    )
252
253

    print("Creating model")
limm's avatar
limm committed
254
    model = torchvision.models.get_model(args.model, weights=args.weights)
255
256
257
258
259
260
    model.to(device)
    if args.distributed and args.sync_bn:
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

    criterion = nn.CrossEntropyLoss()

limm's avatar
limm committed
261
262
    optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
    scaler = torch.cuda.amp.GradScaler() if args.amp else None
263
264
265

    # convert scheduler to be per iteration, not per epoch, for warmup that lasts
    # between different epochs
limm's avatar
limm committed
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
    iters_per_epoch = len(data_loader)
    lr_milestones = [iters_per_epoch * (m - args.lr_warmup_epochs) for m in args.lr_milestones]
    main_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=lr_milestones, gamma=args.lr_gamma)

    if args.lr_warmup_epochs > 0:
        warmup_iters = iters_per_epoch * args.lr_warmup_epochs
        args.lr_warmup_method = args.lr_warmup_method.lower()
        if args.lr_warmup_method == "linear":
            warmup_lr_scheduler = torch.optim.lr_scheduler.LinearLR(
                optimizer, start_factor=args.lr_warmup_decay, total_iters=warmup_iters
            )
        elif args.lr_warmup_method == "constant":
            warmup_lr_scheduler = torch.optim.lr_scheduler.ConstantLR(
                optimizer, factor=args.lr_warmup_decay, total_iters=warmup_iters
            )
        else:
            raise RuntimeError(
                f"Invalid warmup lr method '{args.lr_warmup_method}'. Only linear and constant are supported."
            )

        lr_scheduler = torch.optim.lr_scheduler.SequentialLR(
            optimizer, schedulers=[warmup_lr_scheduler, main_lr_scheduler], milestones=[warmup_iters]
        )
    else:
        lr_scheduler = main_lr_scheduler
291
292
293
294
295
296
297

    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
        model_without_ddp = model.module

    if args.resume:
limm's avatar
limm committed
298
299
300
301
302
303
304
        checkpoint = torch.load(args.resume, map_location="cpu", weights_only=True)
        model_without_ddp.load_state_dict(checkpoint["model"])
        optimizer.load_state_dict(checkpoint["optimizer"])
        lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
        args.start_epoch = checkpoint["epoch"] + 1
        if args.amp:
            scaler.load_state_dict(checkpoint["scaler"])
305
306

    if args.test_only:
limm's avatar
limm committed
307
308
309
        # We disable the cudnn benchmarking because it can noticeably affect the accuracy
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True
310
311
312
313
314
315
316
317
        evaluate(model, criterion, data_loader_test, device=device)
        return

    print("Start training")
    start_time = time.time()
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
limm's avatar
limm committed
318
        train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, device, epoch, args.print_freq, scaler)
319
320
321
        evaluate(model, criterion, data_loader_test, device=device)
        if args.output_dir:
            checkpoint = {
limm's avatar
limm committed
322
323
324
325
326
327
328
329
330
331
                "model": model_without_ddp.state_dict(),
                "optimizer": optimizer.state_dict(),
                "lr_scheduler": lr_scheduler.state_dict(),
                "epoch": epoch,
                "args": args,
            }
            if args.amp:
                checkpoint["scaler"] = scaler.state_dict()
            utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth"))
            utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth"))
332
333
334

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
limm's avatar
limm committed
335
    print(f"Training time {total_time_str}")
336
337


limm's avatar
limm committed
338
def get_args_parser(add_help=True):
339
    import argparse
limm's avatar
limm committed
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380

    parser = argparse.ArgumentParser(description="PyTorch Video Classification Training", add_help=add_help)

    parser.add_argument("--data-path", default="/datasets01_101/kinetics/070618/", type=str, help="dataset path")
    parser.add_argument(
        "--kinetics-version", default="400", type=str, choices=["400", "600"], help="Select kinetics version"
    )
    parser.add_argument("--model", default="r2plus1d_18", type=str, help="model name")
    parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)")
    parser.add_argument("--clip-len", default=16, type=int, metavar="N", help="number of frames per clip")
    parser.add_argument("--frame-rate", default=15, type=int, metavar="N", help="the frame rate")
    parser.add_argument(
        "--clips-per-video", default=5, type=int, metavar="N", help="maximum number of clips per video to consider"
    )
    parser.add_argument(
        "-b", "--batch-size", default=24, type=int, help="images per gpu, the total batch size is $NGPU x batch_size"
    )
    parser.add_argument("--epochs", default=45, type=int, metavar="N", help="number of total epochs to run")
    parser.add_argument(
        "-j", "--workers", default=10, type=int, metavar="N", help="number of data loading workers (default: 10)"
    )
    parser.add_argument("--lr", default=0.64, type=float, help="initial learning rate")
    parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
    parser.add_argument(
        "--wd",
        "--weight-decay",
        default=1e-4,
        type=float,
        metavar="W",
        help="weight decay (default: 1e-4)",
        dest="weight_decay",
    )
    parser.add_argument("--lr-milestones", nargs="+", default=[20, 30, 40], type=int, help="decrease lr on milestones")
    parser.add_argument("--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma")
    parser.add_argument("--lr-warmup-epochs", default=10, type=int, help="the number of epochs to warmup (default: 10)")
    parser.add_argument("--lr-warmup-method", default="linear", type=str, help="the warmup method (default: linear)")
    parser.add_argument("--lr-warmup-decay", default=0.001, type=float, help="the decay for lr")
    parser.add_argument("--print-freq", default=10, type=int, help="print frequency")
    parser.add_argument("--output-dir", default=".", type=str, help="path to save outputs")
    parser.add_argument("--resume", default="", type=str, help="path of checkpoint")
    parser.add_argument("--start-epoch", default=0, type=int, metavar="N", help="start epoch")
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
    parser.add_argument(
        "--cache-dataset",
        dest="cache_dataset",
        help="Cache the datasets for quicker initialization. It also serializes the transforms",
        action="store_true",
    )
    parser.add_argument(
        "--sync-bn",
        dest="sync_bn",
        help="Use sync batch norm",
        action="store_true",
    )
    parser.add_argument(
        "--test-only",
        dest="test_only",
        help="Only test the model",
        action="store_true",
    )
    parser.add_argument(
limm's avatar
limm committed
400
        "--use-deterministic-algorithms", action="store_true", help="Forces the use of deterministic algorithms only."
401
402
403
    )

    # distributed training parameters
limm's avatar
limm committed
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
    parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
    parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training")

    parser.add_argument(
        "--val-resize-size",
        default=(128, 171),
        nargs="+",
        type=int,
        help="the resize size used for validation (default: (128, 171))",
    )
    parser.add_argument(
        "--val-crop-size",
        default=(112, 112),
        nargs="+",
        type=int,
        help="the central crop size used for validation (default: (112, 112))",
    )
    parser.add_argument(
        "--train-resize-size",
        default=(128, 171),
        nargs="+",
        type=int,
        help="the resize size used for training (default: (128, 171))",
    )
    parser.add_argument(
        "--train-crop-size",
        default=(112, 112),
        nargs="+",
        type=int,
        help="the random crop size used for training (default: (112, 112))",
    )

    parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
437

limm's avatar
limm committed
438
439
    # Mixed precision training parameters
    parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training")
440

limm's avatar
limm committed
441
    return parser
442
443
444


if __name__ == "__main__":
limm's avatar
limm committed
445
    args = get_args_parser().parse_args()
446
    main(args)