Unverified Commit 10dafd9b authored by Ponku's avatar Ponku Committed by GitHub
Browse files

Add stereo train loop (#6605)



* crestereo draft implementation

* minor model fixes. positional embedding changes.

* aligned base configuration with paper

* Adressing comments

* Broke down Adaptive Correlation Layer. Adressed some other commets.

* adressed some nits

* changed search size, added output channels to model attrs

* changed weights naming

* changed from iterations to num_iters

* removed _make_coords, adressed comments

* fixed jit test

* added script files

* added cascaded inference evaluation

* added optimizer option

* minor changes

* Update references/depth/stereo/train.py
Co-authored-by: default avatarvfdev <vfdev.5@gmail.com>

* adressed some comments

* change if-else to dict

* added manual resizing for masks and disparities during evaluation

* minor fixes after previous changes

* changed dataloader to be initialised once

* added distributed changes

* changed loader logic

* updated eval script to generate weight API like logs

* improved support for fine-tuning / training resume

* minor changes for finetuning

* updated with transforms from main

* logging distributed deadlock fix

* lint fix

* updated metrics

* weights API log support

* lint fix

* added readme

* updated readme

* updated readme

* read-me update

* remove hardcoded paths. improved valid dataset selection and sync

* removed extras from gitignore
Co-authored-by: default avatarJoao Gomes <jdsgomes@fb.com>
Co-authored-by: default avatarvfdev <vfdev.5@gmail.com>
Co-authored-by: default avatarYosuaMichael <yosuamichaelm@gmail.com>
parent 784ee2b8
# Stereo Matching reference training scripts
This folder contains reference training scripts for Stereo Matching.
They serve as a log of how to train specific models, so as to provide baseline
training and evaluation scripts to quickly bootstrap research.
### CREStereo
The CREStereo model was trained on a dataset mixture between **CREStereo**, **ETH3D** and the additional split from **Middlebury2014**.
A ratio of **88-6-6** was used in order to train a baseline weight set. We provide multi-set variant as well.
Both used 8 A100 GPUs and a batch size of 2 (so effective batch size is 16). The
rest of the hyper-parameters loosely follow the recipe from https://github.com/megvii-research/CREStereo.
The original recipe trains for **300000** updates (or steps) on the dataset mixture. We modify the learning rate
schedule to one that starts decaying the weight much sooner. Throughout experiments we found that this reduces overfitting
during evaluation time and gradient clip help stabilize the loss during a pre-mature learning rate change.
```
torchrun --nproc_per_node 8 --nnodes 1 train.py \
--dataset-root $dataset_root \
--name $name_cre \
--model crestereo_base \
--train-datasets crestereo eth3d-train middlebury2014-other \
--dataset-steps 264000 18000 18000
--batch-size 2 \
--lr 0.0004 \
--min-lr 0.00002 \
--lr-decay-method cosine \
--warmup-steps 6000 \
--decay-after-steps 30000 \
--clip-grad-norm 1.0 \
```
We employ a multi-set fine-tuning stage where we uniformly sample from multiple datasets. Given hat some of these datasets have extremely large images (``2048x2048`` or more) we opt for a very aggresive scale-range ``[0.2 - 0.8]`` such that as much of the original frame composition is captured inside the ``384x512`` crop.
```
torchrun --nproc_per_node 8 --nnodes 1 train.py \
--dataset-root $dataset_root \
--name $name_things \
--model crestereo_base \
--train-datasets crestereo eth3d-train middlebury2014-other instereo2k fallingthings carla-highres sintel sceneflow-monkaa sceneflow-driving \
--dataset-steps 12000 12000 12000 12000 12000 12000 12000 12000 12000
--batch-size 2 \
--scale-range 0.2 0.8 \
--lr 0.0004 \
--lr-decay-method cosine \
--decay-after-steps 0 \
--warmup-steps 0 \
--min-lr 0.00002 \
--resume-path $checkpoint_dir/$name_cre.pth
```
### Evaluation
Evaluating the base weights
```
torchrun --nproc_per_node 1 --nnodes 1 cascade_evaluation.py --dataset middlebury2014-train --batch-size 1 --dataset-root $dataset_root --model crestereo_base --weights CREStereo_Base_Weights.CRESTEREO_ETH_MBL_V1
```
This should give an **mae of about 1.416** on the train set of `Middlebury2014`. Results may vary slightly depending on the batch size and the number of GPUs. For the most accurate resuts use 1 GPU and `--batch-size 1`. The created log file should look like this, where the first key is the number of cascades and the nested key is the number of recursive iterations:
```
Dataset: middlebury2014-train @size: [384, 512]:
{
1: {
2: {'mae': 2.363, 'rmse': 4.352, '1px': 0.611, '3px': 0.828, '5px': 0.891, 'relepe': 0.176, 'fl-all': 64.511}
5: {'mae': 1.618, 'rmse': 3.71, '1px': 0.761, '3px': 0.879, '5px': 0.918, 'relepe': 0.154, 'fl-all': 77.128}
10: {'mae': 1.416, 'rmse': 3.53, '1px': 0.777, '3px': 0.896, '5px': 0.933, 'relepe': 0.148, 'fl-all': 78.388}
20: {'mae': 1.448, 'rmse': 3.583, '1px': 0.771, '3px': 0.893, '5px': 0.931, 'relepe': 0.145, 'fl-all': 77.7}
},
}
{
2: {
2: {'mae': 1.972, 'rmse': 4.125, '1px': 0.73, '3px': 0.865, '5px': 0.908, 'relepe': 0.169, 'fl-all': 74.396}
5: {'mae': 1.403, 'rmse': 3.448, '1px': 0.793, '3px': 0.905, '5px': 0.937, 'relepe': 0.151, 'fl-all': 80.186}
10: {'mae': 1.312, 'rmse': 3.368, '1px': 0.799, '3px': 0.912, '5px': 0.943, 'relepe': 0.148, 'fl-all': 80.379}
20: {'mae': 1.376, 'rmse': 3.542, '1px': 0.796, '3px': 0.91, '5px': 0.942, 'relepe': 0.149, 'fl-all': 80.054}
},
}
```
You can also evaluate the Finetuned weights:
```
torchrun --nproc_per_node 1 --nnodes 1 cascade_evaluation.py --dataset middlebury2014-train --batch-size 1 --dataset-root $dataset_root --model crestereo_base --weights CREStereo_Base_Weights.CRESTEREO_FINETUNE_MULTI_V1
```
```
Dataset: middlebury2014-train @size: [384, 512]:
{
1: {
2: {'mae': 1.85, 'rmse': 3.797, '1px': 0.673, '3px': 0.862, '5px': 0.917, 'relepe': 0.171, 'fl-all': 69.736}
5: {'mae': 1.111, 'rmse': 3.166, '1px': 0.838, '3px': 0.93, '5px': 0.957, 'relepe': 0.134, 'fl-all': 84.596}
10: {'mae': 1.02, 'rmse': 3.073, '1px': 0.854, '3px': 0.938, '5px': 0.96, 'relepe': 0.129, 'fl-all': 86.042}
20: {'mae': 0.993, 'rmse': 3.059, '1px': 0.855, '3px': 0.942, '5px': 0.967, 'relepe': 0.126, 'fl-all': 85.784}
},
}
{
2: {
2: {'mae': 1.667, 'rmse': 3.867, '1px': 0.78, '3px': 0.891, '5px': 0.922, 'relepe': 0.165, 'fl-all': 78.89}
5: {'mae': 1.158, 'rmse': 3.278, '1px': 0.843, '3px': 0.926, '5px': 0.955, 'relepe': 0.135, 'fl-all': 84.556}
10: {'mae': 1.046, 'rmse': 3.13, '1px': 0.85, '3px': 0.934, '5px': 0.96, 'relepe': 0.13, 'fl-all': 85.464}
20: {'mae': 1.021, 'rmse': 3.102, '1px': 0.85, '3px': 0.935, '5px': 0.963, 'relepe': 0.129, 'fl-all': 85.417}
},
}
```
Evaluating the author provided weights:
```
torchrun --nproc_per_node 1 --nnodes 1 cascade_evaluation.py --dataset middlebury2014-train --batch-size 1 --dataset-root $dataset_root --model crestereo_base --weights CREStereo_Base_Weights.MEGVII_V1
```
```
Dataset: middlebury2014-train @size: [384, 512]:
{
1: {
2: {'mae': 1.704, 'rmse': 3.738, '1px': 0.738, '3px': 0.896, '5px': 0.933, 'relepe': 0.157, 'fl-all': 76.464}
5: {'mae': 0.956, 'rmse': 2.963, '1px': 0.88, '3px': 0.948, '5px': 0.965, 'relepe': 0.124, 'fl-all': 88.186}
10: {'mae': 0.792, 'rmse': 2.765, '1px': 0.905, '3px': 0.958, '5px': 0.97, 'relepe': 0.114, 'fl-all': 90.429}
20: {'mae': 0.749, 'rmse': 2.706, '1px': 0.907, '3px': 0.961, '5px': 0.972, 'relepe': 0.113, 'fl-all': 90.807}
},
}
{
2: {
2: {'mae': 1.702, 'rmse': 3.784, '1px': 0.784, '3px': 0.894, '5px': 0.924, 'relepe': 0.172, 'fl-all': 80.313}
5: {'mae': 0.932, 'rmse': 2.907, '1px': 0.877, '3px': 0.944, '5px': 0.963, 'relepe': 0.125, 'fl-all': 87.979}
10: {'mae': 0.773, 'rmse': 2.768, '1px': 0.901, '3px': 0.958, '5px': 0.972, 'relepe': 0.117, 'fl-all': 90.43}
20: {'mae': 0.854, 'rmse': 2.971, '1px': 0.9, '3px': 0.957, '5px': 0.97, 'relepe': 0.122, 'fl-all': 90.269}
},
}
```
# Concerns when training
We encourage users to be aware of the **aspect-ratio** and **disparity scale** they are targetting when doing any sort of training or fine-tuning. The model is highly sensitive to these two factors, as a consequence with naive multi-set fine-tuning one can achieve `0.2 mae` relatively fast. We recommend that users pay close attention to how they **balance dataset sizing** when training such networks.
Ideally, dataset scaling should be trated at an individual level and a thorough **EDA** of the disparity distribution in random crops at the desired training / inference size should be performed prior to any large compute investments.
### Disparity scaling
##### Sample A
The top row contains a sample from `Sintel` whereas the bottom row one from `Middlebury`.
![Disparity1](assets/Disparity%20domain%20drift.jpg)
From left to right (`left_image`, `right_image`, `valid_mask`, `valid_mask & ground_truth`, `prediction`). **Darker is further away, lighter is closer**. In the case of `Sintel` which is more closely aligned to the original distribution of `CREStereo` we notice that the model accurately predicts the background scale whereas in the case of `Middlebury2014` it cannot correcly estimate the continous disparity. Notice that the frame composition is similar for both examples. The blue skybox in the `Sintel` scene behaves similarly to the `Middlebury` black background. However, because the `Middlebury` samples comes from an extremly large scene the crop size of `384x512` does not correctly capture the general training distribution.
##### Sample B
The top row contains a scene from `Sceneflow` using the `Monkaa` split whilst the bottom row is a scene from `Middlebury`. This sample exhibits the same issues when it comes to **background estimation**. Given the exagerated size of the `Middlebury` samples the model **colapses the smooth background** of the sample to what it considers to be a mean background disparity value.
![Disparity2](assets/Disparity%20background%20mode%20collapse.jpg)
For more detail on why this behaviour occurs based on the training distribution proportions you can read more about the network at: https://github.com/pytorch/vision/pull/6629#discussion_r978160493
### Metric overfitting
##### Learning is critical in the beginning
We also advise users to make user of faster training schedules, as the performance gain over long periods time is marginal. Here we exhibit a difference between a faster decay schedule and later decay schedule.
![Loss1](assets/Loss.jpg)
In **grey** we set the lr decay to begin after `30000` steps whilst in **orange** we opt for a very late learning rate decay at around `180000` steps. Although exhibiting stronger variance, we can notice that unfreezing the learning rate earlier whilst employing `gradient-norm` out-performs the default configuration.
##### Gradient norm saves time
![Loss2](assets/Gradient%20Norm%20Removal.jpg)
In **grey** we keep ``gradient norm`` enabled whilst in **orange** we do not. We can notice that remvoing the gradient norm exacerbates the performance decrease in the early stages whilst also showcasing an almost complete collapse around the `60000` steps mark where we started decaying the lr for **orange**.
Although both runs ahieve an improvement of about ``0.1`` mae after the lr decay start, the benefits of it are observable much faster when ``gradient norm`` is employed as the recovery period is no longer accounted for.
import os
import warnings
import torch
import torchvision
import torchvision.prototype.models.depth.stereo
import utils
from torch.nn import functional as F
from train import make_eval_loader
from utils.metrics import AVAILABLE_METRICS
from vizualization import make_prediction_image_side_to_side
def get_args_parser(add_help=True):
import argparse
parser = argparse.ArgumentParser(description="PyTorch Stereo Matching Evaluation", add_help=add_help)
parser.add_argument("--dataset", type=str, default="middlebury2014-train", help="dataset to use")
parser.add_argument("--dataset-root", type=str, default="", help="root of the dataset")
parser.add_argument("--checkpoint", type=str, default="", help="path to weights")
parser.add_argument("--weights", type=str, default=None, help="torchvision API weight")
parser.add_argument(
"--model",
type=str,
default="crestereo_base",
help="which model to use if not speciffying a training checkpoint",
)
parser.add_argument("--img-folder", type=str, default="images")
parser.add_argument("--batch-size", type=int, default=1, help="batch size")
parser.add_argument("--workers", type=int, default=0, help="number of workers")
parser.add_argument("--eval-size", type=int, nargs="+", default=[384, 512], help="resize size")
parser.add_argument(
"--norm-mean", type=float, nargs="+", default=[0.5, 0.5, 0.5], help="mean for image normalization"
)
parser.add_argument(
"--norm-std", type=float, nargs="+", default=[0.5, 0.5, 0.5], help="std for image normalization"
)
parser.add_argument(
"--use-grayscale", action="store_true", help="use grayscale images instead of RGB", default=False
)
parser.add_argument("--max-disparity", type=float, default=None, help="maximum disparity")
parser.add_argument(
"--interpolation-strategy",
type=str,
default="bilinear",
help="interpolation strategy",
choices=["bilinear", "bicubic", "mixed"],
)
parser.add_argument("--n_iterations", nargs="+", type=int, default=[10], help="number of recurent iterations")
parser.add_argument("--n_cascades", nargs="+", type=int, default=[1], help="number of cascades")
parser.add_argument(
"--metrics",
type=str,
nargs="+",
default=["mae", "rmse", "1px", "3px", "5px", "relepe"],
help="metrics to log",
choices=AVAILABLE_METRICS,
)
parser.add_argument("--mixed-precision", action="store_true", help="use mixed precision training")
parser.add_argument("--world-size", type=int, default=1, help="number of distributed processes")
parser.add_argument("--dist-url", type=str, default="env://", help="url used to set up distributed training")
parser.add_argument("--device", type=str, default="cuda", help="device to use for training")
parser.add_argument("--save-images", action="store_true", help="save images of the predictions")
parser.add_argument("--padder-type", type=str, default="kitti", help="padder type", choices=["kitti", "sintel"])
return parser
def cascade_inference(model, image_left, image_right, iterations, cascades):
# check that image size is divisible by 16 * (2 ** (cascades - 1))
for image in [image_left, image_right]:
if image.shape[-2] % ((2 ** (cascades - 1))) != 0:
raise ValueError(
f"image height is not divisible by {16 * (2 ** (cascades - 1))}. Image shape: {image.shape[-2]}"
)
if image.shape[-1] % ((2 ** (cascades - 1))) != 0:
raise ValueError(
f"image width is not divisible by {16 * (2 ** (cascades - 1))}. Image shape: {image.shape[-2]}"
)
left_image_pyramid = [image_left]
right_image_pyramid = [image_right]
for idx in range(0, cascades - 1):
ds_factor = int(2 ** (idx + 1))
ds_shape = (image_left.shape[-2] // ds_factor, image_left.shape[-1] // ds_factor)
left_image_pyramid += F.interpolate(image_left, size=ds_shape, mode="bilinear", align_corners=True).unsqueeze(0)
right_image_pyramid += F.interpolate(image_right, size=ds_shape, mode="bilinear", align_corners=True).unsqueeze(
0
)
flow_init = None
for left_image, right_image in zip(reversed(left_image_pyramid), reversed(right_image_pyramid)):
flow_pred = model(left_image, right_image, flow_init, num_iters=iterations)
# flow pred is a list
flow_init = flow_pred[-1]
return flow_init
@torch.inference_mode()
def _evaluate(
model,
args,
val_loader,
*,
padder_mode,
print_freq=10,
writter=None,
step=None,
iterations=10,
cascades=1,
batch_size=None,
header=None,
save_images=False,
save_path="",
):
"""Helper function to compute various metrics (epe, etc.) for a model on a given dataset.
We process as many samples as possible with ddp.
"""
model.eval()
header = header or "Test:"
device = torch.device(args.device)
metric_logger = utils.MetricLogger(delimiter=" ")
iterations = iterations or args.recurrent_updates
logger = utils.MetricLogger()
for meter_name in args.metrics:
logger.add_meter(meter_name, fmt="{global_avg:.4f}")
if "fl-all" not in args.metrics:
logger.add_meter("fl-all", fmt="{global_avg:.4f}")
num_processed_samples = 0
with torch.cuda.amp.autocast(enabled=args.mixed_precision, dtype=torch.float16):
batch_idx = 0
for blob in metric_logger.log_every(val_loader, print_freq, header):
image_left, image_right, disp_gt, valid_disp_mask = (x.to(device) for x in blob)
padder = utils.InputPadder(image_left.shape, mode=padder_mode)
image_left, image_right = padder.pad(image_left, image_right)
disp_pred = cascade_inference(model, image_left, image_right, iterations, cascades)
disp_pred = disp_pred[:, :1, :, :]
disp_pred = padder.unpad(disp_pred)
if save_images:
if args.distributed:
rank_prefix = args.rank
else:
rank_prefix = 0
make_prediction_image_side_to_side(
disp_pred, disp_gt, valid_disp_mask, save_path, prefix=f"batch_{rank_prefix}_{batch_idx}"
)
metrics, _ = utils.compute_metrics(disp_pred, disp_gt, valid_disp_mask, metrics=logger.meters.keys())
num_processed_samples += image_left.shape[0]
for name in metrics:
logger.meters[name].update(metrics[name], n=1)
batch_idx += 1
num_processed_samples = utils.reduce_across_processes(num_processed_samples) / args.world_size
print("Num_processed_samples: ", num_processed_samples)
if (
hasattr(val_loader.dataset, "__len__")
and len(val_loader.dataset) != num_processed_samples
and torch.distributed.get_rank() == 0
):
warnings.warn(
f"Number of processed samples {num_processed_samples} is different"
f"from the dataset size {len(val_loader.dataset)}. This may happen if"
"the dataset is not divisible by the batch size. Try lowering the batch size for more accurate results."
)
if writter is not None and args.rank == 0:
for meter_name, meter_value in logger.meters.items():
scalar_name = f"{meter_name} {header}"
writter.add_scalar(scalar_name, meter_value.avg, step)
logger.synchronize_between_processes()
print(header, logger)
logger_metrics = {k: v.global_avg for k, v in logger.meters.items()}
return logger_metrics
def evaluate(model, loader, args, writter=None, step=None):
os.makedirs(args.img_folder, exist_ok=True)
checkpoint_name = os.path.basename(args.checkpoint) or args.weights
image_checkpoint_folder = os.path.join(args.img_folder, checkpoint_name)
metrics = {}
base_image_folder = os.path.join(image_checkpoint_folder, args.dataset)
os.makedirs(base_image_folder, exist_ok=True)
for n_cascades in args.n_cascades:
for n_iters in args.n_iterations:
config = f"{n_cascades}c_{n_iters}i"
config_image_folder = os.path.join(base_image_folder, config)
os.makedirs(config_image_folder, exist_ok=True)
metrics[config] = _evaluate(
model,
args,
loader,
padder_mode=args.padder_type,
header=f"{args.dataset} evaluation@ size:{args.eval_size} n_cascades:{n_cascades} n_iters:{n_iters}",
batch_size=args.batch_size,
writter=writter,
step=step,
iterations=n_iters,
cascades=n_cascades,
save_path=config_image_folder,
save_images=args.save_images,
)
metric_log = []
metric_log_dict = {}
# print the final results
for config in metrics:
config_tokens = config.split("_")
config_iters = config_tokens[1][:-1]
config_cascades = config_tokens[0][:-1]
metric_log_dict[config_cascades] = metric_log_dict.get(config_cascades, {})
metric_log_dict[config_cascades][config_iters] = metrics[config]
evaluation_str = f"{args.dataset} evaluation@ size:{args.eval_size} n_cascades:{config_cascades} recurrent_updates:{config_iters}"
metrics_str = f"Metrics: {metrics[config]}"
metric_log.extend([evaluation_str, metrics_str])
print(evaluation_str)
print(metrics_str)
eval_log_name = f"{checkpoint_name.replace('.pth', '')}_eval.log"
print("Saving eval log to: ", eval_log_name)
with open(eval_log_name, "w") as f:
f.write(f"Dataset: {args.dataset} @size: {args.eval_size}:\n")
# write the dict line by line for each key, and each value in the keys
for config_cascades in metric_log_dict:
f.write("{\n")
f.write(f"\t{config_cascades}: {{\n")
for config_iters in metric_log_dict[config_cascades]:
# convert every metric to 4 decimal places
metrics = metric_log_dict[config_cascades][config_iters]
metrics = {k: float(f"{v:.3f}") for k, v in metrics.items()}
f.write(f"\t\t{config_iters}: {metrics}\n")
f.write("\t},\n")
f.write("}\n")
def load_checkpoint(args):
utils.setup_ddp(args)
if not args.weights:
checkpoint = torch.load(args.checkpoint, map_location=torch.device("cpu"))
if "model" in checkpoint:
experiment_args = checkpoint["args"]
model = torchvision.prototype.models.depth.stereo.__dict__[experiment_args.model](weights=None)
model.load_state_dict(checkpoint["model"])
else:
model = torchvision.prototype.models.depth.stereo.__dict__[args.model](weights=None)
model.load_state_dict(checkpoint)
# set the appropiate devices
if args.distributed and args.device == "cpu":
raise ValueError("The device must be cuda if we want to run in distributed mode using torchrun")
device = torch.device(args.device)
else:
model = torchvision.prototype.models.depth.stereo.__dict__[args.model](weights=args.weights)
# convert to DDP if need be
if args.distributed:
model = model.to(args.device)
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
else:
model.to(device)
return model
def main(args):
model = load_checkpoint(args)
loader = make_eval_loader(args.dataset, args)
evaluate(model, loader, args)
if __name__ == "__main__":
args = get_args_parser().parse_args()
main(args)
import argparse
from functools import partial
import torch
from presets import StereoMatchingEvalPreset, StereoMatchingTrainPreset
from torchvision.datasets import (
CarlaStereo,
CREStereo,
ETH3DStereo,
FallingThingsStereo,
InStereo2k,
Kitti2012Stereo,
Kitti2015Stereo,
Middlebury2014Stereo,
SceneFlowStereo,
SintelStereo,
)
VALID_DATASETS = {
"crestereo": partial(CREStereo),
"carla-highres": partial(CarlaStereo),
"instereo2k": partial(InStereo2k),
"sintel": partial(SintelStereo),
"sceneflow-monkaa": partial(SceneFlowStereo, variant="Monkaa", pass_name="both"),
"sceneflow-flyingthings": partial(SceneFlowStereo, variant="FlyingThings3D", pass_name="both"),
"sceneflow-driving": partial(SceneFlowStereo, variant="Driving", pass_name="both"),
"fallingthings": partial(FallingThingsStereo, variant="both"),
"eth3d-train": partial(ETH3DStereo, split="train"),
"eth3d-test": partial(ETH3DStereo, split="test"),
"kitti2015-train": partial(Kitti2015Stereo, split="train"),
"kitti2015-test": partial(Kitti2015Stereo, split="test"),
"kitti2012-train": partial(Kitti2012Stereo, split="train"),
"kitti2012-test": partial(Kitti2012Stereo, split="train"),
"middlebury2014-other": partial(
Middlebury2014Stereo, split="additional", use_ambient_view=True, calibration="both"
),
"middlebury2014-train": partial(Middlebury2014Stereo, split="train", calibration="perfect"),
"middlebury2014-test": partial(Middlebury2014Stereo, split="test", calibration=None),
"middlebury2014-train-ambient": partial(
Middlebury2014Stereo, split="train", use_ambient_views=True, calibrartion="perfect"
),
}
def make_train_transform(args: argparse.Namespace) -> torch.nn.Module:
return StereoMatchingTrainPreset(
resize_size=args.resize_size,
crop_size=args.crop_size,
rescale_prob=args.rescale_prob,
scaling_type=args.scaling_type,
scale_range=args.scale_range,
scale_interpolation_type=args.interpolation_strategy,
use_grayscale=args.use_grayscale,
mean=args.norm_mean,
std=args.norm_std,
horizontal_flip_prob=args.flip_prob,
gpu_transforms=args.gpu_transforms,
max_disparity=args.max_disparity,
spatial_shift_prob=args.spatial_shift_prob,
spatial_shift_max_angle=args.spatial_shift_max_angle,
spatial_shift_max_displacement=args.spatial_shift_max_displacement,
spatial_shift_interpolation_type=args.interpolation_strategy,
gamma_range=args.gamma_range,
brightness=args.brightness_range,
contrast=args.contrast_range,
saturation=args.saturation_range,
hue=args.hue_range,
asymmetric_jitter_prob=args.asymmetric_jitter_prob,
)
def make_eval_transform(args: argparse.Namespace) -> torch.nn.Module:
if args.eval_size is None:
resize_size = args.crop_size
else:
resize_size = args.eval_size
return StereoMatchingEvalPreset(
mean=args.norm_mean,
std=args.norm_std,
use_grayscale=args.use_grayscale,
resize_size=resize_size,
interpolation_type=args.interpolation_strategy,
)
def make_dataset(dataset_name: str, dataset_root: str, transforms: torch.nn.Module) -> torch.utils.data.Dataset:
return VALID_DATASETS[dataset_name](root=dataset_root, transforms=transforms)
import argparse
import os
import warnings
from pathlib import Path
from typing import List, Union
import numpy as np
import torch
import torch.distributed as dist
import torchvision.models.optical_flow
import torchvision.prototype.models.depth.stereo
import utils
import vizualization
from parsing import make_dataset, make_eval_transform, make_train_transform, VALID_DATASETS
from torch import nn
from torchvision.transforms.functional import get_dimensions, InterpolationMode, resize
from utils.metrics import AVAILABLE_METRICS
from utils.norm import freeze_batch_norm
def make_stereo_flow(flow: Union[torch.Tensor, List[torch.Tensor]], model_out_channels: int) -> torch.Tensor:
"""Helper function to make stereo flow from a given model output"""
if isinstance(flow, list):
return [make_stereo_flow(flow_i, model_out_channels) for flow_i in flow]
B, C, H, W = flow.shape
# we need to add zero flow if the model outputs 2 channels
if C == 1 and model_out_channels == 2:
zero_flow = torch.zeros_like(flow)
# by convention the flow is X-Y axis, so we need the Y flow last
flow = torch.cat([flow, zero_flow], dim=1)
return flow
def make_lr_schedule(args: argparse.Namespace, optimizer: torch.optim.Optimizer) -> np.ndarray:
"""Helper function to return a learning rate scheduler for CRE-stereo"""
if args.decay_after_steps < args.warmup_steps:
raise ValueError(f"decay_after_steps: {args.function} must be greater than warmup_steps: {args.warmup_steps}")
warmup_steps = args.warmup_steps if args.warmup_steps else 0
flat_lr_steps = args.decay_after_steps - warmup_steps if args.decay_after_steps else 0
decay_lr_steps = args.total_iterations - flat_lr_steps
max_lr = args.lr
min_lr = args.min_lr
schedulers = []
milestones = []
if warmup_steps > 0:
if args.lr_warmup_method == "linear":
warmup_lr_scheduler = torch.optim.lr_scheduler.LinearLR(
optimizer, start_factor=args.lr_warmup_factor, total_iters=warmup_steps
)
elif args.lr_warmup_method == "constant":
warmup_lr_scheduler = torch.optim.lr_scheduler.ConstantLR(
optimizer, factor=args.lr_warmup_factor, total_iters=warmup_steps
)
else:
raise ValueError(f"Unknown lr warmup method {args.lr_warmup_method}")
schedulers.append(warmup_lr_scheduler)
milestones.append(warmup_steps)
if flat_lr_steps > 0:
flat_lr_scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer, factor=max_lr, total_iters=flat_lr_steps)
schedulers.append(flat_lr_scheduler)
milestones.append(flat_lr_steps + warmup_steps)
if decay_lr_steps > 0:
if args.lr_decay_method == "cosine":
decay_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=decay_lr_steps, eta_min=min_lr
)
elif args.lr_decay_method == "linear":
decay_lr_scheduler = torch.optim.lr_scheduler.LinearLR(
optimizer, start_factor=max_lr, end_factor=min_lr, total_iters=decay_lr_steps
)
elif args.lr_decay_method == "exponential":
decay_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(
optimizer, gamma=args.lr_decay_gamma, last_epoch=-1
)
else:
raise ValueError(f"Unknown lr decay method {args.lr_decay_method}")
schedulers.append(decay_lr_scheduler)
scheduler = torch.optim.lr_scheduler.SequentialLR(optimizer, schedulers, milestones=milestones)
return scheduler
def shuffle_dataset(dataset):
"""Shuffle the dataset"""
perm = torch.randperm(len(dataset))
return torch.utils.data.Subset(dataset, perm)
def resize_dataset_to_n_steps(
dataset: torch.utils.data.Dataset, dataset_steps: int, samples_per_step: int, args: argparse.Namespace
) -> torch.utils.data.Dataset:
original_size = len(dataset)
if args.steps_is_epochs:
samples_per_step = original_size
target_size = dataset_steps * samples_per_step
dataset_copies = []
n_expands, remainder = divmod(target_size, original_size)
for idx in range(n_expands):
dataset_copies.append(dataset)
if remainder > 0:
dataset_copies.append(torch.utils.data.Subset(dataset, list(range(remainder))))
if args.dataset_shuffle:
dataset_copies = [shuffle_dataset(dataset_copy) for dataset_copy in dataset_copies]
dataset = torch.utils.data.ConcatDataset(dataset_copies)
return dataset
def get_train_dataset(dataset_root: str, args: argparse.Namespace) -> torch.utils.data.Dataset:
datasets = []
for dataset_name in args.train_datasets:
transform = make_train_transform(args)
dataset = make_dataset(dataset_name, dataset_root, transform)
datasets.append(dataset)
if len(datasets) == 0:
raise ValueError("No datasets specified for training")
samples_per_step = args.world_size * args.batch_size
for idx, (dataset, steps_per_dataset) in enumerate(zip(datasets, args.dataset_steps)):
datasets[idx] = resize_dataset_to_n_steps(dataset, steps_per_dataset, samples_per_step, args)
dataset = torch.utils.data.ConcatDataset(datasets)
if args.dataset_order_shuffle:
dataset = shuffle_dataset(dataset)
print(f"Training dataset: {len(dataset)} samples")
return dataset
@torch.inference_mode()
def _evaluate(
model,
args,
val_loader,
*,
padder_mode,
print_freq=10,
writter=None,
step=None,
iterations=None,
batch_size=None,
header=None,
):
"""Helper function to compute various metrics (epe, etc.) for a model on a given dataset."""
model.eval()
header = header or "Test:"
device = torch.device(args.device)
metric_logger = utils.MetricLogger(delimiter=" ")
iterations = iterations or args.recurrent_updates
logger = utils.MetricLogger()
for meter_name in args.metrics:
logger.add_meter(meter_name, fmt="{global_avg:.4f}")
if "fl-all" not in args.metrics:
logger.add_meter("fl-all", fmt="{global_avg:.4f}")
num_processed_samples = 0
with torch.cuda.amp.autocast(enabled=args.mixed_precision, dtype=torch.float16):
for blob in metric_logger.log_every(val_loader, print_freq, header):
image_left, image_right, disp_gt, valid_disp_mask = (x.to(device) for x in blob)
padder = utils.InputPadder(image_left.shape, mode=padder_mode)
image_left, image_right = padder.pad(image_left, image_right)
disp_predictions = model(image_left, image_right, flow_init=None, num_iters=iterations)
disp_pred = disp_predictions[-1][:, :1, :, :]
disp_pred = padder.unpad(disp_pred)
metrics, _ = utils.compute_metrics(disp_pred, disp_gt, valid_disp_mask, metrics=logger.meters.keys())
num_processed_samples += image_left.shape[0]
for name in metrics:
logger.meters[name].update(metrics[name], n=1)
num_processed_samples = utils.reduce_across_processes(num_processed_samples)
print("Num_processed_samples: ", num_processed_samples)
if (
hasattr(val_loader.dataset, "__len__")
and len(val_loader.dataset) != num_processed_samples
and torch.distributed.get_rank() == 0
):
warnings.warn(
f"Number of processed samples {num_processed_samples} is different"
f"from the dataset size {len(val_loader.dataset)}. This may happen if"
"the dataset is not divisible by the batch size. Try lowering the batch size or GPU number for more accurate results."
)
if writter is not None and args.rank == 0:
for meter_name, meter_value in logger.meters.items():
scalar_name = f"{meter_name} {header}"
writter.add_scalar(scalar_name, meter_value.avg, step)
logger.synchronize_between_processes()
print(header, logger)
def make_eval_loader(dataset_name: str, args: argparse.Namespace) -> torch.utils.data.DataLoader:
if args.weights:
weights = torchvision.models.get_weight(args.weights)
trans = weights.transforms()
def preprocessing(image_left, image_right, disp, valid_disp_mask):
C_o, H_o, W_o = get_dimensions(image_left)
image_left, image_right = trans(image_left, image_right)
C_t, H_t, W_t = get_dimensions(image_left)
scale_factor = W_t / W_o
if disp is not None and not isinstance(disp, torch.Tensor):
disp = torch.from_numpy(disp)
if W_t != W_o:
disp = resize(disp, (H_t, W_t), mode=InterpolationMode.BILINEAR) * scale_factor
if valid_disp_mask is not None and not isinstance(valid_disp_mask, torch.Tensor):
valid_disp_mask = torch.from_numpy(valid_disp_mask)
if W_t != W_o:
valid_disp_mask = resize(valid_disp_mask, (H_t, W_t), mode=InterpolationMode.NEAREST)
return image_left, image_right, disp, valid_disp_mask
else:
preprocessing = make_eval_transform(args)
val_dataset = make_dataset(dataset_name, args.dataset_root, transforms=preprocessing)
if args.distributed:
sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False, drop_last=False)
else:
sampler = torch.utils.data.SequentialSampler(val_dataset)
val_loader = torch.utils.data.DataLoader(
val_dataset,
sampler=sampler,
batch_size=args.batch_size,
pin_memory=True,
num_workers=args.workers,
)
return val_loader
def evaluate(model, loaders, args, writter=None, step=None):
for loader_name, loader in loaders.items():
_evaluate(
model,
args,
loader,
iterations=args.recurrent_updates,
padder_mode=args.padder_type,
header=f"{loader_name} evaluation",
batch_size=args.batch_size,
writter=writter,
step=step,
)
def run(model, optimizer, scheduler, train_loader, val_loaders, logger, writer, scaler, args):
device = torch.device(args.device)
# wrap the loader in a logger
loader = iter(logger.log_every(train_loader))
# output channels
model_out_channels = model.module.output_channels if args.distributed else model.output_channels
torch.set_num_threads(args.threads)
sequence_criterion = utils.SequenceLoss(
gamma=args.gamma,
max_flow=args.max_disparity,
exclude_large_flows=args.flow_loss_exclude_large,
).to(device)
if args.consistency_weight:
consistency_criterion = utils.FlowSequenceConsistencyLoss(
args.gamma,
resize_factor=0.25,
rescale_factor=0.25,
rescale_mode="bilinear",
).to(device)
else:
consistency_criterion = None
if args.psnr_weight:
psnr_criterion = utils.PSNRLoss().to(device)
else:
psnr_criterion = None
if args.smoothness_weight:
smoothness_criterion = utils.SmoothnessLoss().to(device)
else:
smoothness_criterion = None
if args.photometric_weight:
photometric_criterion = utils.FlowPhotoMetricLoss(
ssim_weight=args.photometric_ssim_weight,
max_displacement_ratio=args.photometric_max_displacement_ratio,
ssim_use_padding=False,
).to(device)
else:
photometric_criterion = None
for step in range(args.start_step + 1, args.total_iterations + 1):
data_blob = next(loader)
optimizer.zero_grad()
# unpack the data blob
image_left, image_right, disp_mask, valid_disp_mask = (x.to(device) for x in data_blob)
with torch.cuda.amp.autocast(enabled=args.mixed_precision, dtype=torch.float16):
disp_predictions = model(image_left, image_right, flow_init=None, num_iters=args.recurrent_updates)
# different models have different outputs, make sure we get the right ones for this task
disp_predictions = make_stereo_flow(disp_predictions, model_out_channels)
# should the architecture or training loop require it, we have to adjust the disparity mask
# target to possibly look like an optical flow mask
disp_mask = make_stereo_flow(disp_mask, model_out_channels)
# sequence loss on top of the model outputs
loss = sequence_criterion(disp_predictions, disp_mask, valid_disp_mask) * args.flow_loss_weight
if args.consistency_weight > 0:
loss_consistency = consistency_criterion(disp_predictions)
loss += loss_consistency * args.consistency_weight
if args.psnr_weight > 0:
loss_psnr = 0.0
for pred in disp_predictions:
# predictions might have 2 channels
loss_psnr += psnr_criterion(
pred * valid_disp_mask.unsqueeze(1),
disp_mask * valid_disp_mask.unsqueeze(1),
).mean() # mean the psnr loss over the batch
loss += loss_psnr / len(disp_predictions) * args.psnr_weight
if args.photometric_weight > 0:
loss_photometric = 0.0
for pred in disp_predictions:
# predictions might have 1 channel, therefore we need to inpute 0s for the second channel
if model_out_channels == 1:
pred = torch.cat([pred, torch.zeros_like(pred)], dim=1)
loss_photometric += photometric_criterion(
image_left, image_right, pred, valid_disp_mask
) # photometric loss already comes out meaned over the batch
loss += loss_photometric / len(disp_predictions) * args.photometric_weight
if args.smoothness_weight > 0:
loss_smoothness = 0.0
for pred in disp_predictions:
# predictions might have 2 channels
loss_smoothness += smoothness_criterion(
image_left, pred[:, :1, :, :]
).mean() # mean the smoothness loss over the batch
loss += loss_smoothness / len(disp_predictions) * args.smoothness_weight
with torch.no_grad():
metrics, _ = utils.compute_metrics(
disp_predictions[-1][:, :1, :, :], # predictions might have 2 channels
disp_mask[:, :1, :, :], # so does the ground truth
valid_disp_mask,
args.metrics,
)
metrics.pop("fl-all", None)
logger.update(loss=loss, **metrics)
if scaler is not None:
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
if args.clip_grad_norm:
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=args.clip_grad_norm)
scaler.step(optimizer)
scaler.update()
else:
loss.backward()
if args.clip_grad_norm:
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=args.clip_grad_norm)
optimizer.step()
scheduler.step()
if not dist.is_initialized() or dist.get_rank() == 0:
if writer is not None and step % args.tensorboard_log_frequency == 0:
# log the loss and metrics to tensorboard
writer.add_scalar("loss", loss, step)
for name, value in logger.meters.items():
writer.add_scalar(name, value.avg, step)
# log the images to tensorboard
pred_grid = vizualization.make_training_sample_grid(
image_left, image_right, disp_mask, valid_disp_mask, disp_predictions
)
writer.add_image("predictions", pred_grid, step, dataformats="HWC")
# second thing we want to see is how relevant the iterative refinement is
pred_sequence_grid = vizualization.make_disparity_sequence_grid(disp_predictions, disp_mask)
writer.add_image("sequence", pred_sequence_grid, step, dataformats="HWC")
if step % args.save_frequency == 0:
if not args.distributed or args.rank == 0:
model_without_ddp = (
model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model
)
checkpoint = {
"model": model_without_ddp.state_dict(),
"optimizer": optimizer.state_dict(),
"scheduler": scheduler.state_dict(),
"step": step,
"args": args,
}
os.makedirs(args.checkpoint_dir, exist_ok=True)
torch.save(checkpoint, Path(args.checkpoint_dir) / f"{args.name}_{step}.pth")
torch.save(checkpoint, Path(args.checkpoint_dir) / f"{args.name}.pth")
if step % args.valid_frequency == 0:
evaluate(model, val_loaders, args, writer, step)
model.train()
if args.freeze_batch_norm:
if isinstance(model, nn.parallel.DistributedDataParallel):
freeze_batch_norm(model.module)
else:
freeze_batch_norm(model)
# one final save at the end
if not args.distributed or args.rank == 0:
model_without_ddp = model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model
checkpoint = {
"model": model_without_ddp.state_dict(),
"optimizer": optimizer.state_dict(),
"scheduler": scheduler.state_dict(),
"step": step,
"args": args,
}
os.makedirs(args.checkpoint_dir, exist_ok=True)
torch.save(checkpoint, Path(args.checkpoint_dir) / f"{args.name}_{step}.pth")
torch.save(checkpoint, Path(args.checkpoint_dir) / f"{args.name}.pth")
def main(args):
args.total_iterations = sum(args.dataset_steps)
# intialize DDP setting
utils.setup_ddp(args)
print(args)
args.test_only = args.train_datasets is None
# set the appropiate devices
if args.distributed and args.device == "cpu":
raise ValueError("The device must be cuda if we want to run in distributed mode using torchrun")
device = torch.device(args.device)
# select model architecture
model = torchvision.prototype.models.depth.stereo.__dict__[args.model](weights=args.weights)
# convert to DDP if need be
if args.distributed:
model = model.to(args.gpu)
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
model_without_ddp = model.module
else:
model.to(device)
model_without_ddp = model
os.makedirs(args.checkpoint_dir, exist_ok=True)
val_loaders = {name: make_eval_loader(name, args) for name in args.test_datasets}
# EVAL ONLY configurations
if args.test_only:
evaluate(model, val_loaders, args)
return
# Sanity check for the parameter count
print(f"Parameter Count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
# Compose the training dataset
train_dataset = get_train_dataset(args.dataset_root, args)
# initialize the optimizer
if args.optimizer == "adam":
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
elif args.optimizer == "sgd":
optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, momentum=0.9)
else:
raise ValueError(f"Unknown optimizer {args.optimizer}. Please choose between adam and sgd")
# initialize the learning rate schedule
scheduler = make_lr_schedule(args, optimizer)
# load them from checkpoint if need
args.start_step = 0
if args.resume_path is not None:
checkpoint = torch.load(args.resume_path, map_location="cpu")
if "model" in checkpoint:
# this means the user requested to resume from a training checkpoint
model_without_ddp.load_state_dict(checkpoint["model"])
# this means the user wants to continue training from where it was left off
if args.resume_schedule:
optimizer.load_state_dict(checkpoint["optimizer"])
scheduler.load_state_dict(checkpoint["scheduler"])
args.start_step = checkpoint["step"] + 1
# modify starting point of the dat
sample_start_step = args.start_step * args.batch_size * args.world_size
train_dataset = train_dataset[sample_start_step:]
else:
# this means the user wants to finetune on top of a model state dict
# and that no other changes are required
model_without_ddp.load_state_dict(checkpoint)
torch.backends.cudnn.benchmark = True
# enable training mode
model.train()
if args.freeze_batch_norm:
freeze_batch_norm(model_without_ddp)
# put dataloader on top of the dataset
# make sure to disable shuffling since the dataset is already shuffled
# in order to guarantee quasi randomness whilst retaining a deterministic
# dataset consumption order
if args.distributed:
# the train dataset is preshuffled in order to respect the iteration order
sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, shuffle=False, drop_last=True)
else:
# the train dataset is already shuffled so we can use a simple SequentialSampler
sampler = torch.utils.data.SequentialSampler(train_dataset)
train_loader = torch.utils.data.DataLoader(
train_dataset,
sampler=sampler,
batch_size=args.batch_size,
pin_memory=True,
num_workers=args.workers,
)
# intialize the logger
if args.tensorboard_summaries:
from torch.utils.tensorboard import SummaryWriter
tensorboard_path = Path(args.checkpoint_dir) / "tensorboard"
os.makedirs(tensorboard_path, exist_ok=True)
tensorboard_run = tensorboard_path / f"{args.name}"
writer = SummaryWriter(tensorboard_run)
else:
writer = None
logger = utils.MetricLogger(delimiter=" ")
scaler = torch.cuda.amp.GradScaler() if args.mixed_precision else None
# run the training loop
# this will perform optimization, respectively logging and saving checkpoints
# when need be
run(
model=model,
optimizer=optimizer,
scheduler=scheduler,
train_loader=train_loader,
val_loaders=val_loaders,
logger=logger,
writer=writer,
scaler=scaler,
args=args,
)
def get_args_parser(add_help=True):
import argparse
parser = argparse.ArgumentParser(description="PyTorch Stereo Matching Training", add_help=add_help)
# checkpointing
parser.add_argument("--name", default="crestereo", help="name of the experiment")
parser.add_argument("--resume", type=str, default=None, help="from which checkpoint to resume")
parser.add_argument("--checkpoint-dir", type=str, default="checkpoints", help="path to the checkpoint directory")
# dataset
parser.add_argument("--dataset-root", type=str, default="", help="path to the dataset root directory")
parser.add_argument(
"--train-datasets",
type=str,
nargs="+",
default=["crestereo"],
help="dataset(s) to train on",
choices=list(VALID_DATASETS.keys()),
)
parser.add_argument(
"--dataset-steps", type=int, nargs="+", default=[300_000], help="number of steps for each dataset"
)
parser.add_argument(
"--steps-is-epochs", action="store_true", help="if set, dataset-steps are interpreted as epochs"
)
parser.add_argument(
"--test-datasets",
type=str,
nargs="+",
default=["middlebury2014-train"],
help="dataset(s) to test on",
choices=["middlebury2014-train"],
)
parser.add_argument("--dataset-shuffle", type=bool, help="shuffle the dataset", default=True)
parser.add_argument("--dataset-order-shuffle", type=bool, help="shuffle the dataset order", default=True)
parser.add_argument("--batch-size", type=int, default=2, help="batch size per GPU")
parser.add_argument("--workers", type=int, default=4, help="number of workers per GPU")
parser.add_argument(
"--threads",
type=int,
default=16,
help="number of CPU threads per GPU. This can be changed around to speed-up transforms if needed. This can lead to worker thread contention so use with care.",
)
# model architecture
parser.add_argument(
"--model",
type=str,
default="crestereo_base",
help="model architecture",
choices=["crestereo_base", "raft_stereo"],
)
parser.add_argument("--recurrent-updates", type=int, default=10, help="number of recurrent updates")
parser.add_argument("--freeze-batch-norm", action="store_true", help="freeze batch norm parameters")
# loss parameters
parser.add_argument("--gamma", type=float, default=0.8, help="gamma parameter for the flow sequence loss")
parser.add_argument("--flow-loss-weight", type=float, default=1.0, help="weight for the flow loss")
parser.add_argument(
"--flow-loss-exclude-large",
action="store_true",
help="exclude large flow values from the loss. A large value is defined as a value greater than the ground truth flow norm",
default=False,
)
parser.add_argument("--consistency-weight", type=float, default=0.0, help="consistency loss weight")
parser.add_argument(
"--consistency-resize-factor",
type=float,
default=0.25,
help="consistency loss resize factor to account for the fact that the flow is computed on a downsampled image",
)
parser.add_argument("--psnr-weight", type=float, default=0.0, help="psnr loss weight")
parser.add_argument("--smoothness-weight", type=float, default=0.0, help="smoothness loss weight")
parser.add_argument("--photometric-weight", type=float, default=0.0, help="photometric loss weight")
parser.add_argument(
"--photometric-max-displacement-ratio",
type=float,
default=0.15,
help="Only pixels with a displacement smaller than this ratio of the image width will be considered for the photometric loss",
)
parser.add_argument("--photometric-ssim-weight", type=float, default=0.85, help="photometric ssim loss weight")
# transforms parameters
parser.add_argument("--gpu-transforms", action="store_true", help="use GPU transforms")
parser.add_argument(
"--eval-size", type=int, nargs="+", default=[384, 512], help="size of the images for evaluation"
)
parser.add_argument("--resize-size", type=int, nargs=2, default=None, help="resize size")
parser.add_argument("--crop-size", type=int, nargs=2, default=[384, 512], help="crop size")
parser.add_argument("--scale-range", type=float, nargs=2, default=[0.6, 1.0], help="random scale range")
parser.add_argument("--rescale-prob", type=float, default=1.0, help="probability of resizing the image")
parser.add_argument(
"--scaling-type", type=str, default="linear", help="scaling type", choices=["exponential", "linear"]
)
parser.add_argument("--flip-prob", type=float, default=0.5, help="probability of flipping the image")
parser.add_argument(
"--norm-mean", type=float, nargs="+", default=[0.5, 0.5, 0.5], help="mean for image normalization"
)
parser.add_argument(
"--norm-std", type=float, nargs="+", default=[0.5, 0.5, 0.5], help="std for image normalization"
)
parser.add_argument(
"--use-grayscale", action="store_true", help="use grayscale images instead of RGB", default=False
)
parser.add_argument("--max-disparity", type=float, default=None, help="maximum disparity")
parser.add_argument(
"--interpolation-strategy",
type=str,
default="bilinear",
help="interpolation strategy",
choices=["bilinear", "bicubic", "mixed"],
)
parser.add_argument("--spatial-shift-prob", type=float, default=1.0, help="probability of shifting the image")
parser.add_argument(
"--spatial-shift-max-angle", type=float, default=0.1, help="maximum angle for the spatial shift"
)
parser.add_argument(
"--spatial-shift-max-displacement", type=float, default=2.0, help="maximum displacement for the spatial shift"
)
parser.add_argument("--gamma-range", type=float, nargs="+", default=[0.8, 1.2], help="range for gamma correction")
parser.add_argument(
"--brightness-range", type=float, nargs="+", default=[0.8, 1.2], help="range for brightness correction"
)
parser.add_argument(
"--contrast-range", type=float, nargs="+", default=[0.8, 1.2], help="range for contrast correction"
)
parser.add_argument(
"--saturation-range", type=float, nargs="+", default=0.0, help="range for saturation correction"
)
parser.add_argument("--hue-range", type=float, nargs="+", default=0.0, help="range for hue correction")
parser.add_argument(
"--asymmetric-jitter-prob",
type=float,
default=1.0,
help="probability of using asymmetric jitter instead of symmetric jitter",
)
parser.add_argument("--occlusion-prob", type=float, default=0.5, help="probability of occluding the rightimage")
parser.add_argument(
"--occlusion-px-range", type=int, nargs="+", default=[50, 100], help="range for the number of occluded pixels"
)
parser.add_argument("--erase-prob", type=float, default=0.0, help="probability of erasing in both images")
parser.add_argument(
"--erase-px-range", type=int, nargs="+", default=[50, 100], help="range for the number of erased pixels"
)
parser.add_argument(
"--erase-num-repeats", type=int, default=1, help="number of times to repeat the erase operation"
)
# optimizer parameters
parser.add_argument("--optimizer", type=str, default="adam", help="optimizer", choices=["adam", "sgd"])
parser.add_argument("--lr", type=float, default=4e-4, help="learning rate")
parser.add_argument("--weight-decay", type=float, default=0.0, help="weight decay")
parser.add_argument("--clip-grad-norm", type=float, default=0.0, help="clip grad norm")
# lr_scheduler parameters
parser.add_argument("--min-lr", type=float, default=2e-5, help="minimum learning rate")
parser.add_argument("--warmup-steps", type=int, default=6_000, help="number of warmup steps")
parser.add_argument(
"--decay-after-steps", type=int, default=180_000, help="number of steps after which to start decay the lr"
)
parser.add_argument(
"--lr-warmup-method", type=str, default="linear", help="warmup method", choices=["linear", "cosine"]
)
parser.add_argument("--lr-warmup-factor", type=float, default=0.02, help="warmup factor for the learning rate")
parser.add_argument(
"--lr-decay-method",
type=str,
default="linear",
help="decay method",
choices=["linear", "cosine", "exponential"],
)
parser.add_argument("--lr-decay-gamma", type=float, default=0.8, help="decay factor for the learning rate")
# deterministic behaviour
parser.add_argument("--seed", type=int, default=42, help="seed for random number generators")
# mixed precision training
parser.add_argument("--mixed-precision", action="store_true", help="use mixed precision training")
# logging
parser.add_argument("--tensorboard-summaries", action="store_true", help="log to tensorboard")
parser.add_argument("--tensorboard-log-frequency", type=int, default=100, help="log frequency")
parser.add_argument("--save-frequency", type=int, default=1_000, help="save frequency")
parser.add_argument("--valid-frequency", type=int, default=1_000, help="validation frequency")
parser.add_argument(
"--metrics",
type=str,
nargs="+",
default=["mae", "rmse", "1px", "3px", "5px", "relepe"],
help="metrics to log",
choices=AVAILABLE_METRICS,
)
# distributed parameters
parser.add_argument("--world-size", type=int, default=8, help="number of distributed processes")
parser.add_argument("--dist-url", type=str, default="env://", help="url used to set up distributed training")
parser.add_argument("--device", type=str, default="cuda", help="device to use for training")
# weights API
parser.add_argument("--weights", type=str, default=None, help="weights API url")
parser.add_argument(
"--resume-path", type=str, default=None, help="a path from which to resume or start fine-tuning"
)
parser.add_argument("--resume-schedule", action="store_true", help="resume optimizer state")
# padder parameters
parser.add_argument("--padder-type", type=str, default="kitti", help="padder type", choices=["kitti", "sintel"])
return parser
if __name__ == "__main__":
args = get_args_parser().parse_args()
main(args)
from .losses import *
from .metrics import *
from .distributed import *
from .logger import *
from .padder import *
from .norm import *
import os
import torch
import torch.distributed as dist
def _redefine_print(is_main):
"""disables printing when not in main process"""
import builtins as __builtin__
builtin_print = __builtin__.print
def print(*args, **kwargs):
force = kwargs.pop("force", False)
if is_main or force:
builtin_print(*args, **kwargs)
__builtin__.print = print
def setup_ddp(args):
# Set the local_rank, rank, and world_size values as args fields
# This is done differently depending on how we're running the script. We
# currently support either torchrun or the custom run_with_submitit.py
# If you're confused (like I was), this might help a bit
# https://discuss.pytorch.org/t/what-is-the-difference-between-rank-and-local-rank/61940/2
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
args.rank = int(os.environ["RANK"])
args.world_size = int(os.environ["WORLD_SIZE"])
args.gpu = int(os.environ["LOCAL_RANK"])
elif "SLURM_PROCID" in os.environ:
args.rank = int(os.environ["SLURM_PROCID"])
args.gpu = args.rank % torch.cuda.device_count()
elif hasattr(args, "rank"):
pass
else:
print("Not using distributed mode")
args.distributed = False
args.world_size = 1
return
args.distributed = True
torch.cuda.set_device(args.gpu)
dist.init_process_group(
backend="nccl",
rank=args.rank,
world_size=args.world_size,
init_method=args.dist_url,
)
torch.distributed.barrier()
_redefine_print(is_main=(args.rank == 0))
def reduce_across_processes(val):
t = torch.tensor(val, device="cuda")
dist.barrier()
dist.all_reduce(t)
return t
import datetime
import time
from collections import defaultdict, deque
import torch
from .distributed import reduce_across_processes
class SmoothedValue:
"""Track a series of values and provide access to smoothed values over a
window or the global series average.
"""
def __init__(self, window_size=20, fmt="{median:.4f} ({global_avg:.4f})"):
self.deque = deque(maxlen=window_size)
self.total = 0.0
self.count = 0
self.fmt = fmt
def update(self, value, n=1):
self.deque.append(value)
self.count += n
self.total += value * n
def synchronize_between_processes(self):
"""
Warning: does not synchronize the deque!
"""
t = reduce_across_processes([self.count, self.total])
t = t.tolist()
self.count = int(t[0])
self.total = t[1]
@property
def median(self):
d = torch.tensor(list(self.deque))
return d.median().item()
@property
def avg(self):
d = torch.tensor(list(self.deque), dtype=torch.float32)
return d.mean().item()
@property
def global_avg(self):
return self.total / self.count
@property
def max(self):
return max(self.deque)
@property
def value(self):
return self.deque[-1]
def __str__(self):
return self.fmt.format(
median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value
)
class MetricLogger:
def __init__(self, delimiter="\t"):
self.meters = defaultdict(SmoothedValue)
self.delimiter = delimiter
def update(self, **kwargs):
for k, v in kwargs.items():
if isinstance(v, torch.Tensor):
v = v.item()
if not isinstance(v, (float, int)):
raise TypeError(
f"This method expects the value of the input arguments to be of type float or int, instead got {type(v)}"
)
self.meters[k].update(v)
def __getattr__(self, attr):
if attr in self.meters:
return self.meters[attr]
if attr in self.__dict__:
return self.__dict__[attr]
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{attr}'")
def __str__(self):
loss_str = []
for name, meter in self.meters.items():
loss_str.append(f"{name}: {str(meter)}")
return self.delimiter.join(loss_str)
def synchronize_between_processes(self):
for meter in self.meters.values():
meter.synchronize_between_processes()
def add_meter(self, name, **kwargs):
self.meters[name] = SmoothedValue(**kwargs)
def log_every(self, iterable, print_freq=5, header=None):
i = 0
if not header:
header = ""
start_time = time.time()
end = time.time()
iter_time = SmoothedValue(fmt="{avg:.4f}")
data_time = SmoothedValue(fmt="{avg:.4f}")
space_fmt = ":" + str(len(str(len(iterable)))) + "d"
if torch.cuda.is_available():
log_msg = self.delimiter.join(
[
header,
"[{0" + space_fmt + "}/{1}]",
"eta: {eta}",
"{meters}",
"time: {time}",
"data: {data}",
"max mem: {memory:.0f}",
]
)
else:
log_msg = self.delimiter.join(
[header, "[{0" + space_fmt + "}/{1}]", "eta: {eta}", "{meters}", "time: {time}", "data: {data}"]
)
MB = 1024.0 * 1024.0
for obj in iterable:
data_time.update(time.time() - end)
yield obj
iter_time.update(time.time() - end)
if print_freq is not None and i % print_freq == 0:
eta_seconds = iter_time.global_avg * (len(iterable) - i)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
if torch.cuda.is_available():
print(
log_msg.format(
i,
len(iterable),
eta=eta_string,
meters=str(self),
time=str(iter_time),
data=str(data_time),
memory=torch.cuda.max_memory_allocated() / MB,
)
)
else:
print(
log_msg.format(
i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time)
)
)
i += 1
end = time.time()
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print(f"{header} Total time: {total_time_str}")
from typing import Dict, List, Optional, Tuple
from torch import Tensor
AVAILABLE_METRICS = ["mae", "rmse", "epe", "bad1", "bad2", "epe", "1px", "3px", "5px", "fl-all", "relepe"]
def compute_metrics(
flow_pred: Tensor, flow_gt: Tensor, valid_flow_mask: Optional[Tensor], metrics: List[str]
) -> Tuple[Dict[str, float], int]:
for m in metrics:
if m not in AVAILABLE_METRICS:
raise ValueError(f"Invalid metric: {m}. Valid metrics are: {AVAILABLE_METRICS}")
metrics_dict = {}
pixels_diffs = (flow_pred - flow_gt).abs()
# there is no Y flow in Stereo Matching, therefor flow.abs() = flow.pow(2).sum(dim=1).sqrt()
flow_norm = flow_gt.abs()
if valid_flow_mask is not None:
valid_flow_mask = valid_flow_mask.unsqueeze(1)
pixels_diffs = pixels_diffs[valid_flow_mask]
flow_norm = flow_norm[valid_flow_mask]
num_pixels = pixels_diffs.numel()
if "bad1" in metrics:
metrics_dict["bad1"] = (pixels_diffs > 1).float().mean().item()
if "bad2" in metrics:
metrics_dict["bad2"] = (pixels_diffs > 2).float().mean().item()
if "mae" in metrics:
metrics_dict["mae"] = pixels_diffs.mean().item()
if "rmse" in metrics:
metrics_dict["rmse"] = pixels_diffs.pow(2).mean().sqrt().item()
if "epe" in metrics:
metrics_dict["epe"] = pixels_diffs.mean().item()
if "1px" in metrics:
metrics_dict["1px"] = (pixels_diffs < 1).float().mean().item()
if "3px" in metrics:
metrics_dict["3px"] = (pixels_diffs < 3).float().mean().item()
if "5px" in metrics:
metrics_dict["5px"] = (pixels_diffs < 5).float().mean().item()
if "fl-all" in metrics:
metrics_dict["fl-all"] = ((pixels_diffs < 3) & ((pixels_diffs / flow_norm) < 0.05)).float().mean().item() * 100
if "relepe" in metrics:
metrics_dict["relepe"] = (pixels_diffs / flow_norm).mean().item()
return metrics_dict, num_pixels
import torch
def freeze_batch_norm(model):
for m in model.modules():
if isinstance(m, torch.nn.BatchNorm2d):
m.eval()
def unfreeze_batch_norm(model):
for m in model.modules():
if isinstance(m, torch.nn.BatchNorm2d):
m.train()
import torch.nn.functional as F
class InputPadder:
"""Pads images such that dimensions are divisible by 8"""
# TODO: Ideally, this should be part of the eval transforms preset, instead
# of being part of the validation code. It's not obvious what a good
# solution would be, because we need to unpad the predicted flows according
# to the input images' size, and in some datasets (Kitti) images can have
# variable sizes.
def __init__(self, dims, mode="sintel"):
self.ht, self.wd = dims[-2:]
pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8
pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8
if mode == "sintel":
self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, pad_ht // 2, pad_ht - pad_ht // 2]
else:
self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, 0, pad_ht]
def pad(self, *inputs):
return [F.pad(x, self._pad, mode="replicate") for x in inputs]
def unpad(self, x):
ht, wd = x.shape[-2:]
c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]]
return x[..., c[0] : c[1], c[2] : c[3]]
import os
from typing import List
import numpy as np
import torch
from torch import Tensor
from torchvision.utils import make_grid
@torch.no_grad()
def make_disparity_image(disparity: Tensor):
# normalize image to [0, 1]
disparity = disparity.detach().cpu()
disparity = (disparity - disparity.min()) / (disparity.max() - disparity.min())
return disparity
@torch.no_grad()
def make_disparity_image_pairs(disparity: Tensor, image: Tensor):
disparity = make_disparity_image(disparity)
# image is in [-1, 1], bring it to [0, 1]
image = image.detach().cpu()
image = image * 0.5 + 0.5
return disparity, image
@torch.no_grad()
def make_disparity_sequence(disparities: List[Tensor]):
# convert each disparity to [0, 1]
for idx, disparity_batch in enumerate(disparities):
disparities[idx] = torch.stack(list(map(make_disparity_image, disparity_batch)))
# make the list into a batch
disparity_sequences = torch.stack(disparities)
return disparity_sequences
@torch.no_grad()
def make_pair_grid(*inputs, orientation="horizontal"):
# make a grid of images with the outputs and references side by side
if orientation == "horizontal":
# interleave the outputs and references
canvas = torch.zeros_like(inputs[0])
canvas = torch.cat([canvas] * len(inputs), dim=0)
size = len(inputs)
for idx, inp in enumerate(inputs):
canvas[idx::size, ...] = inp
grid = make_grid(canvas, nrow=len(inputs), padding=16, normalize=True, scale_each=True)
elif orientation == "vertical":
# interleave the outputs and references
canvas = torch.cat(inputs, dim=0)
size = len(inputs)
for idx, inp in enumerate(inputs):
canvas[idx::size, ...] = inp
grid = make_grid(canvas, nrow=len(inputs[0]), padding=16, normalize=True, scale_each=True)
else:
raise ValueError("Unknown orientation: {}".format(orientation))
return grid
@torch.no_grad()
def make_training_sample_grid(
left_images: Tensor,
right_images: Tensor,
disparities: Tensor,
masks: Tensor,
predictions: List[Tensor],
) -> np.ndarray:
# detach images and renormalize to [0, 1]
images_left = left_images.detach().cpu() * 0.5 + 0.5
images_right = right_images.detach().cpu() * 0.5 + 0.5
# detach the disparties and predictions
disparities = disparities.detach().cpu()
predictions = predictions[-1].detach().cpu()
# keep only the first channel of pixels, and repeat it 3 times
disparities = disparities[:, :1, ...].repeat(1, 3, 1, 1)
predictions = predictions[:, :1, ...].repeat(1, 3, 1, 1)
# unsqueeze and repeat the masks
masks = masks.detach().cpu().unsqueeze(1).repeat(1, 3, 1, 1)
# make a grid that will self normalize across the batch
pred_grid = make_pair_grid(images_left, images_right, masks, disparities, predictions, orientation="horizontal")
pred_grid = pred_grid.permute(1, 2, 0).numpy()
pred_grid = (pred_grid * 255).astype(np.uint8)
return pred_grid
@torch.no_grad()
def make_disparity_sequence_grid(predictions: List[Tensor], disparities: Tensor) -> np.ndarray:
# right most we will be adding the ground truth
seq_len = len(predictions) + 1
predictions = list(map(lambda x: x[:, :1, :, :].detach().cpu(), predictions + [disparities]))
sequence = make_disparity_sequence(predictions)
# swap axes to have the in the correct order for each batch sample
sequence = torch.swapaxes(sequence, 0, 1).contiguous().reshape(-1, 1, disparities.shape[-2], disparities.shape[-1])
sequence = make_grid(sequence, nrow=seq_len, padding=16, normalize=True, scale_each=True)
sequence = sequence.permute(1, 2, 0).numpy()
sequence = (sequence * 255).astype(np.uint8)
return sequence
@torch.no_grad()
def make_prediction_image_side_to_side(
predictions: Tensor, disparities: Tensor, valid_mask: Tensor, save_path: str, prefix: str
) -> None:
import matplotlib.pyplot as plt
# normalize the predictions and disparities in [0, 1]
predictions = (predictions - predictions.min()) / (predictions.max() - predictions.min())
disparities = (disparities - disparities.min()) / (disparities.max() - disparities.min())
predictions = predictions * valid_mask
disparities = disparities * valid_mask
predictions = predictions.detach().cpu()
disparities = disparities.detach().cpu()
for idx, (pred, gt) in enumerate(zip(predictions, disparities)):
pred = pred.permute(1, 2, 0).numpy()
gt = gt.permute(1, 2, 0).numpy()
# plot pred and gt side by side
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].imshow(pred)
ax[0].set_title("Prediction")
ax[1].imshow(gt)
ax[1].set_title("Ground Truth")
save_name = os.path.join(save_path, "{}_{}.png".format(prefix, idx))
plt.savefig(save_name)
plt.close()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment