"vscode:/vscode.git/clone" did not exist on "6fc9d2270786a76952a5d55ec75a9f9dcde73e26"
train.py 17.8 KB
Newer Older
1
2
3
import datetime
import os
import time
4
import warnings
5

6
import datasets
7
import presets
8
9
10
11
12
import torch
import torch.utils.data
import torchvision
import torchvision.datasets.video_utils
import utils
13
14
from torch import nn
from torch.utils.data.dataloader import default_collate
15
from torchvision.datasets.samplers import DistributedSampler, RandomClipSampler, UniformClipSampler
16

17

18
def train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, device, epoch, print_freq, scaler=None):
19
20
    model.train()
    metric_logger = utils.MetricLogger(delimiter="  ")
21
22
    metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}"))
    metric_logger.add_meter("clips/s", utils.SmoothedValue(window_size=10, fmt="{value:.3f}"))
23

24
    header = f"Epoch: [{epoch}]"
25
    for video, target, _ in metric_logger.log_every(data_loader, print_freq, header):
26
27
        start_time = time.time()
        video, target = video.to(device), target.to(device)
28
29
30
        with torch.cuda.amp.autocast(enabled=scaler is not None):
            output = model(video)
            loss = criterion(output, target)
31
32

        optimizer.zero_grad()
33
34
35
36
37

        if scaler is not None:
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
38
39
        else:
            loss.backward()
40
            optimizer.step()
41
42
43
44

        acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
        batch_size = video.shape[0]
        metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"])
45
46
47
        metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
        metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)
        metric_logger.meters["clips/s"].update(batch_size / (time.time() - start_time))
48
49
50
51
52
53
        lr_scheduler.step()


def evaluate(model, criterion, data_loader, device):
    model.eval()
    metric_logger = utils.MetricLogger(delimiter="  ")
54
    header = "Test:"
55
    num_processed_samples = 0
56
57
58
59
60
    # Group and aggregate output of a video
    num_videos = len(data_loader.dataset.samples)
    num_classes = len(data_loader.dataset.classes)
    agg_preds = torch.zeros((num_videos, num_classes), dtype=torch.float32, device=device)
    agg_targets = torch.zeros((num_videos), dtype=torch.int32, device=device)
61
    with torch.inference_mode():
62
        for video, target, video_idx in metric_logger.log_every(data_loader, 100, header):
63
64
65
66
67
            video = video.to(device, non_blocking=True)
            target = target.to(device, non_blocking=True)
            output = model(video)
            loss = criterion(output, target)

68
69
70
71
72
73
74
            # Use softmax to convert output into prediction probability
            preds = torch.softmax(output, dim=1)
            for b in range(video.size(0)):
                idx = video_idx[b].item()
                agg_preds[idx] += preds[b].detach()
                agg_targets[idx] = target[b].detach().item()

75
76
77
78
79
            acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
            # FIXME need to take into account that the datasets
            # could have been padded in distributed setup
            batch_size = video.shape[0]
            metric_logger.update(loss=loss.item())
80
81
            metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
            metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)
82
            num_processed_samples += batch_size
83
    # gather the stats from all processes
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
    num_processed_samples = utils.reduce_across_processes(num_processed_samples)
    if isinstance(data_loader.sampler, DistributedSampler):
        # Get the len of UniformClipSampler inside DistributedSampler
        num_data_from_sampler = len(data_loader.sampler.dataset)
    else:
        num_data_from_sampler = len(data_loader.sampler)

    if (
        hasattr(data_loader.dataset, "__len__")
        and num_data_from_sampler != num_processed_samples
        and torch.distributed.get_rank() == 0
    ):
        # See FIXME above
        warnings.warn(
            f"It looks like the sampler has {num_data_from_sampler} samples, but {num_processed_samples} "
            "samples were used for the validation, which might bias the results. "
            "Try adjusting the batch size and / or the world size. "
            "Setting the world size to 1 is always a safe bet."
        )

104
105
    metric_logger.synchronize_between_processes()

106
107
108
109
110
    print(
        " * Clip Acc@1 {top1.global_avg:.3f} Clip Acc@5 {top5.global_avg:.3f}".format(
            top1=metric_logger.acc1, top5=metric_logger.acc5
        )
    )
111
112
113
114
115
    # Reduce the agg_preds and agg_targets from all gpu and show result
    agg_preds = utils.reduce_across_processes(agg_preds)
    agg_targets = utils.reduce_across_processes(agg_targets, op=torch.distributed.ReduceOp.MAX)
    agg_acc1, agg_acc5 = utils.accuracy(agg_preds, agg_targets, topk=(1, 5))
    print(" * Video Acc@1 {acc1:.3f} Video Acc@5 {acc5:.3f}".format(acc1=agg_acc1, acc5=agg_acc5))
116
117
118
    return metric_logger.acc1.global_avg


119
def _get_cache_path(filepath, args):
120
    import hashlib
121

122
123
    value = f"{filepath}-{args.clip_len}-{args.kinetics_version}-{args.frame_rate}"
    h = hashlib.sha1(value.encode()).hexdigest()
124
125
126
127
128
129
130
    cache_path = os.path.join("~", ".torch", "vision", "datasets", "kinetics", h[:10] + ".pt")
    cache_path = os.path.expanduser(cache_path)
    return cache_path


def collate_fn(batch):
    # remove audio from the batch
131
    batch = [(d[0], d[2], d[3]) for d in batch]
132
133
134
135
136
137
138
139
140
141
142
143
    return default_collate(batch)


def main(args):
    if args.output_dir:
        utils.mkdir(args.output_dir)

    utils.init_distributed_mode(args)
    print(args)

    device = torch.device(args.device)

144
145
146
147
148
    if args.use_deterministic_algorithms:
        torch.backends.cudnn.benchmark = False
        torch.use_deterministic_algorithms(True)
    else:
        torch.backends.cudnn.benchmark = True
149
150
151

    # Data loading code
    print("Loading data")
152
153
154
155
156
    val_resize_size = tuple(args.val_resize_size)
    val_crop_size = tuple(args.val_crop_size)
    train_resize_size = tuple(args.train_resize_size)
    train_crop_size = tuple(args.train_crop_size)

157
158
    traindir = os.path.join(args.data_path, "train")
    valdir = os.path.join(args.data_path, "val")
159
160
161

    print("Loading training data")
    st = time.time()
162
    cache_path = _get_cache_path(traindir, args)
163
    transform_train = presets.VideoClassificationPresetTrain(crop_size=train_crop_size, resize_size=train_resize_size)
164
165

    if args.cache_dataset and os.path.exists(cache_path):
166
        print(f"Loading dataset_train from {cache_path}")
167
168
169
170
        dataset, _ = torch.load(cache_path)
        dataset.transform = transform_train
    else:
        if args.distributed:
171
            print("It is recommended to pre-compute the dataset cache on a single-gpu first, as it will be faster")
172
        dataset = datasets.KineticsWithVideoId(
173
            args.data_path,
174
            frames_per_clip=args.clip_len,
175
176
            num_classes=args.kinetics_version,
            split="train",
177
            step_between_clips=1,
178
            transform=transform_train,
179
            frame_rate=args.frame_rate,
180
181
182
183
            extensions=(
                "avi",
                "mp4",
            ),
184
            output_format="TCHW",
185
186
        )
        if args.cache_dataset:
187
            print(f"Saving dataset_train to {cache_path}")
188
189
190
191
192
193
            utils.mkdir(os.path.dirname(cache_path))
            utils.save_on_master((dataset, traindir), cache_path)

    print("Took", time.time() - st)

    print("Loading validation data")
194
    cache_path = _get_cache_path(valdir, args)
195

196
197
198
    if args.weights and args.test_only:
        weights = torchvision.models.get_weight(args.weights)
        transform_test = weights.transforms()
199
    else:
200
        transform_test = presets.VideoClassificationPresetEval(crop_size=val_crop_size, resize_size=val_resize_size)
201
202

    if args.cache_dataset and os.path.exists(cache_path):
203
        print(f"Loading dataset_test from {cache_path}")
204
205
206
207
        dataset_test, _ = torch.load(cache_path)
        dataset_test.transform = transform_test
    else:
        if args.distributed:
208
            print("It is recommended to pre-compute the dataset cache on a single-gpu first, as it will be faster")
209
        dataset_test = datasets.KineticsWithVideoId(
210
            args.data_path,
211
            frames_per_clip=args.clip_len,
212
213
            num_classes=args.kinetics_version,
            split="val",
214
            step_between_clips=1,
215
            transform=transform_test,
216
            frame_rate=args.frame_rate,
217
218
219
220
            extensions=(
                "avi",
                "mp4",
            ),
221
            output_format="TCHW",
222
223
        )
        if args.cache_dataset:
224
            print(f"Saving dataset_test to {cache_path}")
225
226
227
228
            utils.mkdir(os.path.dirname(cache_path))
            utils.save_on_master((dataset_test, valdir), cache_path)

    print("Creating data loaders")
229
    train_sampler = RandomClipSampler(dataset.video_clips, args.clips_per_video)
230
231
232
    test_sampler = UniformClipSampler(dataset_test.video_clips, args.clips_per_video)
    if args.distributed:
        train_sampler = DistributedSampler(train_sampler)
233
        test_sampler = DistributedSampler(test_sampler, shuffle=False)
234
235

    data_loader = torch.utils.data.DataLoader(
236
237
238
239
240
241
242
        dataset,
        batch_size=args.batch_size,
        sampler=train_sampler,
        num_workers=args.workers,
        pin_memory=True,
        collate_fn=collate_fn,
    )
243
244

    data_loader_test = torch.utils.data.DataLoader(
245
246
247
248
249
250
251
        dataset_test,
        batch_size=args.batch_size,
        sampler=test_sampler,
        num_workers=args.workers,
        pin_memory=True,
        collate_fn=collate_fn,
    )
252
253

    print("Creating model")
254
    model = torchvision.models.get_model(args.model, weights=args.weights)
255
256
257
258
259
260
    model.to(device)
    if args.distributed and args.sync_bn:
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

    criterion = nn.CrossEntropyLoss()

261
    optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
262
    scaler = torch.cuda.amp.GradScaler() if args.amp else None
263
264
265

    # convert scheduler to be per iteration, not per epoch, for warmup that lasts
    # between different epochs
266
267
268
269
270
271
272
    iters_per_epoch = len(data_loader)
    lr_milestones = [iters_per_epoch * (m - args.lr_warmup_epochs) for m in args.lr_milestones]
    main_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=lr_milestones, gamma=args.lr_gamma)

    if args.lr_warmup_epochs > 0:
        warmup_iters = iters_per_epoch * args.lr_warmup_epochs
        args.lr_warmup_method = args.lr_warmup_method.lower()
273
274
275
276
277
278
279
280
        if args.lr_warmup_method == "linear":
            warmup_lr_scheduler = torch.optim.lr_scheduler.LinearLR(
                optimizer, start_factor=args.lr_warmup_decay, total_iters=warmup_iters
            )
        elif args.lr_warmup_method == "constant":
            warmup_lr_scheduler = torch.optim.lr_scheduler.ConstantLR(
                optimizer, factor=args.lr_warmup_decay, total_iters=warmup_iters
            )
281
        else:
282
            raise RuntimeError(
283
                f"Invalid warmup lr method '{args.lr_warmup_method}'. Only linear and constant are supported."
284
            )
285
286

        lr_scheduler = torch.optim.lr_scheduler.SequentialLR(
287
            optimizer, schedulers=[warmup_lr_scheduler, main_lr_scheduler], milestones=[warmup_iters]
288
289
290
        )
    else:
        lr_scheduler = main_lr_scheduler
291
292
293
294
295
296
297

    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
        model_without_ddp = model.module

    if args.resume:
298
299
300
301
302
        checkpoint = torch.load(args.resume, map_location="cpu")
        model_without_ddp.load_state_dict(checkpoint["model"])
        optimizer.load_state_dict(checkpoint["optimizer"])
        lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
        args.start_epoch = checkpoint["epoch"] + 1
303
304
        if args.amp:
            scaler.load_state_dict(checkpoint["scaler"])
305
306

    if args.test_only:
307
308
309
        # We disable the cudnn benchmarking because it can noticeably affect the accuracy
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True
310
311
312
313
314
315
316
317
        evaluate(model, criterion, data_loader_test, device=device)
        return

    print("Start training")
    start_time = time.time()
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
318
        train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, device, epoch, args.print_freq, scaler)
319
320
321
        evaluate(model, criterion, data_loader_test, device=device)
        if args.output_dir:
            checkpoint = {
322
323
324
325
326
327
                "model": model_without_ddp.state_dict(),
                "optimizer": optimizer.state_dict(),
                "lr_scheduler": lr_scheduler.state_dict(),
                "epoch": epoch,
                "args": args,
            }
328
329
            if args.amp:
                checkpoint["scaler"] = scaler.state_dict()
330
            utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth"))
331
            utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth"))
332
333
334

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
335
    print(f"Training time {total_time_str}")
336
337


338
def get_args_parser(add_help=True):
339
    import argparse
340

341
    parser = argparse.ArgumentParser(description="PyTorch Video Classification Training", add_help=add_help)
342

343
    parser.add_argument("--data-path", default="/datasets01_101/kinetics/070618/", type=str, help="dataset path")
344
345
346
    parser.add_argument(
        "--kinetics-version", default="400", type=str, choices=["400", "600"], help="Select kinetics version"
    )
347
348
    parser.add_argument("--model", default="r2plus1d_18", type=str, help="model name")
    parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)")
349
    parser.add_argument("--clip-len", default=16, type=int, metavar="N", help="number of frames per clip")
350
    parser.add_argument("--frame-rate", default=15, type=int, metavar="N", help="the frame rate")
351
352
353
    parser.add_argument(
        "--clips-per-video", default=5, type=int, metavar="N", help="maximum number of clips per video to consider"
    )
354
355
356
    parser.add_argument(
        "-b", "--batch-size", default=24, type=int, help="images per gpu, the total batch size is $NGPU x batch_size"
    )
357
358
359
360
    parser.add_argument("--epochs", default=45, type=int, metavar="N", help="number of total epochs to run")
    parser.add_argument(
        "-j", "--workers", default=10, type=int, metavar="N", help="number of data loading workers (default: 10)"
    )
361
    parser.add_argument("--lr", default=0.64, type=float, help="initial learning rate")
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
    parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
    parser.add_argument(
        "--wd",
        "--weight-decay",
        default=1e-4,
        type=float,
        metavar="W",
        help="weight decay (default: 1e-4)",
        dest="weight_decay",
    )
    parser.add_argument("--lr-milestones", nargs="+", default=[20, 30, 40], type=int, help="decrease lr on milestones")
    parser.add_argument("--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma")
    parser.add_argument("--lr-warmup-epochs", default=10, type=int, help="the number of epochs to warmup (default: 10)")
    parser.add_argument("--lr-warmup-method", default="linear", type=str, help="the warmup method (default: linear)")
    parser.add_argument("--lr-warmup-decay", default=0.001, type=float, help="the decay for lr")
    parser.add_argument("--print-freq", default=10, type=int, help="print frequency")
378
379
    parser.add_argument("--output-dir", default=".", type=str, help="path to save outputs")
    parser.add_argument("--resume", default="", type=str, help="path of checkpoint")
380
    parser.add_argument("--start-epoch", default=0, type=int, metavar="N", help="start epoch")
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
    parser.add_argument(
        "--cache-dataset",
        dest="cache_dataset",
        help="Cache the datasets for quicker initialization. It also serializes the transforms",
        action="store_true",
    )
    parser.add_argument(
        "--sync-bn",
        dest="sync_bn",
        help="Use sync batch norm",
        action="store_true",
    )
    parser.add_argument(
        "--test-only",
        dest="test_only",
        help="Only test the model",
        action="store_true",
    )
399
400
401
    parser.add_argument(
        "--use-deterministic-algorithms", action="store_true", help="Forces the use of deterministic algorithms only."
    )
402
403

    # distributed training parameters
404
    parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
405
    parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training")
406

407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
    parser.add_argument(
        "--val-resize-size",
        default=(128, 171),
        nargs="+",
        type=int,
        help="the resize size used for validation (default: (128, 171))",
    )
    parser.add_argument(
        "--val-crop-size",
        default=(112, 112),
        nargs="+",
        type=int,
        help="the central crop size used for validation (default: (112, 112))",
    )
    parser.add_argument(
        "--train-resize-size",
        default=(128, 171),
        nargs="+",
        type=int,
        help="the resize size used for training (default: (128, 171))",
    )
    parser.add_argument(
        "--train-crop-size",
        default=(112, 112),
        nargs="+",
        type=int,
        help="the random crop size used for training (default: (112, 112))",
    )

436
437
    parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")

438
439
440
    # Mixed precision training parameters
    parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training")

441
    return parser
442
443
444


if __name__ == "__main__":
445
    args = get_args_parser().parse_args()
446
    main(args)