train.py 16.8 KB
Newer Older
1
2
3
import datetime
import os
import time
4
import warnings
5

6
import datasets
7
import presets
8
9
10
11
12
import torch
import torch.utils.data
import torchvision
import torchvision.datasets.video_utils
import utils
13
14
from torch import nn
from torch.utils.data.dataloader import default_collate
15
from torchvision.datasets.samplers import DistributedSampler, RandomClipSampler, UniformClipSampler
16

17

18
def train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, device, epoch, print_freq, scaler=None):
19
20
    model.train()
    metric_logger = utils.MetricLogger(delimiter="  ")
21
22
    metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}"))
    metric_logger.add_meter("clips/s", utils.SmoothedValue(window_size=10, fmt="{value:.3f}"))
23

24
    header = f"Epoch: [{epoch}]"
25
    for video, target, _ in metric_logger.log_every(data_loader, print_freq, header):
26
27
        start_time = time.time()
        video, target = video.to(device), target.to(device)
28
29
30
        with torch.cuda.amp.autocast(enabled=scaler is not None):
            output = model(video)
            loss = criterion(output, target)
31
32

        optimizer.zero_grad()
33
34
35
36
37

        if scaler is not None:
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
38
39
        else:
            loss.backward()
40
            optimizer.step()
41
42
43
44

        acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
        batch_size = video.shape[0]
        metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"])
45
46
47
        metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
        metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)
        metric_logger.meters["clips/s"].update(batch_size / (time.time() - start_time))
48
49
50
51
52
53
        lr_scheduler.step()


def evaluate(model, criterion, data_loader, device):
    model.eval()
    metric_logger = utils.MetricLogger(delimiter="  ")
54
    header = "Test:"
55
    num_processed_samples = 0
56
57
58
59
60
    # Group and aggregate output of a video
    num_videos = len(data_loader.dataset.samples)
    num_classes = len(data_loader.dataset.classes)
    agg_preds = torch.zeros((num_videos, num_classes), dtype=torch.float32, device=device)
    agg_targets = torch.zeros((num_videos), dtype=torch.int32, device=device)
61
    with torch.inference_mode():
62
        for video, target, video_idx in metric_logger.log_every(data_loader, 100, header):
63
64
65
66
67
            video = video.to(device, non_blocking=True)
            target = target.to(device, non_blocking=True)
            output = model(video)
            loss = criterion(output, target)

68
69
70
71
72
73
74
            # Use softmax to convert output into prediction probability
            preds = torch.softmax(output, dim=1)
            for b in range(video.size(0)):
                idx = video_idx[b].item()
                agg_preds[idx] += preds[b].detach()
                agg_targets[idx] = target[b].detach().item()

75
76
77
78
79
            acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
            # FIXME need to take into account that the datasets
            # could have been padded in distributed setup
            batch_size = video.shape[0]
            metric_logger.update(loss=loss.item())
80
81
            metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
            metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)
82
            num_processed_samples += batch_size
83
    # gather the stats from all processes
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
    num_processed_samples = utils.reduce_across_processes(num_processed_samples)
    if isinstance(data_loader.sampler, DistributedSampler):
        # Get the len of UniformClipSampler inside DistributedSampler
        num_data_from_sampler = len(data_loader.sampler.dataset)
    else:
        num_data_from_sampler = len(data_loader.sampler)

    if (
        hasattr(data_loader.dataset, "__len__")
        and num_data_from_sampler != num_processed_samples
        and torch.distributed.get_rank() == 0
    ):
        # See FIXME above
        warnings.warn(
            f"It looks like the sampler has {num_data_from_sampler} samples, but {num_processed_samples} "
            "samples were used for the validation, which might bias the results. "
            "Try adjusting the batch size and / or the world size. "
            "Setting the world size to 1 is always a safe bet."
        )

104
105
    metric_logger.synchronize_between_processes()

106
107
108
109
110
    print(
        " * Clip Acc@1 {top1.global_avg:.3f} Clip Acc@5 {top5.global_avg:.3f}".format(
            top1=metric_logger.acc1, top5=metric_logger.acc5
        )
    )
111
112
113
114
115
    # Reduce the agg_preds and agg_targets from all gpu and show result
    agg_preds = utils.reduce_across_processes(agg_preds)
    agg_targets = utils.reduce_across_processes(agg_targets, op=torch.distributed.ReduceOp.MAX)
    agg_acc1, agg_acc5 = utils.accuracy(agg_preds, agg_targets, topk=(1, 5))
    print(" * Video Acc@1 {acc1:.3f} Video Acc@5 {acc5:.3f}".format(acc1=agg_acc1, acc5=agg_acc5))
116
117
118
    return metric_logger.acc1.global_avg


119
def _get_cache_path(filepath, args):
120
    import hashlib
121

122
123
    value = f"{filepath}-{args.clip_len}-{args.kinetics_version}-{args.frame_rate}"
    h = hashlib.sha1(value.encode()).hexdigest()
124
125
126
127
128
129
130
    cache_path = os.path.join("~", ".torch", "vision", "datasets", "kinetics", h[:10] + ".pt")
    cache_path = os.path.expanduser(cache_path)
    return cache_path


def collate_fn(batch):
    # remove audio from the batch
131
    batch = [(d[0], d[2], d[3]) for d in batch]
132
133
134
135
136
137
138
139
140
141
142
143
    return default_collate(batch)


def main(args):
    if args.output_dir:
        utils.mkdir(args.output_dir)

    utils.init_distributed_mode(args)
    print(args)

    device = torch.device(args.device)

144
145
146
147
148
    if args.use_deterministic_algorithms:
        torch.backends.cudnn.benchmark = False
        torch.use_deterministic_algorithms(True)
    else:
        torch.backends.cudnn.benchmark = True
149
150
151

    # Data loading code
    print("Loading data")
152
153
    traindir = os.path.join(args.data_path, "train")
    valdir = os.path.join(args.data_path, "val")
154
155
156

    print("Loading training data")
    st = time.time()
157
    cache_path = _get_cache_path(traindir, args)
158
    transform_train = presets.VideoClassificationPresetTrain(crop_size=(112, 112), resize_size=(128, 171))
159
160

    if args.cache_dataset and os.path.exists(cache_path):
161
        print(f"Loading dataset_train from {cache_path}")
162
163
164
165
        dataset, _ = torch.load(cache_path)
        dataset.transform = transform_train
    else:
        if args.distributed:
166
            print("It is recommended to pre-compute the dataset cache on a single-gpu first, as it will be faster")
167
        dataset = datasets.KineticsWithVideoId(
168
            args.data_path,
169
            frames_per_clip=args.clip_len,
170
171
            num_classes=args.kinetics_version,
            split="train",
172
            step_between_clips=1,
173
            transform=transform_train,
174
            frame_rate=args.frame_rate,
175
176
177
178
            extensions=(
                "avi",
                "mp4",
            ),
179
            output_format="TCHW",
180
181
        )
        if args.cache_dataset:
182
            print(f"Saving dataset_train to {cache_path}")
183
184
185
186
187
188
            utils.mkdir(os.path.dirname(cache_path))
            utils.save_on_master((dataset, traindir), cache_path)

    print("Took", time.time() - st)

    print("Loading validation data")
189
    cache_path = _get_cache_path(valdir, args)
190

191
192
193
    if args.weights and args.test_only:
        weights = torchvision.models.get_weight(args.weights)
        transform_test = weights.transforms()
194
    else:
195
        transform_test = presets.VideoClassificationPresetEval(crop_size=(112, 112), resize_size=(128, 171))
196
197

    if args.cache_dataset and os.path.exists(cache_path):
198
        print(f"Loading dataset_test from {cache_path}")
199
200
201
202
        dataset_test, _ = torch.load(cache_path)
        dataset_test.transform = transform_test
    else:
        if args.distributed:
203
            print("It is recommended to pre-compute the dataset cache on a single-gpu first, as it will be faster")
204
        dataset_test = datasets.KineticsWithVideoId(
205
            args.data_path,
206
            frames_per_clip=args.clip_len,
207
208
            num_classes=args.kinetics_version,
            split="val",
209
            step_between_clips=1,
210
            transform=transform_test,
211
            frame_rate=args.frame_rate,
212
213
214
215
            extensions=(
                "avi",
                "mp4",
            ),
216
            output_format="TCHW",
217
218
        )
        if args.cache_dataset:
219
            print(f"Saving dataset_test to {cache_path}")
220
221
222
223
            utils.mkdir(os.path.dirname(cache_path))
            utils.save_on_master((dataset_test, valdir), cache_path)

    print("Creating data loaders")
224
    train_sampler = RandomClipSampler(dataset.video_clips, args.clips_per_video)
225
226
227
    test_sampler = UniformClipSampler(dataset_test.video_clips, args.clips_per_video)
    if args.distributed:
        train_sampler = DistributedSampler(train_sampler)
228
        test_sampler = DistributedSampler(test_sampler, shuffle=False)
229
230

    data_loader = torch.utils.data.DataLoader(
231
232
233
234
235
236
237
        dataset,
        batch_size=args.batch_size,
        sampler=train_sampler,
        num_workers=args.workers,
        pin_memory=True,
        collate_fn=collate_fn,
    )
238
239

    data_loader_test = torch.utils.data.DataLoader(
240
241
242
243
244
245
246
        dataset_test,
        batch_size=args.batch_size,
        sampler=test_sampler,
        num_workers=args.workers,
        pin_memory=True,
        collate_fn=collate_fn,
    )
247
248

    print("Creating model")
249
    model = torchvision.models.get_model(args.model, weights=args.weights)
250
251
252
253
254
255
256
    model.to(device)
    if args.distributed and args.sync_bn:
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

    criterion = nn.CrossEntropyLoss()

    lr = args.lr * args.world_size
257
    optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=args.momentum, weight_decay=args.weight_decay)
258
    scaler = torch.cuda.amp.GradScaler() if args.amp else None
259
260
261

    # convert scheduler to be per iteration, not per epoch, for warmup that lasts
    # between different epochs
262
263
264
265
266
267
268
    iters_per_epoch = len(data_loader)
    lr_milestones = [iters_per_epoch * (m - args.lr_warmup_epochs) for m in args.lr_milestones]
    main_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=lr_milestones, gamma=args.lr_gamma)

    if args.lr_warmup_epochs > 0:
        warmup_iters = iters_per_epoch * args.lr_warmup_epochs
        args.lr_warmup_method = args.lr_warmup_method.lower()
269
270
271
272
273
274
275
276
        if args.lr_warmup_method == "linear":
            warmup_lr_scheduler = torch.optim.lr_scheduler.LinearLR(
                optimizer, start_factor=args.lr_warmup_decay, total_iters=warmup_iters
            )
        elif args.lr_warmup_method == "constant":
            warmup_lr_scheduler = torch.optim.lr_scheduler.ConstantLR(
                optimizer, factor=args.lr_warmup_decay, total_iters=warmup_iters
            )
277
        else:
278
            raise RuntimeError(
279
                f"Invalid warmup lr method '{args.lr_warmup_method}'. Only linear and constant are supported."
280
            )
281
282

        lr_scheduler = torch.optim.lr_scheduler.SequentialLR(
283
            optimizer, schedulers=[warmup_lr_scheduler, main_lr_scheduler], milestones=[warmup_iters]
284
285
286
        )
    else:
        lr_scheduler = main_lr_scheduler
287
288
289
290
291
292
293

    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
        model_without_ddp = model.module

    if args.resume:
294
295
296
297
298
        checkpoint = torch.load(args.resume, map_location="cpu")
        model_without_ddp.load_state_dict(checkpoint["model"])
        optimizer.load_state_dict(checkpoint["optimizer"])
        lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
        args.start_epoch = checkpoint["epoch"] + 1
299
300
        if args.amp:
            scaler.load_state_dict(checkpoint["scaler"])
301
302

    if args.test_only:
303
304
305
        # We disable the cudnn benchmarking because it can noticeably affect the accuracy
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True
306
307
308
309
310
311
312
313
        evaluate(model, criterion, data_loader_test, device=device)
        return

    print("Start training")
    start_time = time.time()
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
314
        train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, device, epoch, args.print_freq, scaler)
315
316
317
        evaluate(model, criterion, data_loader_test, device=device)
        if args.output_dir:
            checkpoint = {
318
319
320
321
322
323
                "model": model_without_ddp.state_dict(),
                "optimizer": optimizer.state_dict(),
                "lr_scheduler": lr_scheduler.state_dict(),
                "epoch": epoch,
                "args": args,
            }
324
325
            if args.amp:
                checkpoint["scaler"] = scaler.state_dict()
326
            utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth"))
327
            utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth"))
328
329
330

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
331
    print(f"Training time {total_time_str}")
332
333


334
def get_args_parser(add_help=True):
335
    import argparse
336

337
    parser = argparse.ArgumentParser(description="PyTorch Video Classification Training", add_help=add_help)
338

339
    parser.add_argument("--data-path", default="/datasets01_101/kinetics/070618/", type=str, help="dataset path")
340
341
342
    parser.add_argument(
        "--kinetics-version", default="400", type=str, choices=["400", "600"], help="Select kinetics version"
    )
343
344
    parser.add_argument("--model", default="r2plus1d_18", type=str, help="model name")
    parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)")
345
    parser.add_argument("--clip-len", default=16, type=int, metavar="N", help="number of frames per clip")
346
    parser.add_argument("--frame-rate", default=15, type=int, metavar="N", help="the frame rate")
347
348
349
    parser.add_argument(
        "--clips-per-video", default=5, type=int, metavar="N", help="maximum number of clips per video to consider"
    )
350
351
352
    parser.add_argument(
        "-b", "--batch-size", default=24, type=int, help="images per gpu, the total batch size is $NGPU x batch_size"
    )
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
    parser.add_argument("--epochs", default=45, type=int, metavar="N", help="number of total epochs to run")
    parser.add_argument(
        "-j", "--workers", default=10, type=int, metavar="N", help="number of data loading workers (default: 10)"
    )
    parser.add_argument("--lr", default=0.01, type=float, help="initial learning rate")
    parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
    parser.add_argument(
        "--wd",
        "--weight-decay",
        default=1e-4,
        type=float,
        metavar="W",
        help="weight decay (default: 1e-4)",
        dest="weight_decay",
    )
    parser.add_argument("--lr-milestones", nargs="+", default=[20, 30, 40], type=int, help="decrease lr on milestones")
    parser.add_argument("--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma")
    parser.add_argument("--lr-warmup-epochs", default=10, type=int, help="the number of epochs to warmup (default: 10)")
    parser.add_argument("--lr-warmup-method", default="linear", type=str, help="the warmup method (default: linear)")
    parser.add_argument("--lr-warmup-decay", default=0.001, type=float, help="the decay for lr")
    parser.add_argument("--print-freq", default=10, type=int, help="print frequency")
374
375
    parser.add_argument("--output-dir", default=".", type=str, help="path to save outputs")
    parser.add_argument("--resume", default="", type=str, help="path of checkpoint")
376
    parser.add_argument("--start-epoch", default=0, type=int, metavar="N", help="start epoch")
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
    parser.add_argument(
        "--cache-dataset",
        dest="cache_dataset",
        help="Cache the datasets for quicker initialization. It also serializes the transforms",
        action="store_true",
    )
    parser.add_argument(
        "--sync-bn",
        dest="sync_bn",
        help="Use sync batch norm",
        action="store_true",
    )
    parser.add_argument(
        "--test-only",
        dest="test_only",
        help="Only test the model",
        action="store_true",
    )
395
396
397
    parser.add_argument(
        "--use-deterministic-algorithms", action="store_true", help="Forces the use of deterministic algorithms only."
    )
398
399

    # distributed training parameters
400
    parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
401
    parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training")
402

403
404
    parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")

405
406
407
    # Mixed precision training parameters
    parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training")

408
    return parser
409
410
411


if __name__ == "__main__":
412
    args = get_args_parser().parse_args()
413
    main(args)