train.py 14 KB
Newer Older
1
2
import argparse
import warnings
3
from math import ceil
4
5
6
from pathlib import Path

import torch
7
import torchvision.models.optical_flow
8
9
10
import utils
from presets import OpticalFlowPresetTrain, OpticalFlowPresetEval
from torchvision.datasets import KittiFlow, FlyingChairs, FlyingThings3D, Sintel, HD1K
11
12

try:
13
    from torchvision import prototype
14
except ImportError:
15
    prototype = None
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133


def get_train_dataset(stage, dataset_root):
    if stage == "chairs":
        transforms = OpticalFlowPresetTrain(crop_size=(368, 496), min_scale=0.1, max_scale=1.0, do_flip=True)
        return FlyingChairs(root=dataset_root, split="train", transforms=transforms)
    elif stage == "things":
        transforms = OpticalFlowPresetTrain(crop_size=(400, 720), min_scale=-0.4, max_scale=0.8, do_flip=True)
        return FlyingThings3D(root=dataset_root, split="train", pass_name="both", transforms=transforms)
    elif stage == "sintel_SKH":  # S + K + H as from paper
        crop_size = (368, 768)
        transforms = OpticalFlowPresetTrain(crop_size=crop_size, min_scale=-0.2, max_scale=0.6, do_flip=True)

        things_clean = FlyingThings3D(root=dataset_root, split="train", pass_name="clean", transforms=transforms)
        sintel = Sintel(root=dataset_root, split="train", pass_name="both", transforms=transforms)

        kitti_transforms = OpticalFlowPresetTrain(crop_size=crop_size, min_scale=-0.3, max_scale=0.5, do_flip=True)
        kitti = KittiFlow(root=dataset_root, split="train", transforms=kitti_transforms)

        hd1k_transforms = OpticalFlowPresetTrain(crop_size=crop_size, min_scale=-0.5, max_scale=0.2, do_flip=True)
        hd1k = HD1K(root=dataset_root, split="train", transforms=hd1k_transforms)

        # As future improvement, we could probably be using a distributed sampler here
        # The distribution is S(.71), T(.135), K(.135), H(.02)
        return 100 * sintel + 200 * kitti + 5 * hd1k + things_clean
    elif stage == "kitti":
        transforms = OpticalFlowPresetTrain(
            # resize and crop params
            crop_size=(288, 960),
            min_scale=-0.2,
            max_scale=0.4,
            stretch_prob=0,
            # flip params
            do_flip=False,
            # jitter params
            brightness=0.3,
            contrast=0.3,
            saturation=0.3,
            hue=0.3 / 3.14,
            asymmetric_jitter_prob=0,
        )
        return KittiFlow(root=dataset_root, split="train", transforms=transforms)
    else:
        raise ValueError(f"Unknown stage {stage}")


@torch.no_grad()
def _validate(model, args, val_dataset, *, padder_mode, num_flow_updates=None, batch_size=None, header=None):
    """Helper function to compute various metrics (epe, etc.) for a model on a given dataset.

    We process as many samples as possible with ddp, and process the rest on a single worker.
    """
    batch_size = batch_size or args.batch_size

    model.eval()

    sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False, drop_last=True)
    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        sampler=sampler,
        batch_size=batch_size,
        pin_memory=True,
        num_workers=args.num_workers,
    )

    num_flow_updates = num_flow_updates or args.num_flow_updates

    def inner_loop(blob):
        if blob[0].dim() == 3:
            # input is not batched so we add an extra dim for consistency
            blob = [x[None, :, :, :] if x is not None else None for x in blob]

        image1, image2, flow_gt = blob[:3]
        valid_flow_mask = None if len(blob) == 3 else blob[-1]

        image1, image2 = image1.cuda(), image2.cuda()

        padder = utils.InputPadder(image1.shape, mode=padder_mode)
        image1, image2 = padder.pad(image1, image2)

        flow_predictions = model(image1, image2, num_flow_updates=num_flow_updates)
        flow_pred = flow_predictions[-1]
        flow_pred = padder.unpad(flow_pred).cpu()

        metrics, num_pixels_tot = utils.compute_metrics(flow_pred, flow_gt, valid_flow_mask)

        # We compute per-pixel epe (epe) and per-image epe (called f1-epe in RAFT paper).
        # per-pixel epe: average epe of all pixels of all images
        # per-image epe: average epe on each image independently, then average over images
        for name in ("epe", "1px", "3px", "5px", "f1"):  # f1 is called f1-all in paper
            logger.meters[name].update(metrics[name], n=num_pixels_tot)
        logger.meters["per_image_epe"].update(metrics["epe"], n=batch_size)

    logger = utils.MetricLogger()
    for meter_name in ("epe", "1px", "3px", "5px", "per_image_epe", "f1"):
        logger.add_meter(meter_name, fmt="{global_avg:.4f}")

    num_processed_samples = 0
    for blob in logger.log_every(val_loader, header=header, print_freq=None):
        inner_loop(blob)
        num_processed_samples += blob[0].shape[0]  # batch size

    num_processed_samples = utils.reduce_across_processes(num_processed_samples)
    print(
        f"Batch-processed {num_processed_samples} / {len(val_dataset)} samples. "
        "Going to process the remaining samples individually, if any."
    )

    if args.rank == 0:  # we only need to process the rest on a single worker
        for i in range(num_processed_samples, len(val_dataset)):
            inner_loop(val_dataset[i])

    logger.synchronize_between_processes()
    print(header, logger)


def validate(model, args):
    val_datasets = args.val_dataset or []
134

135
136
137
138
139
    if args.prototype:
        if args.weights:
            weights = prototype.models.get_weight(args.weights)
            preprocessing = weights.transforms()
        else:
140
            preprocessing = prototype.transforms.OpticalFlowEval()
141
142
143
    else:
        preprocessing = OpticalFlowPresetEval()

144
145
146
147
148
149
150
151
152
    for name in val_datasets:
        if name == "kitti":
            # Kitti has different image sizes so we need to individually pad them, we can't batch.
            # see comment in InputPadder
            if args.batch_size != 1 and args.rank == 0:
                warnings.warn(
                    f"Batch-size={args.batch_size} was passed. For technical reasons, evaluating on Kitti can only be done with a batch-size of 1."
                )

153
            val_dataset = KittiFlow(root=args.dataset_root, split="train", transforms=preprocessing)
154
155
156
157
158
159
            _validate(
                model, args, val_dataset, num_flow_updates=24, padder_mode="kitti", header="Kitti val", batch_size=1
            )
        elif name == "sintel":
            for pass_name in ("clean", "final"):
                val_dataset = Sintel(
160
                    root=args.dataset_root, split="train", pass_name=pass_name, transforms=preprocessing
161
162
163
164
165
166
167
168
169
170
171
172
173
                )
                _validate(
                    model,
                    args,
                    val_dataset,
                    num_flow_updates=32,
                    padder_mode="sintel",
                    header=f"Sintel val {pass_name}",
                )
        else:
            warnings.warn(f"Can't validate on {val_dataset}, skipping.")


174
def train_one_epoch(model, optimizer, scheduler, train_loader, logger, args):
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
    for data_blob in logger.log_every(train_loader):

        optimizer.zero_grad()

        image1, image2, flow_gt, valid_flow_mask = (x.cuda() for x in data_blob)
        flow_predictions = model(image1, image2, num_flow_updates=args.num_flow_updates)

        loss = utils.sequence_loss(flow_predictions, flow_gt, valid_flow_mask, args.gamma)
        metrics, _ = utils.compute_metrics(flow_predictions[-1], flow_gt, valid_flow_mask)

        metrics.pop("f1")
        logger.update(loss=loss, **metrics)

        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)

        optimizer.step()
        scheduler.step()


def main(args):
197
198
199
200
    if args.prototype and prototype is None:
        raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.")
    if not args.prototype and args.weights:
        raise ValueError("The weights parameter works only in prototype mode. Please pass the --prototype argument.")
201
202
    utils.setup_ddp(args)

203
204
    if args.prototype:
        model = prototype.models.optical_flow.__dict__[args.model](weights=args.weights)
205
206
207
    else:
        model = torchvision.models.optical_flow.__dict__[args.model](pretrained=args.pretrained)

208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
    model = model.to(args.local_rank)
    model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank])

    if args.resume is not None:
        d = torch.load(args.resume, map_location="cpu")
        model.load_state_dict(d, strict=True)

    if args.train_dataset is None:
        # Set deterministic CUDNN algorithms, since they can affect epe a fair bit.
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True
        validate(model, args)
        return

    print(f"Parameter Count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")

    torch.backends.cudnn.benchmark = True

    model.train()
    if args.freeze_batch_norm:
        utils.freeze_batch_norm(model.module)

    train_dataset = get_train_dataset(args.train_dataset, args.dataset_root)

    sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, shuffle=True, drop_last=True)
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        sampler=sampler,
        batch_size=args.batch_size,
        pin_memory=True,
        num_workers=args.num_workers,
    )

    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, eps=args.adamw_eps)

    scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer=optimizer,
        max_lr=args.lr,
246
247
        epochs=args.epochs,
        steps_per_epoch=ceil(len(train_dataset) / (args.world_size * args.batch_size)),
248
249
250
251
252
253
254
255
        pct_start=0.05,
        cycle_momentum=False,
        anneal_strategy="linear",
    )

    logger = utils.MetricLogger()

    done = False
256
    for current_epoch in range(args.epochs):
257
258
259
        print(f"EPOCH {current_epoch}")

        sampler.set_epoch(current_epoch)  # needed, otherwise the data loading order would be the same for all epochs
260
        train_one_epoch(
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
            model=model,
            optimizer=optimizer,
            scheduler=scheduler,
            train_loader=train_loader,
            logger=logger,
            args=args,
        )

        # Note: we don't sync the SmoothedValues across processes, so the printed metrics are just those of rank 0
        print(f"Epoch {current_epoch} done. ", logger)

        if args.rank == 0:
            # TODO: Also save the optimizer and scheduler
            torch.save(model.state_dict(), Path(args.output_dir) / f"{args.name}_{current_epoch}.pth")
            torch.save(model.state_dict(), Path(args.output_dir) / f"{args.name}.pth")

        if current_epoch % args.val_freq == 0 or done:
            validate(model, args)
            model.train()
            if args.freeze_batch_norm:
                utils.freeze_batch_norm(model.module)


def get_args_parser(add_help=True):
    parser = argparse.ArgumentParser(add_help=add_help, description="Train or evaluate an optical-flow model.")
    parser.add_argument(
        "--name",
        default="raft",
        type=str,
        help="The name of the experiment - determines the name of the files where weights are saved.",
    )
    parser.add_argument(
        "--output-dir", default="checkpoints", type=str, help="Output dir where checkpoints will be stored."
    )
    parser.add_argument(
        "--resume",
        type=str,
        help="A path to previously saved weights. Used to re-start training from, or evaluate a pre-saved model.",
    )

    parser.add_argument("--num-workers", type=int, default=12, help="Number of workers for the data loading part.")

    parser.add_argument(
        "--train-dataset",
        type=str,
        help="The dataset to use for training. If not passed, only validation is performed (and you probably want to pass --resume).",
    )
    parser.add_argument("--val-dataset", type=str, nargs="+", help="The dataset(s) to use for validation.")
    parser.add_argument("--val-freq", type=int, default=2, help="Validate every X epochs")
310
311
    parser.add_argument("--epochs", type=int, default=20, help="The total number of epochs to train.")
    parser.add_argument("--batch-size", type=int, default=2)
312
313
314
315
316
317
318
319
320

    parser.add_argument("--lr", type=float, default=0.00002, help="Learning rate for AdamW optimizer")
    parser.add_argument("--weight-decay", type=float, default=0.00005, help="Weight decay for AdamW optimizer")
    parser.add_argument("--adamw-eps", type=float, default=1e-8, help="eps value for AdamW optimizer")

    parser.add_argument(
        "--freeze-batch-norm", action="store_true", help="Set BatchNorm modules of the model in eval mode."
    )

321
322
323
324
325
    parser.add_argument(
        "--model", type=str, default="raft_large", help="The name of the model to use - either raft_large or raft_small"
    )
    # TODO: resume, pretrained, and weights should be in an exclusive arg group
    parser.add_argument("--pretrained", action="store_true", help="Whether to use pretrained weights")
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343

    parser.add_argument(
        "--num_flow_updates",
        type=int,
        default=12,
        help="number of updates (or 'iters') in the update operator of the model.",
    )

    parser.add_argument("--gamma", type=float, default=0.8, help="exponential weighting for loss. Must be < 1.")

    parser.add_argument("--dist-url", default="env://", help="URL used to set up distributed training")

    parser.add_argument(
        "--dataset-root",
        help="Root folder where the datasets are stored. Will be passed as the 'root' parameter of the datasets.",
        required=True,
    )

344
345
346
347
348
349
350
351
352
    # Prototype models only
    parser.add_argument(
        "--prototype",
        dest="prototype",
        help="Use prototype model builders instead those from main area",
        action="store_true",
    )
    parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load.")

353
354
355
356
357
358
359
    return parser


if __name__ == "__main__":
    args = get_args_parser().parse_args()
    Path(args.output_dir).mkdir(exist_ok=True)
    main(args)