train.py 32.6 KB
Newer Older
Ponku's avatar
Ponku committed
1
2
3
4
5
6
7
8
9
10
11
12
import argparse
import os
import warnings
from pathlib import Path
from typing import List, Union

import numpy as np
import torch
import torch.distributed as dist
import torchvision.models.optical_flow
import torchvision.prototype.models.depth.stereo
import utils
13
import visualization
Ponku's avatar
Ponku committed
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150

from parsing import make_dataset, make_eval_transform, make_train_transform, VALID_DATASETS
from torch import nn
from torchvision.transforms.functional import get_dimensions, InterpolationMode, resize
from utils.metrics import AVAILABLE_METRICS
from utils.norm import freeze_batch_norm


def make_stereo_flow(flow: Union[torch.Tensor, List[torch.Tensor]], model_out_channels: int) -> torch.Tensor:
    """Helper function to make stereo flow from a given model output"""
    if isinstance(flow, list):
        return [make_stereo_flow(flow_i, model_out_channels) for flow_i in flow]

    B, C, H, W = flow.shape
    # we need to add zero flow if the model outputs 2 channels
    if C == 1 and model_out_channels == 2:
        zero_flow = torch.zeros_like(flow)
        # by convention the flow is X-Y axis, so we need the Y flow last
        flow = torch.cat([flow, zero_flow], dim=1)
    return flow


def make_lr_schedule(args: argparse.Namespace, optimizer: torch.optim.Optimizer) -> np.ndarray:
    """Helper function to return a learning rate scheduler for CRE-stereo"""
    if args.decay_after_steps < args.warmup_steps:
        raise ValueError(f"decay_after_steps: {args.function} must be greater than warmup_steps: {args.warmup_steps}")

    warmup_steps = args.warmup_steps if args.warmup_steps else 0
    flat_lr_steps = args.decay_after_steps - warmup_steps if args.decay_after_steps else 0
    decay_lr_steps = args.total_iterations - flat_lr_steps

    max_lr = args.lr
    min_lr = args.min_lr

    schedulers = []
    milestones = []

    if warmup_steps > 0:
        if args.lr_warmup_method == "linear":
            warmup_lr_scheduler = torch.optim.lr_scheduler.LinearLR(
                optimizer, start_factor=args.lr_warmup_factor, total_iters=warmup_steps
            )
        elif args.lr_warmup_method == "constant":
            warmup_lr_scheduler = torch.optim.lr_scheduler.ConstantLR(
                optimizer, factor=args.lr_warmup_factor, total_iters=warmup_steps
            )
        else:
            raise ValueError(f"Unknown lr warmup method {args.lr_warmup_method}")
        schedulers.append(warmup_lr_scheduler)
        milestones.append(warmup_steps)

    if flat_lr_steps > 0:
        flat_lr_scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer, factor=max_lr, total_iters=flat_lr_steps)
        schedulers.append(flat_lr_scheduler)
        milestones.append(flat_lr_steps + warmup_steps)

    if decay_lr_steps > 0:
        if args.lr_decay_method == "cosine":
            decay_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                optimizer, T_max=decay_lr_steps, eta_min=min_lr
            )
        elif args.lr_decay_method == "linear":
            decay_lr_scheduler = torch.optim.lr_scheduler.LinearLR(
                optimizer, start_factor=max_lr, end_factor=min_lr, total_iters=decay_lr_steps
            )
        elif args.lr_decay_method == "exponential":
            decay_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(
                optimizer, gamma=args.lr_decay_gamma, last_epoch=-1
            )
        else:
            raise ValueError(f"Unknown lr decay method {args.lr_decay_method}")
        schedulers.append(decay_lr_scheduler)

    scheduler = torch.optim.lr_scheduler.SequentialLR(optimizer, schedulers, milestones=milestones)
    return scheduler


def shuffle_dataset(dataset):
    """Shuffle the dataset"""
    perm = torch.randperm(len(dataset))
    return torch.utils.data.Subset(dataset, perm)


def resize_dataset_to_n_steps(
    dataset: torch.utils.data.Dataset, dataset_steps: int, samples_per_step: int, args: argparse.Namespace
) -> torch.utils.data.Dataset:
    original_size = len(dataset)
    if args.steps_is_epochs:
        samples_per_step = original_size
    target_size = dataset_steps * samples_per_step

    dataset_copies = []
    n_expands, remainder = divmod(target_size, original_size)
    for idx in range(n_expands):
        dataset_copies.append(dataset)

    if remainder > 0:
        dataset_copies.append(torch.utils.data.Subset(dataset, list(range(remainder))))

    if args.dataset_shuffle:
        dataset_copies = [shuffle_dataset(dataset_copy) for dataset_copy in dataset_copies]

    dataset = torch.utils.data.ConcatDataset(dataset_copies)
    return dataset


def get_train_dataset(dataset_root: str, args: argparse.Namespace) -> torch.utils.data.Dataset:
    datasets = []
    for dataset_name in args.train_datasets:
        transform = make_train_transform(args)
        dataset = make_dataset(dataset_name, dataset_root, transform)
        datasets.append(dataset)

    if len(datasets) == 0:
        raise ValueError("No datasets specified for training")

    samples_per_step = args.world_size * args.batch_size

    for idx, (dataset, steps_per_dataset) in enumerate(zip(datasets, args.dataset_steps)):
        datasets[idx] = resize_dataset_to_n_steps(dataset, steps_per_dataset, samples_per_step, args)

    dataset = torch.utils.data.ConcatDataset(datasets)
    if args.dataset_order_shuffle:
        dataset = shuffle_dataset(dataset)

    print(f"Training dataset: {len(dataset)} samples")
    return dataset


@torch.inference_mode()
def _evaluate(
    model,
    args,
    val_loader,
    *,
    padder_mode,
    print_freq=10,
151
    writer=None,
Ponku's avatar
Ponku committed
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
    step=None,
    iterations=None,
    batch_size=None,
    header=None,
):
    """Helper function to compute various metrics (epe, etc.) for a model on a given dataset."""
    model.eval()
    header = header or "Test:"
    device = torch.device(args.device)
    metric_logger = utils.MetricLogger(delimiter="  ")

    iterations = iterations or args.recurrent_updates

    logger = utils.MetricLogger()
    for meter_name in args.metrics:
        logger.add_meter(meter_name, fmt="{global_avg:.4f}")
    if "fl-all" not in args.metrics:
        logger.add_meter("fl-all", fmt="{global_avg:.4f}")

    num_processed_samples = 0
    with torch.cuda.amp.autocast(enabled=args.mixed_precision, dtype=torch.float16):
        for blob in metric_logger.log_every(val_loader, print_freq, header):
            image_left, image_right, disp_gt, valid_disp_mask = (x.to(device) for x in blob)
            padder = utils.InputPadder(image_left.shape, mode=padder_mode)
            image_left, image_right = padder.pad(image_left, image_right)

            disp_predictions = model(image_left, image_right, flow_init=None, num_iters=iterations)
            disp_pred = disp_predictions[-1][:, :1, :, :]
            disp_pred = padder.unpad(disp_pred)

            metrics, _ = utils.compute_metrics(disp_pred, disp_gt, valid_disp_mask, metrics=logger.meters.keys())
            num_processed_samples += image_left.shape[0]
            for name in metrics:
                logger.meters[name].update(metrics[name], n=1)

    num_processed_samples = utils.reduce_across_processes(num_processed_samples)

    print("Num_processed_samples: ", num_processed_samples)
    if (
        hasattr(val_loader.dataset, "__len__")
        and len(val_loader.dataset) != num_processed_samples
        and torch.distributed.get_rank() == 0
    ):
        warnings.warn(
            f"Number of processed samples {num_processed_samples} is different"
            f"from the dataset size {len(val_loader.dataset)}. This may happen if"
            "the dataset is not divisible by the batch size. Try lowering the batch size or GPU number for more accurate results."
        )

201
    if writer is not None and args.rank == 0:
Ponku's avatar
Ponku committed
202
203
        for meter_name, meter_value in logger.meters.items():
            scalar_name = f"{meter_name} {header}"
204
            writer.add_scalar(scalar_name, meter_value.avg, step)
Ponku's avatar
Ponku committed
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251

    logger.synchronize_between_processes()
    print(header, logger)


def make_eval_loader(dataset_name: str, args: argparse.Namespace) -> torch.utils.data.DataLoader:
    if args.weights:
        weights = torchvision.models.get_weight(args.weights)
        trans = weights.transforms()

        def preprocessing(image_left, image_right, disp, valid_disp_mask):
            C_o, H_o, W_o = get_dimensions(image_left)
            image_left, image_right = trans(image_left, image_right)

            C_t, H_t, W_t = get_dimensions(image_left)
            scale_factor = W_t / W_o

            if disp is not None and not isinstance(disp, torch.Tensor):
                disp = torch.from_numpy(disp)
                if W_t != W_o:
                    disp = resize(disp, (H_t, W_t), mode=InterpolationMode.BILINEAR) * scale_factor
            if valid_disp_mask is not None and not isinstance(valid_disp_mask, torch.Tensor):
                valid_disp_mask = torch.from_numpy(valid_disp_mask)
                if W_t != W_o:
                    valid_disp_mask = resize(valid_disp_mask, (H_t, W_t), mode=InterpolationMode.NEAREST)
            return image_left, image_right, disp, valid_disp_mask

    else:
        preprocessing = make_eval_transform(args)

    val_dataset = make_dataset(dataset_name, args.dataset_root, transforms=preprocessing)
    if args.distributed:
        sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False, drop_last=False)
    else:
        sampler = torch.utils.data.SequentialSampler(val_dataset)

    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        sampler=sampler,
        batch_size=args.batch_size,
        pin_memory=True,
        num_workers=args.workers,
    )

    return val_loader


252
def evaluate(model, loaders, args, writer=None, step=None):
Ponku's avatar
Ponku committed
253
254
255
256
257
258
259
260
261
    for loader_name, loader in loaders.items():
        _evaluate(
            model,
            args,
            loader,
            iterations=args.recurrent_updates,
            padder_mode=args.padder_type,
            header=f"{loader_name} evaluation",
            batch_size=args.batch_size,
262
            writer=writer,
Ponku's avatar
Ponku committed
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
            step=step,
        )


def run(model, optimizer, scheduler, train_loader, val_loaders, logger, writer, scaler, args):
    device = torch.device(args.device)
    # wrap the loader in a logger
    loader = iter(logger.log_every(train_loader))
    # output channels
    model_out_channels = model.module.output_channels if args.distributed else model.output_channels

    torch.set_num_threads(args.threads)

    sequence_criterion = utils.SequenceLoss(
        gamma=args.gamma,
        max_flow=args.max_disparity,
        exclude_large_flows=args.flow_loss_exclude_large,
    ).to(device)

    if args.consistency_weight:
        consistency_criterion = utils.FlowSequenceConsistencyLoss(
            args.gamma,
            resize_factor=0.25,
            rescale_factor=0.25,
            rescale_mode="bilinear",
        ).to(device)
    else:
        consistency_criterion = None

    if args.psnr_weight:
        psnr_criterion = utils.PSNRLoss().to(device)
    else:
        psnr_criterion = None

    if args.smoothness_weight:
        smoothness_criterion = utils.SmoothnessLoss().to(device)
    else:
        smoothness_criterion = None

    if args.photometric_weight:
        photometric_criterion = utils.FlowPhotoMetricLoss(
            ssim_weight=args.photometric_ssim_weight,
            max_displacement_ratio=args.photometric_max_displacement_ratio,
            ssim_use_padding=False,
        ).to(device)
    else:
        photometric_criterion = None

    for step in range(args.start_step + 1, args.total_iterations + 1):
        data_blob = next(loader)
        optimizer.zero_grad()

        # unpack the data blob
        image_left, image_right, disp_mask, valid_disp_mask = (x.to(device) for x in data_blob)
        with torch.cuda.amp.autocast(enabled=args.mixed_precision, dtype=torch.float16):
            disp_predictions = model(image_left, image_right, flow_init=None, num_iters=args.recurrent_updates)
            # different models have different outputs, make sure we get the right ones for this task
            disp_predictions = make_stereo_flow(disp_predictions, model_out_channels)
            # should the architecture or training loop require it, we have to adjust the disparity mask
            # target to possibly look like an optical flow mask
            disp_mask = make_stereo_flow(disp_mask, model_out_channels)
            # sequence loss on top of the model outputs

        loss = sequence_criterion(disp_predictions, disp_mask, valid_disp_mask) * args.flow_loss_weight

        if args.consistency_weight > 0:
            loss_consistency = consistency_criterion(disp_predictions)
            loss += loss_consistency * args.consistency_weight

        if args.psnr_weight > 0:
            loss_psnr = 0.0
            for pred in disp_predictions:
                # predictions might have 2 channels
                loss_psnr += psnr_criterion(
                    pred * valid_disp_mask.unsqueeze(1),
                    disp_mask * valid_disp_mask.unsqueeze(1),
                ).mean()  # mean the psnr loss over the batch
            loss += loss_psnr / len(disp_predictions) * args.psnr_weight

        if args.photometric_weight > 0:
            loss_photometric = 0.0
            for pred in disp_predictions:
                # predictions might have 1 channel, therefore we need to inpute 0s for the second channel
                if model_out_channels == 1:
                    pred = torch.cat([pred, torch.zeros_like(pred)], dim=1)

                loss_photometric += photometric_criterion(
                    image_left, image_right, pred, valid_disp_mask
                )  # photometric loss already comes out meaned over the batch
            loss += loss_photometric / len(disp_predictions) * args.photometric_weight

        if args.smoothness_weight > 0:
            loss_smoothness = 0.0
            for pred in disp_predictions:
                # predictions might have 2 channels
                loss_smoothness += smoothness_criterion(
                    image_left, pred[:, :1, :, :]
                ).mean()  # mean the smoothness loss over the batch
            loss += loss_smoothness / len(disp_predictions) * args.smoothness_weight

        with torch.no_grad():
            metrics, _ = utils.compute_metrics(
                disp_predictions[-1][:, :1, :, :],  # predictions might have 2 channels
                disp_mask[:, :1, :, :],  # so does the ground truth
                valid_disp_mask,
                args.metrics,
            )

        metrics.pop("fl-all", None)
        logger.update(loss=loss, **metrics)

        if scaler is not None:
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            if args.clip_grad_norm:
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=args.clip_grad_norm)
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            if args.clip_grad_norm:
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=args.clip_grad_norm)
            optimizer.step()

        scheduler.step()

        if not dist.is_initialized() or dist.get_rank() == 0:
            if writer is not None and step % args.tensorboard_log_frequency == 0:
                # log the loss and metrics to tensorboard

                writer.add_scalar("loss", loss, step)
                for name, value in logger.meters.items():
                    writer.add_scalar(name, value.avg, step)
                # log the images to tensorboard
397
                pred_grid = visualization.make_training_sample_grid(
Ponku's avatar
Ponku committed
398
399
400
401
402
                    image_left, image_right, disp_mask, valid_disp_mask, disp_predictions
                )
                writer.add_image("predictions", pred_grid, step, dataformats="HWC")

                # second thing we want to see is how relevant the iterative refinement is
403
                pred_sequence_grid = visualization.make_disparity_sequence_grid(disp_predictions, disp_mask)
Ponku's avatar
Ponku committed
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
                writer.add_image("sequence", pred_sequence_grid, step, dataformats="HWC")

        if step % args.save_frequency == 0:
            if not args.distributed or args.rank == 0:
                model_without_ddp = (
                    model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model
                )
                checkpoint = {
                    "model": model_without_ddp.state_dict(),
                    "optimizer": optimizer.state_dict(),
                    "scheduler": scheduler.state_dict(),
                    "step": step,
                    "args": args,
                }
                os.makedirs(args.checkpoint_dir, exist_ok=True)
                torch.save(checkpoint, Path(args.checkpoint_dir) / f"{args.name}_{step}.pth")
                torch.save(checkpoint, Path(args.checkpoint_dir) / f"{args.name}.pth")

        if step % args.valid_frequency == 0:
            evaluate(model, val_loaders, args, writer, step)
            model.train()
            if args.freeze_batch_norm:
                if isinstance(model, nn.parallel.DistributedDataParallel):
                    freeze_batch_norm(model.module)
                else:
                    freeze_batch_norm(model)

    # one final save at the end
    if not args.distributed or args.rank == 0:
        model_without_ddp = model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model
        checkpoint = {
            "model": model_without_ddp.state_dict(),
            "optimizer": optimizer.state_dict(),
            "scheduler": scheduler.state_dict(),
            "step": step,
            "args": args,
        }
        os.makedirs(args.checkpoint_dir, exist_ok=True)
        torch.save(checkpoint, Path(args.checkpoint_dir) / f"{args.name}_{step}.pth")
        torch.save(checkpoint, Path(args.checkpoint_dir) / f"{args.name}.pth")


def main(args):
    args.total_iterations = sum(args.dataset_steps)

449
    # initialize DDP setting
Ponku's avatar
Ponku committed
450
451
452
453
454
    utils.setup_ddp(args)
    print(args)

    args.test_only = args.train_datasets is None

455
    # set the appropriate devices
Ponku's avatar
Ponku committed
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
    if args.distributed and args.device == "cpu":
        raise ValueError("The device must be cuda if we want to run in distributed mode using torchrun")
    device = torch.device(args.device)

    # select model architecture
    model = torchvision.prototype.models.depth.stereo.__dict__[args.model](weights=args.weights)

    # convert to DDP if need be
    if args.distributed:
        model = model.to(args.gpu)
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
        model_without_ddp = model.module
    else:
        model.to(device)
        model_without_ddp = model

    os.makedirs(args.checkpoint_dir, exist_ok=True)

    val_loaders = {name: make_eval_loader(name, args) for name in args.test_datasets}

    # EVAL ONLY configurations
    if args.test_only:
        evaluate(model, val_loaders, args)
        return

    # Sanity check for the parameter count
    print(f"Parameter Count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")

    # Compose the training dataset
    train_dataset = get_train_dataset(args.dataset_root, args)

    # initialize the optimizer
    if args.optimizer == "adam":
        optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    elif args.optimizer == "sgd":
        optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, momentum=0.9)
    else:
        raise ValueError(f"Unknown optimizer {args.optimizer}. Please choose between adam and sgd")

    # initialize the learning rate schedule
    scheduler = make_lr_schedule(args, optimizer)

498
    # load them from checkpoint if needed
Ponku's avatar
Ponku committed
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
    args.start_step = 0
    if args.resume_path is not None:
        checkpoint = torch.load(args.resume_path, map_location="cpu")
        if "model" in checkpoint:
            # this means the user requested to resume from a training checkpoint
            model_without_ddp.load_state_dict(checkpoint["model"])
            # this means the user wants to continue training from where it was left off
            if args.resume_schedule:
                optimizer.load_state_dict(checkpoint["optimizer"])
                scheduler.load_state_dict(checkpoint["scheduler"])
                args.start_step = checkpoint["step"] + 1
                # modify starting point of the dat
                sample_start_step = args.start_step * args.batch_size * args.world_size
                train_dataset = train_dataset[sample_start_step:]

        else:
            # this means the user wants to finetune on top of a model state dict
            # and that no other changes are required
            model_without_ddp.load_state_dict(checkpoint)

    torch.backends.cudnn.benchmark = True

    # enable training mode
    model.train()
    if args.freeze_batch_norm:
        freeze_batch_norm(model_without_ddp)

    # put dataloader on top of the dataset
    # make sure to disable shuffling since the dataset is already shuffled
    # in order to guarantee quasi randomness whilst retaining a deterministic
    # dataset consumption order
    if args.distributed:
        # the train dataset is preshuffled in order to respect the iteration order
        sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, shuffle=False, drop_last=True)
    else:
534
        # the train dataset is already shuffled, so we can use a simple SequentialSampler
Ponku's avatar
Ponku committed
535
536
537
538
539
540
541
542
543
544
        sampler = torch.utils.data.SequentialSampler(train_dataset)

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        sampler=sampler,
        batch_size=args.batch_size,
        pin_memory=True,
        num_workers=args.workers,
    )

545
    # initialize the logger
Ponku's avatar
Ponku committed
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
    if args.tensorboard_summaries:
        from torch.utils.tensorboard import SummaryWriter

        tensorboard_path = Path(args.checkpoint_dir) / "tensorboard"
        os.makedirs(tensorboard_path, exist_ok=True)

        tensorboard_run = tensorboard_path / f"{args.name}"
        writer = SummaryWriter(tensorboard_run)
    else:
        writer = None

    logger = utils.MetricLogger(delimiter="  ")

    scaler = torch.cuda.amp.GradScaler() if args.mixed_precision else None
    # run the training loop
    # this will perform optimization, respectively logging and saving checkpoints
    # when need be
    run(
        model=model,
        optimizer=optimizer,
        scheduler=scheduler,
        train_loader=train_loader,
        val_loaders=val_loaders,
        logger=logger,
        writer=writer,
        scaler=scaler,
        args=args,
    )


def get_args_parser(add_help=True):
    import argparse

    parser = argparse.ArgumentParser(description="PyTorch Stereo Matching Training", add_help=add_help)
    # checkpointing
    parser.add_argument("--name", default="crestereo", help="name of the experiment")
    parser.add_argument("--resume", type=str, default=None, help="from which checkpoint to resume")
    parser.add_argument("--checkpoint-dir", type=str, default="checkpoints", help="path to the checkpoint directory")

    # dataset
    parser.add_argument("--dataset-root", type=str, default="", help="path to the dataset root directory")
    parser.add_argument(
        "--train-datasets",
        type=str,
        nargs="+",
        default=["crestereo"],
        help="dataset(s) to train on",
        choices=list(VALID_DATASETS.keys()),
    )
    parser.add_argument(
        "--dataset-steps", type=int, nargs="+", default=[300_000], help="number of steps for each dataset"
    )
    parser.add_argument(
        "--steps-is-epochs", action="store_true", help="if set, dataset-steps are interpreted as epochs"
    )
    parser.add_argument(
        "--test-datasets",
        type=str,
        nargs="+",
        default=["middlebury2014-train"],
        help="dataset(s) to test on",
        choices=["middlebury2014-train"],
    )
    parser.add_argument("--dataset-shuffle", type=bool, help="shuffle the dataset", default=True)
    parser.add_argument("--dataset-order-shuffle", type=bool, help="shuffle the dataset order", default=True)
    parser.add_argument("--batch-size", type=int, default=2, help="batch size per GPU")
    parser.add_argument("--workers", type=int, default=4, help="number of workers per GPU")
    parser.add_argument(
        "--threads",
        type=int,
        default=16,
        help="number of CPU threads per GPU. This can be changed around to speed-up transforms if needed. This can lead to worker thread contention so use with care.",
    )

    # model architecture
    parser.add_argument(
        "--model",
        type=str,
        default="crestereo_base",
        help="model architecture",
        choices=["crestereo_base", "raft_stereo"],
    )
    parser.add_argument("--recurrent-updates", type=int, default=10, help="number of recurrent updates")
    parser.add_argument("--freeze-batch-norm", action="store_true", help="freeze batch norm parameters")

    # loss parameters
    parser.add_argument("--gamma", type=float, default=0.8, help="gamma parameter for the flow sequence loss")
    parser.add_argument("--flow-loss-weight", type=float, default=1.0, help="weight for the flow loss")
    parser.add_argument(
        "--flow-loss-exclude-large",
        action="store_true",
        help="exclude large flow values from the loss. A large value is defined as a value greater than the ground truth flow norm",
        default=False,
    )
    parser.add_argument("--consistency-weight", type=float, default=0.0, help="consistency loss weight")
    parser.add_argument(
        "--consistency-resize-factor",
        type=float,
        default=0.25,
        help="consistency loss resize factor to account for the fact that the flow is computed on a downsampled image",
    )
    parser.add_argument("--psnr-weight", type=float, default=0.0, help="psnr loss weight")
    parser.add_argument("--smoothness-weight", type=float, default=0.0, help="smoothness loss weight")
    parser.add_argument("--photometric-weight", type=float, default=0.0, help="photometric loss weight")
    parser.add_argument(
        "--photometric-max-displacement-ratio",
        type=float,
        default=0.15,
        help="Only pixels with a displacement smaller than this ratio of the image width will be considered for the photometric loss",
    )
    parser.add_argument("--photometric-ssim-weight", type=float, default=0.85, help="photometric ssim loss weight")

    # transforms parameters
    parser.add_argument("--gpu-transforms", action="store_true", help="use GPU transforms")
    parser.add_argument(
        "--eval-size", type=int, nargs="+", default=[384, 512], help="size of the images for evaluation"
    )
    parser.add_argument("--resize-size", type=int, nargs=2, default=None, help="resize size")
    parser.add_argument("--crop-size", type=int, nargs=2, default=[384, 512], help="crop size")
    parser.add_argument("--scale-range", type=float, nargs=2, default=[0.6, 1.0], help="random scale range")
    parser.add_argument("--rescale-prob", type=float, default=1.0, help="probability of resizing the image")
    parser.add_argument(
        "--scaling-type", type=str, default="linear", help="scaling type", choices=["exponential", "linear"]
    )
    parser.add_argument("--flip-prob", type=float, default=0.5, help="probability of flipping the image")
    parser.add_argument(
        "--norm-mean", type=float, nargs="+", default=[0.5, 0.5, 0.5], help="mean for image normalization"
    )
    parser.add_argument(
        "--norm-std", type=float, nargs="+", default=[0.5, 0.5, 0.5], help="std for image normalization"
    )
    parser.add_argument(
        "--use-grayscale", action="store_true", help="use grayscale images instead of RGB", default=False
    )
    parser.add_argument("--max-disparity", type=float, default=None, help="maximum disparity")
    parser.add_argument(
        "--interpolation-strategy",
        type=str,
        default="bilinear",
        help="interpolation strategy",
        choices=["bilinear", "bicubic", "mixed"],
    )
    parser.add_argument("--spatial-shift-prob", type=float, default=1.0, help="probability of shifting the image")
    parser.add_argument(
        "--spatial-shift-max-angle", type=float, default=0.1, help="maximum angle for the spatial shift"
    )
    parser.add_argument(
        "--spatial-shift-max-displacement", type=float, default=2.0, help="maximum displacement for the spatial shift"
    )
    parser.add_argument("--gamma-range", type=float, nargs="+", default=[0.8, 1.2], help="range for gamma correction")
    parser.add_argument(
        "--brightness-range", type=float, nargs="+", default=[0.8, 1.2], help="range for brightness correction"
    )
    parser.add_argument(
        "--contrast-range", type=float, nargs="+", default=[0.8, 1.2], help="range for contrast correction"
    )
    parser.add_argument(
        "--saturation-range", type=float, nargs="+", default=0.0, help="range for saturation correction"
    )
    parser.add_argument("--hue-range", type=float, nargs="+", default=0.0, help="range for hue correction")
    parser.add_argument(
        "--asymmetric-jitter-prob",
        type=float,
        default=1.0,
        help="probability of using asymmetric jitter instead of symmetric jitter",
    )
    parser.add_argument("--occlusion-prob", type=float, default=0.5, help="probability of occluding the rightimage")
    parser.add_argument(
        "--occlusion-px-range", type=int, nargs="+", default=[50, 100], help="range for the number of occluded pixels"
    )
    parser.add_argument("--erase-prob", type=float, default=0.0, help="probability of erasing in both images")
    parser.add_argument(
        "--erase-px-range", type=int, nargs="+", default=[50, 100], help="range for the number of erased pixels"
    )
    parser.add_argument(
        "--erase-num-repeats", type=int, default=1, help="number of times to repeat the erase operation"
    )

    # optimizer parameters
    parser.add_argument("--optimizer", type=str, default="adam", help="optimizer", choices=["adam", "sgd"])
    parser.add_argument("--lr", type=float, default=4e-4, help="learning rate")
    parser.add_argument("--weight-decay", type=float, default=0.0, help="weight decay")
    parser.add_argument("--clip-grad-norm", type=float, default=0.0, help="clip grad norm")

    # lr_scheduler parameters
    parser.add_argument("--min-lr", type=float, default=2e-5, help="minimum learning rate")
    parser.add_argument("--warmup-steps", type=int, default=6_000, help="number of warmup steps")
    parser.add_argument(
        "--decay-after-steps", type=int, default=180_000, help="number of steps after which to start decay the lr"
    )
    parser.add_argument(
        "--lr-warmup-method", type=str, default="linear", help="warmup method", choices=["linear", "cosine"]
    )
    parser.add_argument("--lr-warmup-factor", type=float, default=0.02, help="warmup factor for the learning rate")
    parser.add_argument(
        "--lr-decay-method",
        type=str,
        default="linear",
        help="decay method",
        choices=["linear", "cosine", "exponential"],
    )
    parser.add_argument("--lr-decay-gamma", type=float, default=0.8, help="decay factor for the learning rate")

    # deterministic behaviour
    parser.add_argument("--seed", type=int, default=42, help="seed for random number generators")

    # mixed precision training
    parser.add_argument("--mixed-precision", action="store_true", help="use mixed precision training")

    # logging
    parser.add_argument("--tensorboard-summaries", action="store_true", help="log to tensorboard")
    parser.add_argument("--tensorboard-log-frequency", type=int, default=100, help="log frequency")
    parser.add_argument("--save-frequency", type=int, default=1_000, help="save frequency")
    parser.add_argument("--valid-frequency", type=int, default=1_000, help="validation frequency")
    parser.add_argument(
        "--metrics",
        type=str,
        nargs="+",
        default=["mae", "rmse", "1px", "3px", "5px", "relepe"],
        help="metrics to log",
        choices=AVAILABLE_METRICS,
    )

    # distributed parameters
    parser.add_argument("--world-size", type=int, default=8, help="number of distributed processes")
    parser.add_argument("--dist-url", type=str, default="env://", help="url used to set up distributed training")
    parser.add_argument("--device", type=str, default="cuda", help="device to use for training")

    # weights API
    parser.add_argument("--weights", type=str, default=None, help="weights API url")
    parser.add_argument(
        "--resume-path", type=str, default=None, help="a path from which to resume or start fine-tuning"
    )
    parser.add_argument("--resume-schedule", action="store_true", help="resume optimizer state")

    # padder parameters
    parser.add_argument("--padder-type", type=str, default="kitti", help="padder type", choices=["kitti", "sintel"])
    return parser


if __name__ == "__main__":
    args = get_args_parser().parse_args()
    main(args)