train.py 14.5 KB
Newer Older
1
2
3
import datetime
import os
import time
4
5

import presets
6
7
8
9
10
import torch
import torch.utils.data
import torchvision
import torchvision.datasets.video_utils
import utils
11
12
13
from torch import nn
from torch.utils.data.dataloader import default_collate
from torchvision.datasets.samplers import DistributedSampler, UniformClipSampler, RandomClipSampler
14

15
16
17
18
19
20
try:
    from torchvision.prototype import models as PM
except ImportError:
    PM = None


21
def train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, device, epoch, print_freq, scaler=None):
22
23
    model.train()
    metric_logger = utils.MetricLogger(delimiter="  ")
24
25
    metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}"))
    metric_logger.add_meter("clips/s", utils.SmoothedValue(window_size=10, fmt="{value:.3f}"))
26

27
    header = f"Epoch: [{epoch}]"
28
29
30
    for video, target in metric_logger.log_every(data_loader, print_freq, header):
        start_time = time.time()
        video, target = video.to(device), target.to(device)
31
32
33
        with torch.cuda.amp.autocast(enabled=scaler is not None):
            output = model(video)
            loss = criterion(output, target)
34
35

        optimizer.zero_grad()
36
37
38
39
40

        if scaler is not None:
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
41
42
        else:
            loss.backward()
43
            optimizer.step()
44
45
46
47

        acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
        batch_size = video.shape[0]
        metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"])
48
49
50
        metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
        metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)
        metric_logger.meters["clips/s"].update(batch_size / (time.time() - start_time))
51
52
53
54
55
56
        lr_scheduler.step()


def evaluate(model, criterion, data_loader, device):
    model.eval()
    metric_logger = utils.MetricLogger(delimiter="  ")
57
    header = "Test:"
58
    with torch.inference_mode():
59
60
61
62
63
64
65
66
67
68
69
        for video, target in metric_logger.log_every(data_loader, 100, header):
            video = video.to(device, non_blocking=True)
            target = target.to(device, non_blocking=True)
            output = model(video)
            loss = criterion(output, target)

            acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
            # FIXME need to take into account that the datasets
            # could have been padded in distributed setup
            batch_size = video.shape[0]
            metric_logger.update(loss=loss.item())
70
71
            metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
            metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)
72
73
74
    # gather the stats from all processes
    metric_logger.synchronize_between_processes()

75
76
77
78
79
    print(
        " * Clip Acc@1 {top1.global_avg:.3f} Clip Acc@5 {top5.global_avg:.3f}".format(
            top1=metric_logger.acc1, top5=metric_logger.acc5
        )
    )
80
81
82
83
84
    return metric_logger.acc1.global_avg


def _get_cache_path(filepath):
    import hashlib
85

86
87
88
89
90
91
92
93
94
95
96
97
98
    h = hashlib.sha1(filepath.encode()).hexdigest()
    cache_path = os.path.join("~", ".torch", "vision", "datasets", "kinetics", h[:10] + ".pt")
    cache_path = os.path.expanduser(cache_path)
    return cache_path


def collate_fn(batch):
    # remove audio from the batch
    batch = [(d[0], d[2]) for d in batch]
    return default_collate(batch)


def main(args):
99
100
    if args.weights and PM is None:
        raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.")
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115

    if args.output_dir:
        utils.mkdir(args.output_dir)

    utils.init_distributed_mode(args)
    print(args)
    print("torch version: ", torch.__version__)
    print("torchvision version: ", torchvision.__version__)

    device = torch.device(args.device)

    torch.backends.cudnn.benchmark = True

    # Data loading code
    print("Loading data")
116
117
    traindir = os.path.join(args.data_path, args.train_dir)
    valdir = os.path.join(args.data_path, args.val_dir)
118
119
120
121

    print("Loading training data")
    st = time.time()
    cache_path = _get_cache_path(traindir)
122
    transform_train = presets.VideoClassificationPresetTrain((128, 171), (112, 112))
123
124

    if args.cache_dataset and os.path.exists(cache_path):
125
        print(f"Loading dataset_train from {cache_path}")
126
127
128
129
        dataset, _ = torch.load(cache_path)
        dataset.transform = transform_train
    else:
        if args.distributed:
130
            print("It is recommended to pre-compute the dataset cache on a single-gpu first, as it will be faster")
131
        dataset = torchvision.datasets.Kinetics400(
132
133
134
            traindir,
            frames_per_clip=args.clip_len,
            step_between_clips=1,
135
            transform=transform_train,
136
            frame_rate=15,
137
138
139
140
            extensions=(
                "avi",
                "mp4",
            ),
141
142
        )
        if args.cache_dataset:
143
            print(f"Saving dataset_train to {cache_path}")
144
145
146
147
148
149
150
151
            utils.mkdir(os.path.dirname(cache_path))
            utils.save_on_master((dataset, traindir), cache_path)

    print("Took", time.time() - st)

    print("Loading validation data")
    cache_path = _get_cache_path(valdir)

152
153
154
    if not args.weights:
        transform_test = presets.VideoClassificationPresetEval((128, 171), (112, 112))
    else:
155
        weights = PM.get_weight(args.weights)
156
        transform_test = weights.transforms()
157
158

    if args.cache_dataset and os.path.exists(cache_path):
159
        print(f"Loading dataset_test from {cache_path}")
160
161
162
163
        dataset_test, _ = torch.load(cache_path)
        dataset_test.transform = transform_test
    else:
        if args.distributed:
164
            print("It is recommended to pre-compute the dataset cache on a single-gpu first, as it will be faster")
165
        dataset_test = torchvision.datasets.Kinetics400(
166
167
168
            valdir,
            frames_per_clip=args.clip_len,
            step_between_clips=1,
169
            transform=transform_test,
170
            frame_rate=15,
171
172
173
174
            extensions=(
                "avi",
                "mp4",
            ),
175
176
        )
        if args.cache_dataset:
177
            print(f"Saving dataset_test to {cache_path}")
178
179
180
181
            utils.mkdir(os.path.dirname(cache_path))
            utils.save_on_master((dataset_test, valdir), cache_path)

    print("Creating data loaders")
182
    train_sampler = RandomClipSampler(dataset.video_clips, args.clips_per_video)
183
184
185
186
187
188
    test_sampler = UniformClipSampler(dataset_test.video_clips, args.clips_per_video)
    if args.distributed:
        train_sampler = DistributedSampler(train_sampler)
        test_sampler = DistributedSampler(test_sampler)

    data_loader = torch.utils.data.DataLoader(
189
190
191
192
193
194
195
        dataset,
        batch_size=args.batch_size,
        sampler=train_sampler,
        num_workers=args.workers,
        pin_memory=True,
        collate_fn=collate_fn,
    )
196
197

    data_loader_test = torch.utils.data.DataLoader(
198
199
200
201
202
203
204
        dataset_test,
        batch_size=args.batch_size,
        sampler=test_sampler,
        num_workers=args.workers,
        pin_memory=True,
        collate_fn=collate_fn,
    )
205
206

    print("Creating model")
207
208
209
210
    if not args.weights:
        model = torchvision.models.video.__dict__[args.model](pretrained=args.pretrained)
    else:
        model = PM.video.__dict__[args.model](weights=args.weights)
211
212
213
214
215
216
217
    model.to(device)
    if args.distributed and args.sync_bn:
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

    criterion = nn.CrossEntropyLoss()

    lr = args.lr * args.world_size
218
    optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=args.momentum, weight_decay=args.weight_decay)
219
    scaler = torch.cuda.amp.GradScaler() if args.amp else None
220
221
222

    # convert scheduler to be per iteration, not per epoch, for warmup that lasts
    # between different epochs
223
224
225
226
227
228
229
    iters_per_epoch = len(data_loader)
    lr_milestones = [iters_per_epoch * (m - args.lr_warmup_epochs) for m in args.lr_milestones]
    main_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=lr_milestones, gamma=args.lr_gamma)

    if args.lr_warmup_epochs > 0:
        warmup_iters = iters_per_epoch * args.lr_warmup_epochs
        args.lr_warmup_method = args.lr_warmup_method.lower()
230
231
232
233
234
235
236
237
        if args.lr_warmup_method == "linear":
            warmup_lr_scheduler = torch.optim.lr_scheduler.LinearLR(
                optimizer, start_factor=args.lr_warmup_decay, total_iters=warmup_iters
            )
        elif args.lr_warmup_method == "constant":
            warmup_lr_scheduler = torch.optim.lr_scheduler.ConstantLR(
                optimizer, factor=args.lr_warmup_decay, total_iters=warmup_iters
            )
238
        else:
239
            raise RuntimeError(
240
                f"Invalid warmup lr method '{args.lr_warmup_method}'. Only linear and constant are supported."
241
            )
242
243

        lr_scheduler = torch.optim.lr_scheduler.SequentialLR(
244
            optimizer, schedulers=[warmup_lr_scheduler, main_lr_scheduler], milestones=[warmup_iters]
245
246
247
        )
    else:
        lr_scheduler = main_lr_scheduler
248
249
250
251
252
253
254

    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
        model_without_ddp = model.module

    if args.resume:
255
256
257
258
259
        checkpoint = torch.load(args.resume, map_location="cpu")
        model_without_ddp.load_state_dict(checkpoint["model"])
        optimizer.load_state_dict(checkpoint["optimizer"])
        lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
        args.start_epoch = checkpoint["epoch"] + 1
260
261
        if args.amp:
            scaler.load_state_dict(checkpoint["scaler"])
262
263
264
265
266
267
268
269
270
271

    if args.test_only:
        evaluate(model, criterion, data_loader_test, device=device)
        return

    print("Start training")
    start_time = time.time()
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
272
        train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, device, epoch, args.print_freq, scaler)
273
274
275
        evaluate(model, criterion, data_loader_test, device=device)
        if args.output_dir:
            checkpoint = {
276
277
278
279
280
281
                "model": model_without_ddp.state_dict(),
                "optimizer": optimizer.state_dict(),
                "lr_scheduler": lr_scheduler.state_dict(),
                "epoch": epoch,
                "args": args,
            }
282
283
            if args.amp:
                checkpoint["scaler"] = scaler.state_dict()
284
            utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth"))
285
            utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth"))
286
287
288

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
289
    print(f"Training time {total_time_str}")
290
291
292
293


def parse_args():
    import argparse
294
295
296

    parser = argparse.ArgumentParser(description="PyTorch Video Classification Training")

297
298
299
300
301
    parser.add_argument("--data-path", default="/datasets01_101/kinetics/070618/", type=str, help="dataset path")
    parser.add_argument("--train-dir", default="train_avi-480p", type=str, help="name of train dir")
    parser.add_argument("--val-dir", default="val_avi-480p", type=str, help="name of val dir")
    parser.add_argument("--model", default="r2plus1d_18", type=str, help="model name")
    parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)")
302
303
304
305
    parser.add_argument("--clip-len", default=16, type=int, metavar="N", help="number of frames per clip")
    parser.add_argument(
        "--clips-per-video", default=5, type=int, metavar="N", help="maximum number of clips per video to consider"
    )
306
307
308
    parser.add_argument(
        "-b", "--batch-size", default=24, type=int, help="images per gpu, the total batch size is $NGPU x batch_size"
    )
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
    parser.add_argument("--epochs", default=45, type=int, metavar="N", help="number of total epochs to run")
    parser.add_argument(
        "-j", "--workers", default=10, type=int, metavar="N", help="number of data loading workers (default: 10)"
    )
    parser.add_argument("--lr", default=0.01, type=float, help="initial learning rate")
    parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
    parser.add_argument(
        "--wd",
        "--weight-decay",
        default=1e-4,
        type=float,
        metavar="W",
        help="weight decay (default: 1e-4)",
        dest="weight_decay",
    )
    parser.add_argument("--lr-milestones", nargs="+", default=[20, 30, 40], type=int, help="decrease lr on milestones")
    parser.add_argument("--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma")
    parser.add_argument("--lr-warmup-epochs", default=10, type=int, help="the number of epochs to warmup (default: 10)")
    parser.add_argument("--lr-warmup-method", default="linear", type=str, help="the warmup method (default: linear)")
    parser.add_argument("--lr-warmup-decay", default=0.001, type=float, help="the decay for lr")
    parser.add_argument("--print-freq", default=10, type=int, help="print frequency")
330
331
    parser.add_argument("--output-dir", default=".", type=str, help="path to save outputs")
    parser.add_argument("--resume", default="", type=str, help="path of checkpoint")
332
    parser.add_argument("--start-epoch", default=0, type=int, metavar="N", help="start epoch")
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
    parser.add_argument(
        "--cache-dataset",
        dest="cache_dataset",
        help="Cache the datasets for quicker initialization. It also serializes the transforms",
        action="store_true",
    )
    parser.add_argument(
        "--sync-bn",
        dest="sync_bn",
        help="Use sync batch norm",
        action="store_true",
    )
    parser.add_argument(
        "--test-only",
        dest="test_only",
        help="Only test the model",
        action="store_true",
    )
    parser.add_argument(
        "--pretrained",
        dest="pretrained",
        help="Use pre-trained models from the modelzoo",
        action="store_true",
    )

    # distributed training parameters
359
    parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
360
    parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training")
361

362
363
364
    # Prototype models only
    parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")

365
366
367
    # Mixed precision training parameters
    parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training")

368
369
370
371
372
373
374
375
    args = parser.parse_args()

    return args


if __name__ == "__main__":
    args = parse_args()
    main(args)