Unverified Commit 673838f5 authored by YosuaMichael's avatar YosuaMichael Committed by GitHub
Browse files

Removing prototype related things from release/0.14 branch (#6687)

* Remove test related to prototype

* Remove torchvision/prototype dir

* Remove references/depth/stereo because it depend on prototype

* Remove prototype related entries on mypy.ini

* Remove things related to prototype in pytest.ini

* clean setup.py from prototype

* Clean CI from prototype

* Remove unused expect file
parent 07ae61bf
......@@ -152,15 +152,6 @@ commands:
args: --no-build-isolation <<# parameters.editable >> --editable <</ parameters.editable >> .
descr: Install torchvision <<# parameters.editable >> in editable mode <</ parameters.editable >>
install_prototype_dependencies:
steps:
- pip_install:
args: iopath
descr: Install third-party dependencies
- pip_install:
args: --pre torchdata --extra-index-url https://download.pytorch.org/whl/nightly/cpu
descr: Install torchdata from nightly releases
# Most of the test suite is handled by the `unittest` jobs, with completely different workflow and setup.
# This command can be used if only a selection of tests need to be run, for ad-hoc files.
run_tests_selective:
......@@ -326,7 +317,6 @@ jobs:
- checkout
- install_torchvision:
editable: true
- install_prototype_dependencies
- pip_install:
args: mypy
descr: Install Python type check utilities
......
......@@ -152,15 +152,6 @@ commands:
args: --no-build-isolation <<# parameters.editable >> --editable <</ parameters.editable >> .
descr: Install torchvision <<# parameters.editable >> in editable mode <</ parameters.editable >>
install_prototype_dependencies:
steps:
- pip_install:
args: iopath
descr: Install third-party dependencies
- pip_install:
args: --pre torchdata --extra-index-url https://download.pytorch.org/whl/nightly/cpu
descr: Install torchdata from nightly releases
# Most of the test suite is handled by the `unittest` jobs, with completely different workflow and setup.
# This command can be used if only a selection of tests need to be run, for ad-hoc files.
run_tests_selective:
......@@ -326,7 +317,6 @@ jobs:
- checkout
- install_torchvision:
editable: true
- install_prototype_dependencies
- pip_install:
args: mypy
descr: Install Python type check utilities
......
name: tests
on:
pull_request:
jobs:
prototype:
strategy:
matrix:
os:
- ubuntu-latest
- windows-latest
- macos-latest
fail-fast: false
runs-on: ${{ matrix.os }}
steps:
- name: Set up python
uses: actions/setup-python@v3
with:
python-version: 3.7
- name: Upgrade system packages
run: python -m pip install --upgrade pip setuptools wheel
- name: Checkout repository
uses: actions/checkout@v3
- name: Install PyTorch nightly builds
run: pip install --progress-bar=off --pre torch torchdata --extra-index-url https://download.pytorch.org/whl/nightly/cpu/
- name: Install torchvision
run: pip install --progress-bar=off --no-build-isolation --editable .
- name: Install other prototype dependencies
run: pip install --progress-bar=off scipy pycocotools h5py iopath
- name: Install test requirements
run: pip install --progress-bar=off pytest pytest-mock pytest-cov
- name: Mark setup as complete
id: setup
run: exit 0
- name: Run prototype features tests
shell: bash
run: |
pytest \
--durations=20 \
--cov=torchvision/prototype/features \
--cov-report=term-missing \
test/test_prototype_features*.py
- name: Run prototype datasets tests
if: success() || ( failure() && steps.setup.conclusion == 'success' )
shell: bash
run: |
pytest \
--durations=20 \
--cov=torchvision/prototype/datasets \
--cov-report=term-missing \
test/test_prototype_datasets*.py
- name: Run prototype transforms tests
if: success() || ( failure() && steps.setup.conclusion == 'success' )
shell: bash
run: |
pytest \
--durations=20 \
--cov=torchvision/prototype/transforms \
--cov-report=term-missing \
test/test_prototype_transforms*.py
- name: Run prototype models tests
if: success() || ( failure() && steps.setup.conclusion == 'success' )
shell: bash
run: |
pytest \
--durations=20 \
--cov=torchvision/prototype/models \
--cov-report=term-missing \
test/test_prototype_models*.py
......@@ -7,52 +7,6 @@ allow_redefinition = True
no_implicit_optional = True
warn_redundant_casts = True
[mypy-torchvision.prototype.features.*]
; untyped definitions and calls
disallow_untyped_defs = True
; None and Optional handling
no_implicit_optional = True
; warnings
warn_unused_ignores = True
warn_return_any = True
; miscellaneous strictness flags
allow_redefinition = True
[mypy-torchvision.prototype.transforms.*]
; untyped definitions and calls
disallow_untyped_defs = True
; None and Optional handling
no_implicit_optional = True
; warnings
warn_unused_ignores = True
warn_return_any = True
; miscellaneous strictness flags
allow_redefinition = True
[mypy-torchvision.prototype.datasets.*]
; untyped definitions and calls
disallow_untyped_defs = True
; None and Optional handling
no_implicit_optional = True
; warnings
warn_unused_ignores = True
warn_return_any = True
warn_unreachable = True
; miscellaneous strictness flags
allow_redefinition = True
[mypy-torchvision.io.image.*]
ignore_errors = True
......@@ -149,10 +103,6 @@ ignore_missing_imports = True
ignore_missing_imports = True
[mypy-torchdata.*]
ignore_missing_imports = True
[mypy-h5py.*]
ignore_missing_imports = True
......@@ -7,7 +7,6 @@ addopts =
# enable all warnings
-Wd
--ignore=test/test_datasets_download.py
--ignore-glob=test/test_prototype_*.py
testpaths =
test
xfail_strict = True
# Stereo Matching reference training scripts
This folder contains reference training scripts for Stereo Matching.
They serve as a log of how to train specific models, so as to provide baseline
training and evaluation scripts to quickly bootstrap research.
### CREStereo
The CREStereo model was trained on a dataset mixture between **CREStereo**, **ETH3D** and the additional split from **Middlebury2014**.
A ratio of **88-6-6** was used in order to train a baseline weight set. We provide multi-set variant as well.
Both used 8 A100 GPUs and a batch size of 2 (so effective batch size is 16). The
rest of the hyper-parameters loosely follow the recipe from https://github.com/megvii-research/CREStereo.
The original recipe trains for **300000** updates (or steps) on the dataset mixture. We modify the learning rate
schedule to one that starts decaying the weight much sooner. Throughout experiments we found that this reduces overfitting
during evaluation time and gradient clip help stabilize the loss during a pre-mature learning rate change.
```
torchrun --nproc_per_node 8 --nnodes 1 train.py \
--dataset-root $dataset_root \
--name $name_cre \
--model crestereo_base \
--train-datasets crestereo eth3d-train middlebury2014-other \
--dataset-steps 264000 18000 18000
--batch-size 2 \
--lr 0.0004 \
--min-lr 0.00002 \
--lr-decay-method cosine \
--warmup-steps 6000 \
--decay-after-steps 30000 \
--clip-grad-norm 1.0 \
```
We employ a multi-set fine-tuning stage where we uniformly sample from multiple datasets. Given hat some of these datasets have extremely large images (``2048x2048`` or more) we opt for a very aggresive scale-range ``[0.2 - 0.8]`` such that as much of the original frame composition is captured inside the ``384x512`` crop.
```
torchrun --nproc_per_node 8 --nnodes 1 train.py \
--dataset-root $dataset_root \
--name $name_things \
--model crestereo_base \
--train-datasets crestereo eth3d-train middlebury2014-other instereo2k fallingthings carla-highres sintel sceneflow-monkaa sceneflow-driving \
--dataset-steps 12000 12000 12000 12000 12000 12000 12000 12000 12000
--batch-size 2 \
--scale-range 0.2 0.8 \
--lr 0.0004 \
--lr-decay-method cosine \
--decay-after-steps 0 \
--warmup-steps 0 \
--min-lr 0.00002 \
--resume-path $checkpoint_dir/$name_cre.pth
```
### Evaluation
Evaluating the base weights
```
torchrun --nproc_per_node 1 --nnodes 1 cascade_evaluation.py --dataset middlebury2014-train --batch-size 1 --dataset-root $dataset_root --model crestereo_base --weights CREStereo_Base_Weights.CRESTEREO_ETH_MBL_V1
```
This should give an **mae of about 1.416** on the train set of `Middlebury2014`. Results may vary slightly depending on the batch size and the number of GPUs. For the most accurate resuts use 1 GPU and `--batch-size 1`. The created log file should look like this, where the first key is the number of cascades and the nested key is the number of recursive iterations:
```
Dataset: middlebury2014-train @size: [384, 512]:
{
1: {
2: {'mae': 2.363, 'rmse': 4.352, '1px': 0.611, '3px': 0.828, '5px': 0.891, 'relepe': 0.176, 'fl-all': 64.511}
5: {'mae': 1.618, 'rmse': 3.71, '1px': 0.761, '3px': 0.879, '5px': 0.918, 'relepe': 0.154, 'fl-all': 77.128}
10: {'mae': 1.416, 'rmse': 3.53, '1px': 0.777, '3px': 0.896, '5px': 0.933, 'relepe': 0.148, 'fl-all': 78.388}
20: {'mae': 1.448, 'rmse': 3.583, '1px': 0.771, '3px': 0.893, '5px': 0.931, 'relepe': 0.145, 'fl-all': 77.7}
},
}
{
2: {
2: {'mae': 1.972, 'rmse': 4.125, '1px': 0.73, '3px': 0.865, '5px': 0.908, 'relepe': 0.169, 'fl-all': 74.396}
5: {'mae': 1.403, 'rmse': 3.448, '1px': 0.793, '3px': 0.905, '5px': 0.937, 'relepe': 0.151, 'fl-all': 80.186}
10: {'mae': 1.312, 'rmse': 3.368, '1px': 0.799, '3px': 0.912, '5px': 0.943, 'relepe': 0.148, 'fl-all': 80.379}
20: {'mae': 1.376, 'rmse': 3.542, '1px': 0.796, '3px': 0.91, '5px': 0.942, 'relepe': 0.149, 'fl-all': 80.054}
},
}
```
You can also evaluate the Finetuned weights:
```
torchrun --nproc_per_node 1 --nnodes 1 cascade_evaluation.py --dataset middlebury2014-train --batch-size 1 --dataset-root $dataset_root --model crestereo_base --weights CREStereo_Base_Weights.CRESTEREO_FINETUNE_MULTI_V1
```
```
Dataset: middlebury2014-train @size: [384, 512]:
{
1: {
2: {'mae': 1.85, 'rmse': 3.797, '1px': 0.673, '3px': 0.862, '5px': 0.917, 'relepe': 0.171, 'fl-all': 69.736}
5: {'mae': 1.111, 'rmse': 3.166, '1px': 0.838, '3px': 0.93, '5px': 0.957, 'relepe': 0.134, 'fl-all': 84.596}
10: {'mae': 1.02, 'rmse': 3.073, '1px': 0.854, '3px': 0.938, '5px': 0.96, 'relepe': 0.129, 'fl-all': 86.042}
20: {'mae': 0.993, 'rmse': 3.059, '1px': 0.855, '3px': 0.942, '5px': 0.967, 'relepe': 0.126, 'fl-all': 85.784}
},
}
{
2: {
2: {'mae': 1.667, 'rmse': 3.867, '1px': 0.78, '3px': 0.891, '5px': 0.922, 'relepe': 0.165, 'fl-all': 78.89}
5: {'mae': 1.158, 'rmse': 3.278, '1px': 0.843, '3px': 0.926, '5px': 0.955, 'relepe': 0.135, 'fl-all': 84.556}
10: {'mae': 1.046, 'rmse': 3.13, '1px': 0.85, '3px': 0.934, '5px': 0.96, 'relepe': 0.13, 'fl-all': 85.464}
20: {'mae': 1.021, 'rmse': 3.102, '1px': 0.85, '3px': 0.935, '5px': 0.963, 'relepe': 0.129, 'fl-all': 85.417}
},
}
```
Evaluating the author provided weights:
```
torchrun --nproc_per_node 1 --nnodes 1 cascade_evaluation.py --dataset middlebury2014-train --batch-size 1 --dataset-root $dataset_root --model crestereo_base --weights CREStereo_Base_Weights.MEGVII_V1
```
```
Dataset: middlebury2014-train @size: [384, 512]:
{
1: {
2: {'mae': 1.704, 'rmse': 3.738, '1px': 0.738, '3px': 0.896, '5px': 0.933, 'relepe': 0.157, 'fl-all': 76.464}
5: {'mae': 0.956, 'rmse': 2.963, '1px': 0.88, '3px': 0.948, '5px': 0.965, 'relepe': 0.124, 'fl-all': 88.186}
10: {'mae': 0.792, 'rmse': 2.765, '1px': 0.905, '3px': 0.958, '5px': 0.97, 'relepe': 0.114, 'fl-all': 90.429}
20: {'mae': 0.749, 'rmse': 2.706, '1px': 0.907, '3px': 0.961, '5px': 0.972, 'relepe': 0.113, 'fl-all': 90.807}
},
}
{
2: {
2: {'mae': 1.702, 'rmse': 3.784, '1px': 0.784, '3px': 0.894, '5px': 0.924, 'relepe': 0.172, 'fl-all': 80.313}
5: {'mae': 0.932, 'rmse': 2.907, '1px': 0.877, '3px': 0.944, '5px': 0.963, 'relepe': 0.125, 'fl-all': 87.979}
10: {'mae': 0.773, 'rmse': 2.768, '1px': 0.901, '3px': 0.958, '5px': 0.972, 'relepe': 0.117, 'fl-all': 90.43}
20: {'mae': 0.854, 'rmse': 2.971, '1px': 0.9, '3px': 0.957, '5px': 0.97, 'relepe': 0.122, 'fl-all': 90.269}
},
}
```
# Concerns when training
We encourage users to be aware of the **aspect-ratio** and **disparity scale** they are targetting when doing any sort of training or fine-tuning. The model is highly sensitive to these two factors, as a consequence with naive multi-set fine-tuning one can achieve `0.2 mae` relatively fast. We recommend that users pay close attention to how they **balance dataset sizing** when training such networks.
Ideally, dataset scaling should be trated at an individual level and a thorough **EDA** of the disparity distribution in random crops at the desired training / inference size should be performed prior to any large compute investments.
### Disparity scaling
##### Sample A
The top row contains a sample from `Sintel` whereas the bottom row one from `Middlebury`.
![Disparity1](assets/disparity-domain-drift.jpg)
From left to right (`left_image`, `right_image`, `valid_mask`, `valid_mask & ground_truth`, `prediction`). **Darker is further away, lighter is closer**. In the case of `Sintel` which is more closely aligned to the original distribution of `CREStereo` we notice that the model accurately predicts the background scale whereas in the case of `Middlebury2014` it cannot correcly estimate the continous disparity. Notice that the frame composition is similar for both examples. The blue skybox in the `Sintel` scene behaves similarly to the `Middlebury` black background. However, because the `Middlebury` samples comes from an extremly large scene the crop size of `384x512` does not correctly capture the general training distribution.
##### Sample B
The top row contains a scene from `Sceneflow` using the `Monkaa` split whilst the bottom row is a scene from `Middlebury`. This sample exhibits the same issues when it comes to **background estimation**. Given the exagerated size of the `Middlebury` samples the model **colapses the smooth background** of the sample to what it considers to be a mean background disparity value.
![Disparity2](assets/disparity-background-mode-collapse.jpg)
For more detail on why this behaviour occurs based on the training distribution proportions you can read more about the network at: https://github.com/pytorch/vision/pull/6629#discussion_r978160493
### Metric overfitting
##### Learning is critical in the beginning
We also advise users to make user of faster training schedules, as the performance gain over long periods time is marginal. Here we exhibit a difference between a faster decay schedule and later decay schedule.
![Loss1](assets/Loss.jpg)
In **grey** we set the lr decay to begin after `30000` steps whilst in **orange** we opt for a very late learning rate decay at around `180000` steps. Although exhibiting stronger variance, we can notice that unfreezing the learning rate earlier whilst employing `gradient-norm` out-performs the default configuration.
##### Gradient norm saves time
![Loss2](assets/gradient-norm-removal.jpg)
In **grey** we keep ``gradient norm`` enabled whilst in **orange** we do not. We can notice that remvoing the gradient norm exacerbates the performance decrease in the early stages whilst also showcasing an almost complete collapse around the `60000` steps mark where we started decaying the lr for **orange**.
Although both runs ahieve an improvement of about ``0.1`` mae after the lr decay start, the benefits of it are observable much faster when ``gradient norm`` is employed as the recovery period is no longer accounted for.
import os
import warnings
import torch
import torchvision
import torchvision.prototype.models.depth.stereo
import utils
from torch.nn import functional as F
from train import make_eval_loader
from utils.metrics import AVAILABLE_METRICS
from vizualization import make_prediction_image_side_to_side
def get_args_parser(add_help=True):
import argparse
parser = argparse.ArgumentParser(description="PyTorch Stereo Matching Evaluation", add_help=add_help)
parser.add_argument("--dataset", type=str, default="middlebury2014-train", help="dataset to use")
parser.add_argument("--dataset-root", type=str, default="", help="root of the dataset")
parser.add_argument("--checkpoint", type=str, default="", help="path to weights")
parser.add_argument("--weights", type=str, default=None, help="torchvision API weight")
parser.add_argument(
"--model",
type=str,
default="crestereo_base",
help="which model to use if not speciffying a training checkpoint",
)
parser.add_argument("--img-folder", type=str, default="images")
parser.add_argument("--batch-size", type=int, default=1, help="batch size")
parser.add_argument("--workers", type=int, default=0, help="number of workers")
parser.add_argument("--eval-size", type=int, nargs="+", default=[384, 512], help="resize size")
parser.add_argument(
"--norm-mean", type=float, nargs="+", default=[0.5, 0.5, 0.5], help="mean for image normalization"
)
parser.add_argument(
"--norm-std", type=float, nargs="+", default=[0.5, 0.5, 0.5], help="std for image normalization"
)
parser.add_argument(
"--use-grayscale", action="store_true", help="use grayscale images instead of RGB", default=False
)
parser.add_argument("--max-disparity", type=float, default=None, help="maximum disparity")
parser.add_argument(
"--interpolation-strategy",
type=str,
default="bilinear",
help="interpolation strategy",
choices=["bilinear", "bicubic", "mixed"],
)
parser.add_argument("--n_iterations", nargs="+", type=int, default=[10], help="number of recurent iterations")
parser.add_argument("--n_cascades", nargs="+", type=int, default=[1], help="number of cascades")
parser.add_argument(
"--metrics",
type=str,
nargs="+",
default=["mae", "rmse", "1px", "3px", "5px", "relepe"],
help="metrics to log",
choices=AVAILABLE_METRICS,
)
parser.add_argument("--mixed-precision", action="store_true", help="use mixed precision training")
parser.add_argument("--world-size", type=int, default=1, help="number of distributed processes")
parser.add_argument("--dist-url", type=str, default="env://", help="url used to set up distributed training")
parser.add_argument("--device", type=str, default="cuda", help="device to use for training")
parser.add_argument("--save-images", action="store_true", help="save images of the predictions")
parser.add_argument("--padder-type", type=str, default="kitti", help="padder type", choices=["kitti", "sintel"])
return parser
def cascade_inference(model, image_left, image_right, iterations, cascades):
# check that image size is divisible by 16 * (2 ** (cascades - 1))
for image in [image_left, image_right]:
if image.shape[-2] % ((2 ** (cascades - 1))) != 0:
raise ValueError(
f"image height is not divisible by {16 * (2 ** (cascades - 1))}. Image shape: {image.shape[-2]}"
)
if image.shape[-1] % ((2 ** (cascades - 1))) != 0:
raise ValueError(
f"image width is not divisible by {16 * (2 ** (cascades - 1))}. Image shape: {image.shape[-2]}"
)
left_image_pyramid = [image_left]
right_image_pyramid = [image_right]
for idx in range(0, cascades - 1):
ds_factor = int(2 ** (idx + 1))
ds_shape = (image_left.shape[-2] // ds_factor, image_left.shape[-1] // ds_factor)
left_image_pyramid += F.interpolate(image_left, size=ds_shape, mode="bilinear", align_corners=True).unsqueeze(0)
right_image_pyramid += F.interpolate(image_right, size=ds_shape, mode="bilinear", align_corners=True).unsqueeze(
0
)
flow_init = None
for left_image, right_image in zip(reversed(left_image_pyramid), reversed(right_image_pyramid)):
flow_pred = model(left_image, right_image, flow_init, num_iters=iterations)
# flow pred is a list
flow_init = flow_pred[-1]
return flow_init
@torch.inference_mode()
def _evaluate(
model,
args,
val_loader,
*,
padder_mode,
print_freq=10,
writter=None,
step=None,
iterations=10,
cascades=1,
batch_size=None,
header=None,
save_images=False,
save_path="",
):
"""Helper function to compute various metrics (epe, etc.) for a model on a given dataset.
We process as many samples as possible with ddp.
"""
model.eval()
header = header or "Test:"
device = torch.device(args.device)
metric_logger = utils.MetricLogger(delimiter=" ")
iterations = iterations or args.recurrent_updates
logger = utils.MetricLogger()
for meter_name in args.metrics:
logger.add_meter(meter_name, fmt="{global_avg:.4f}")
if "fl-all" not in args.metrics:
logger.add_meter("fl-all", fmt="{global_avg:.4f}")
num_processed_samples = 0
with torch.cuda.amp.autocast(enabled=args.mixed_precision, dtype=torch.float16):
batch_idx = 0
for blob in metric_logger.log_every(val_loader, print_freq, header):
image_left, image_right, disp_gt, valid_disp_mask = (x.to(device) for x in blob)
padder = utils.InputPadder(image_left.shape, mode=padder_mode)
image_left, image_right = padder.pad(image_left, image_right)
disp_pred = cascade_inference(model, image_left, image_right, iterations, cascades)
disp_pred = disp_pred[:, :1, :, :]
disp_pred = padder.unpad(disp_pred)
if save_images:
if args.distributed:
rank_prefix = args.rank
else:
rank_prefix = 0
make_prediction_image_side_to_side(
disp_pred, disp_gt, valid_disp_mask, save_path, prefix=f"batch_{rank_prefix}_{batch_idx}"
)
metrics, _ = utils.compute_metrics(disp_pred, disp_gt, valid_disp_mask, metrics=logger.meters.keys())
num_processed_samples += image_left.shape[0]
for name in metrics:
logger.meters[name].update(metrics[name], n=1)
batch_idx += 1
num_processed_samples = utils.reduce_across_processes(num_processed_samples) / args.world_size
print("Num_processed_samples: ", num_processed_samples)
if (
hasattr(val_loader.dataset, "__len__")
and len(val_loader.dataset) != num_processed_samples
and torch.distributed.get_rank() == 0
):
warnings.warn(
f"Number of processed samples {num_processed_samples} is different"
f"from the dataset size {len(val_loader.dataset)}. This may happen if"
"the dataset is not divisible by the batch size. Try lowering the batch size for more accurate results."
)
if writter is not None and args.rank == 0:
for meter_name, meter_value in logger.meters.items():
scalar_name = f"{meter_name} {header}"
writter.add_scalar(scalar_name, meter_value.avg, step)
logger.synchronize_between_processes()
print(header, logger)
logger_metrics = {k: v.global_avg for k, v in logger.meters.items()}
return logger_metrics
def evaluate(model, loader, args, writter=None, step=None):
os.makedirs(args.img_folder, exist_ok=True)
checkpoint_name = os.path.basename(args.checkpoint) or args.weights
image_checkpoint_folder = os.path.join(args.img_folder, checkpoint_name)
metrics = {}
base_image_folder = os.path.join(image_checkpoint_folder, args.dataset)
os.makedirs(base_image_folder, exist_ok=True)
for n_cascades in args.n_cascades:
for n_iters in args.n_iterations:
config = f"{n_cascades}c_{n_iters}i"
config_image_folder = os.path.join(base_image_folder, config)
os.makedirs(config_image_folder, exist_ok=True)
metrics[config] = _evaluate(
model,
args,
loader,
padder_mode=args.padder_type,
header=f"{args.dataset} evaluation@ size:{args.eval_size} n_cascades:{n_cascades} n_iters:{n_iters}",
batch_size=args.batch_size,
writter=writter,
step=step,
iterations=n_iters,
cascades=n_cascades,
save_path=config_image_folder,
save_images=args.save_images,
)
metric_log = []
metric_log_dict = {}
# print the final results
for config in metrics:
config_tokens = config.split("_")
config_iters = config_tokens[1][:-1]
config_cascades = config_tokens[0][:-1]
metric_log_dict[config_cascades] = metric_log_dict.get(config_cascades, {})
metric_log_dict[config_cascades][config_iters] = metrics[config]
evaluation_str = f"{args.dataset} evaluation@ size:{args.eval_size} n_cascades:{config_cascades} recurrent_updates:{config_iters}"
metrics_str = f"Metrics: {metrics[config]}"
metric_log.extend([evaluation_str, metrics_str])
print(evaluation_str)
print(metrics_str)
eval_log_name = f"{checkpoint_name.replace('.pth', '')}_eval.log"
print("Saving eval log to: ", eval_log_name)
with open(eval_log_name, "w") as f:
f.write(f"Dataset: {args.dataset} @size: {args.eval_size}:\n")
# write the dict line by line for each key, and each value in the keys
for config_cascades in metric_log_dict:
f.write("{\n")
f.write(f"\t{config_cascades}: {{\n")
for config_iters in metric_log_dict[config_cascades]:
# convert every metric to 4 decimal places
metrics = metric_log_dict[config_cascades][config_iters]
metrics = {k: float(f"{v:.3f}") for k, v in metrics.items()}
f.write(f"\t\t{config_iters}: {metrics}\n")
f.write("\t},\n")
f.write("}\n")
def load_checkpoint(args):
utils.setup_ddp(args)
if not args.weights:
checkpoint = torch.load(args.checkpoint, map_location=torch.device("cpu"))
if "model" in checkpoint:
experiment_args = checkpoint["args"]
model = torchvision.prototype.models.depth.stereo.__dict__[experiment_args.model](weights=None)
model.load_state_dict(checkpoint["model"])
else:
model = torchvision.prototype.models.depth.stereo.__dict__[args.model](weights=None)
model.load_state_dict(checkpoint)
# set the appropiate devices
if args.distributed and args.device == "cpu":
raise ValueError("The device must be cuda if we want to run in distributed mode using torchrun")
device = torch.device(args.device)
else:
model = torchvision.prototype.models.depth.stereo.__dict__[args.model](weights=args.weights)
# convert to DDP if need be
if args.distributed:
model = model.to(args.device)
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
else:
model.to(device)
return model
def main(args):
model = load_checkpoint(args)
loader = make_eval_loader(args.dataset, args)
evaluate(model, loader, args)
if __name__ == "__main__":
args = get_args_parser().parse_args()
main(args)
import argparse
from functools import partial
import torch
from presets import StereoMatchingEvalPreset, StereoMatchingTrainPreset
from torchvision.datasets import (
CarlaStereo,
CREStereo,
ETH3DStereo,
FallingThingsStereo,
InStereo2k,
Kitti2012Stereo,
Kitti2015Stereo,
Middlebury2014Stereo,
SceneFlowStereo,
SintelStereo,
)
VALID_DATASETS = {
"crestereo": partial(CREStereo),
"carla-highres": partial(CarlaStereo),
"instereo2k": partial(InStereo2k),
"sintel": partial(SintelStereo),
"sceneflow-monkaa": partial(SceneFlowStereo, variant="Monkaa", pass_name="both"),
"sceneflow-flyingthings": partial(SceneFlowStereo, variant="FlyingThings3D", pass_name="both"),
"sceneflow-driving": partial(SceneFlowStereo, variant="Driving", pass_name="both"),
"fallingthings": partial(FallingThingsStereo, variant="both"),
"eth3d-train": partial(ETH3DStereo, split="train"),
"eth3d-test": partial(ETH3DStereo, split="test"),
"kitti2015-train": partial(Kitti2015Stereo, split="train"),
"kitti2015-test": partial(Kitti2015Stereo, split="test"),
"kitti2012-train": partial(Kitti2012Stereo, split="train"),
"kitti2012-test": partial(Kitti2012Stereo, split="train"),
"middlebury2014-other": partial(
Middlebury2014Stereo, split="additional", use_ambient_view=True, calibration="both"
),
"middlebury2014-train": partial(Middlebury2014Stereo, split="train", calibration="perfect"),
"middlebury2014-test": partial(Middlebury2014Stereo, split="test", calibration=None),
"middlebury2014-train-ambient": partial(
Middlebury2014Stereo, split="train", use_ambient_views=True, calibrartion="perfect"
),
}
def make_train_transform(args: argparse.Namespace) -> torch.nn.Module:
return StereoMatchingTrainPreset(
resize_size=args.resize_size,
crop_size=args.crop_size,
rescale_prob=args.rescale_prob,
scaling_type=args.scaling_type,
scale_range=args.scale_range,
scale_interpolation_type=args.interpolation_strategy,
use_grayscale=args.use_grayscale,
mean=args.norm_mean,
std=args.norm_std,
horizontal_flip_prob=args.flip_prob,
gpu_transforms=args.gpu_transforms,
max_disparity=args.max_disparity,
spatial_shift_prob=args.spatial_shift_prob,
spatial_shift_max_angle=args.spatial_shift_max_angle,
spatial_shift_max_displacement=args.spatial_shift_max_displacement,
spatial_shift_interpolation_type=args.interpolation_strategy,
gamma_range=args.gamma_range,
brightness=args.brightness_range,
contrast=args.contrast_range,
saturation=args.saturation_range,
hue=args.hue_range,
asymmetric_jitter_prob=args.asymmetric_jitter_prob,
)
def make_eval_transform(args: argparse.Namespace) -> torch.nn.Module:
if args.eval_size is None:
resize_size = args.crop_size
else:
resize_size = args.eval_size
return StereoMatchingEvalPreset(
mean=args.norm_mean,
std=args.norm_std,
use_grayscale=args.use_grayscale,
resize_size=resize_size,
interpolation_type=args.interpolation_strategy,
)
def make_dataset(dataset_name: str, dataset_root: str, transforms: torch.nn.Module) -> torch.utils.data.Dataset:
return VALID_DATASETS[dataset_name](root=dataset_root, transforms=transforms)
from typing import Optional, Tuple, Union
import torch
import transforms as T
class StereoMatchingEvalPreset(torch.nn.Module):
def __init__(
self,
mean: float = 0.5,
std: float = 0.5,
resize_size: Optional[Tuple[int, ...]] = None,
max_disparity: Optional[float] = None,
interpolation_type: str = "bilinear",
use_grayscale: bool = False,
) -> None:
super().__init__()
transforms = [
T.ToTensor(),
T.ConvertImageDtype(torch.float32),
]
if use_grayscale:
transforms.append(T.ConvertToGrayscale())
if resize_size is not None:
transforms.append(T.Resize(resize_size, interpolation_type=interpolation_type))
transforms.extend(
[
T.Normalize(mean=mean, std=std),
T.MakeValidDisparityMask(max_disparity=max_disparity),
T.ValidateModelInput(),
]
)
self.transforms = T.Compose(transforms)
def forward(self, images, disparities, masks):
return self.transforms(images, disparities, masks)
class StereoMatchingTrainPreset(torch.nn.Module):
def __init__(
self,
*,
resize_size: Optional[Tuple[int, ...]],
resize_interpolation_type: str = "bilinear",
# RandomResizeAndCrop params
crop_size: Tuple[int, int],
rescale_prob: float = 1.0,
scaling_type: str = "exponential",
scale_range: Tuple[float, float] = (-0.2, 0.5),
scale_interpolation_type: str = "bilinear",
# convert to grayscale
use_grayscale: bool = False,
# normalization params
mean: float = 0.5,
std: float = 0.5,
# processing device
gpu_transforms: bool = False,
# masking
max_disparity: Optional[int] = 256,
# SpatialShift params
spatial_shift_prob: float = 0.5,
spatial_shift_max_angle: float = 0.5,
spatial_shift_max_displacement: float = 0.5,
spatial_shift_interpolation_type: str = "bilinear",
# AssymetricColorJitter
gamma_range: Tuple[float, float] = (0.8, 1.2),
brightness: Union[int, Tuple[int, int]] = (0.8, 1.2),
contrast: Union[int, Tuple[int, int]] = (0.8, 1.2),
saturation: Union[int, Tuple[int, int]] = 0.0,
hue: Union[int, Tuple[int, int]] = 0.0,
asymmetric_jitter_prob: float = 1.0,
# RandomHorizontalFlip
horizontal_flip_prob: float = 0.5,
# RandomOcclusion
occlusion_prob: float = 0.0,
occlusion_px_range: Tuple[int, int] = (50, 100),
# RandomErase
erase_prob: float = 0.0,
erase_px_range: Tuple[int, int] = (50, 100),
erase_num_repeats: int = 1,
) -> None:
if scaling_type not in ["linear", "exponential"]:
raise ValueError(f"Unknown scaling type: {scaling_type}. Available types: linear, exponential")
super().__init__()
transforms = [T.ToTensor()]
# when fixing size across multiple datasets, we ensure
# that the same size is used for all datasets when cropping
if resize_size is not None:
transforms.append(T.Resize(resize_size, interpolation_type=resize_interpolation_type))
if gpu_transforms:
transforms.append(T.ToGPU())
# color handling
color_transforms = [
T.AsymmetricColorJitter(
brightness=brightness, contrast=contrast, saturation=saturation, hue=hue, p=asymmetric_jitter_prob
),
T.AsymetricGammaAdjust(p=asymmetric_jitter_prob, gamma_range=gamma_range),
]
if use_grayscale:
color_transforms.append(T.ConvertToGrayscale())
transforms.extend(color_transforms)
transforms.extend(
[
T.RandomSpatialShift(
p=spatial_shift_prob,
max_angle=spatial_shift_max_angle,
max_px_shift=spatial_shift_max_displacement,
interpolation_type=spatial_shift_interpolation_type,
),
T.ConvertImageDtype(torch.float32),
T.RandomRescaleAndCrop(
crop_size=crop_size,
scale_range=scale_range,
rescale_prob=rescale_prob,
scaling_type=scaling_type,
interpolation_type=scale_interpolation_type,
),
T.RandomHorizontalFlip(horizontal_flip_prob),
# occlusion after flip, otherwise we're occluding the reference image
T.RandomOcclusion(p=occlusion_prob, occlusion_px_range=occlusion_px_range),
T.RandomErase(p=erase_prob, erase_px_range=erase_px_range, max_erase=erase_num_repeats),
T.Normalize(mean=mean, std=std),
T.MakeValidDisparityMask(max_disparity),
T.ValidateModelInput(),
]
)
self.transforms = T.Compose(transforms)
def forward(self, images, disparties, mask):
return self.transforms(images, disparties, mask)
This diff is collapsed.
This diff is collapsed.
from .losses import *
from .metrics import *
from .distributed import *
from .logger import *
from .padder import *
from .norm import *
import os
import torch
import torch.distributed as dist
def _redefine_print(is_main):
"""disables printing when not in main process"""
import builtins as __builtin__
builtin_print = __builtin__.print
def print(*args, **kwargs):
force = kwargs.pop("force", False)
if is_main or force:
builtin_print(*args, **kwargs)
__builtin__.print = print
def setup_ddp(args):
# Set the local_rank, rank, and world_size values as args fields
# This is done differently depending on how we're running the script. We
# currently support either torchrun or the custom run_with_submitit.py
# If you're confused (like I was), this might help a bit
# https://discuss.pytorch.org/t/what-is-the-difference-between-rank-and-local-rank/61940/2
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
args.rank = int(os.environ["RANK"])
args.world_size = int(os.environ["WORLD_SIZE"])
args.gpu = int(os.environ["LOCAL_RANK"])
elif "SLURM_PROCID" in os.environ:
args.rank = int(os.environ["SLURM_PROCID"])
args.gpu = args.rank % torch.cuda.device_count()
elif hasattr(args, "rank"):
pass
else:
print("Not using distributed mode")
args.distributed = False
args.world_size = 1
return
args.distributed = True
torch.cuda.set_device(args.gpu)
dist.init_process_group(
backend="nccl",
rank=args.rank,
world_size=args.world_size,
init_method=args.dist_url,
)
torch.distributed.barrier()
_redefine_print(is_main=(args.rank == 0))
def reduce_across_processes(val):
t = torch.tensor(val, device="cuda")
dist.barrier()
dist.all_reduce(t)
return t
import datetime
import time
from collections import defaultdict, deque
import torch
from .distributed import reduce_across_processes
class SmoothedValue:
"""Track a series of values and provide access to smoothed values over a
window or the global series average.
"""
def __init__(self, window_size=20, fmt="{median:.4f} ({global_avg:.4f})"):
self.deque = deque(maxlen=window_size)
self.total = 0.0
self.count = 0
self.fmt = fmt
def update(self, value, n=1):
self.deque.append(value)
self.count += n
self.total += value * n
def synchronize_between_processes(self):
"""
Warning: does not synchronize the deque!
"""
t = reduce_across_processes([self.count, self.total])
t = t.tolist()
self.count = int(t[0])
self.total = t[1]
@property
def median(self):
d = torch.tensor(list(self.deque))
return d.median().item()
@property
def avg(self):
d = torch.tensor(list(self.deque), dtype=torch.float32)
return d.mean().item()
@property
def global_avg(self):
return self.total / self.count
@property
def max(self):
return max(self.deque)
@property
def value(self):
return self.deque[-1]
def __str__(self):
return self.fmt.format(
median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value
)
class MetricLogger:
def __init__(self, delimiter="\t"):
self.meters = defaultdict(SmoothedValue)
self.delimiter = delimiter
def update(self, **kwargs):
for k, v in kwargs.items():
if isinstance(v, torch.Tensor):
v = v.item()
if not isinstance(v, (float, int)):
raise TypeError(
f"This method expects the value of the input arguments to be of type float or int, instead got {type(v)}"
)
self.meters[k].update(v)
def __getattr__(self, attr):
if attr in self.meters:
return self.meters[attr]
if attr in self.__dict__:
return self.__dict__[attr]
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{attr}'")
def __str__(self):
loss_str = []
for name, meter in self.meters.items():
loss_str.append(f"{name}: {str(meter)}")
return self.delimiter.join(loss_str)
def synchronize_between_processes(self):
for meter in self.meters.values():
meter.synchronize_between_processes()
def add_meter(self, name, **kwargs):
self.meters[name] = SmoothedValue(**kwargs)
def log_every(self, iterable, print_freq=5, header=None):
i = 0
if not header:
header = ""
start_time = time.time()
end = time.time()
iter_time = SmoothedValue(fmt="{avg:.4f}")
data_time = SmoothedValue(fmt="{avg:.4f}")
space_fmt = ":" + str(len(str(len(iterable)))) + "d"
if torch.cuda.is_available():
log_msg = self.delimiter.join(
[
header,
"[{0" + space_fmt + "}/{1}]",
"eta: {eta}",
"{meters}",
"time: {time}",
"data: {data}",
"max mem: {memory:.0f}",
]
)
else:
log_msg = self.delimiter.join(
[header, "[{0" + space_fmt + "}/{1}]", "eta: {eta}", "{meters}", "time: {time}", "data: {data}"]
)
MB = 1024.0 * 1024.0
for obj in iterable:
data_time.update(time.time() - end)
yield obj
iter_time.update(time.time() - end)
if print_freq is not None and i % print_freq == 0:
eta_seconds = iter_time.global_avg * (len(iterable) - i)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
if torch.cuda.is_available():
print(
log_msg.format(
i,
len(iterable),
eta=eta_string,
meters=str(self),
time=str(iter_time),
data=str(data_time),
memory=torch.cuda.max_memory_allocated() / MB,
)
)
else:
print(
log_msg.format(
i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time)
)
)
i += 1
end = time.time()
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print(f"{header} Total time: {total_time_str}")
from typing import List, Optional
import torch
from torch import nn, Tensor
from torch.nn import functional as F
from torchvision.prototype.models.depth.stereo.raft_stereo import grid_sample, make_coords_grid
def make_gaussian_kernel(kernel_size: int, sigma: float) -> torch.Tensor:
"""Function to create a 2D Gaussian kernel."""
x = torch.arange(kernel_size, dtype=torch.float32)
y = torch.arange(kernel_size, dtype=torch.float32)
x = x - (kernel_size - 1) / 2
y = y - (kernel_size - 1) / 2
x, y = torch.meshgrid(x, y)
grid = (x**2 + y**2) / (2 * sigma**2)
kernel = torch.exp(-grid)
kernel = kernel / kernel.sum()
return kernel
def _sequence_loss_fn(
flow_preds: List[Tensor],
flow_gt: Tensor,
valid_flow_mask: Optional[Tensor],
gamma: Tensor,
max_flow: int = 256,
exclude_large: bool = False,
weights: Optional[Tensor] = None,
):
"""Loss function defined over sequence of flow predictions"""
torch._assert(
gamma < 1,
"sequence_loss: `gamma` must be lower than 1, but got {}".format(gamma),
)
if exclude_large:
# exclude invalid pixels and extremely large diplacements
flow_norm = torch.sum(flow_gt**2, dim=1).sqrt()
if valid_flow_mask is not None:
valid_flow_mask = valid_flow_mask & (flow_norm < max_flow)
else:
valid_flow_mask = flow_norm < max_flow
if valid_flow_mask is not None:
valid_flow_mask = valid_flow_mask.unsqueeze(1)
flow_preds = torch.stack(flow_preds) # shape = (num_flow_updates, batch_size, 2, H, W)
abs_diff = (flow_preds - flow_gt).abs()
if valid_flow_mask is not None:
abs_diff = abs_diff * valid_flow_mask.unsqueeze(0)
abs_diff = abs_diff.mean(axis=(1, 2, 3, 4))
num_predictions = flow_preds.shape[0]
# alocating on CPU and moving to device during run-time can force
# an unwanted GPU synchronization that produces a large overhead
if weights is None or len(weights) != num_predictions:
weights = gamma ** torch.arange(num_predictions - 1, -1, -1, device=flow_preds.device, dtype=flow_preds.dtype)
flow_loss = (abs_diff * weights).sum()
return flow_loss, weights
class SequenceLoss(nn.Module):
def __init__(self, gamma: float = 0.8, max_flow: int = 256, exclude_large_flows: bool = False) -> None:
"""
Args:
gamma: value for the exponential weighting of the loss across frames
max_flow: maximum flow value to exclude
exclude_large_flows: whether to exclude large flows
"""
super().__init__()
self.max_flow = max_flow
self.excluding_large = exclude_large_flows
self.register_buffer("gamma", torch.tensor([gamma]))
# cache the scale factor for the loss
self._weights = None
def forward(self, flow_preds: List[Tensor], flow_gt: Tensor, valid_flow_mask: Optional[Tensor]) -> Tensor:
"""
Args:
flow_preds: list of flow predictions of shape (batch_size, C, H, W)
flow_gt: ground truth flow of shape (batch_size, C, H, W)
valid_flow_mask: mask of valid flow pixels of shape (batch_size, H, W)
"""
loss, weights = _sequence_loss_fn(
flow_preds, flow_gt, valid_flow_mask, self.gamma, self.max_flow, self.excluding_large, self._weights
)
self._weights = weights
return loss
def set_gamma(self, gamma: float) -> None:
self.gamma.fill_(gamma)
# reset the cached scale factor
self._weights = None
def _ssim_loss_fn(
source: Tensor,
reference: Tensor,
kernel: Tensor,
eps: float = 1e-8,
c1: float = 0.01**2,
c2: float = 0.03**2,
use_padding: bool = False,
) -> Tensor:
# ref: Algorithm section: https://en.wikipedia.org/wiki/Structural_similarity
# ref: Alternative implementation: https://kornia.readthedocs.io/en/latest/_modules/kornia/metrics/ssim.html#ssim
torch._assert(
source.ndim == reference.ndim == 4,
"SSIM: `source` and `reference` must be 4-dimensional tensors",
)
torch._assert(
source.shape == reference.shape,
"SSIM: `source` and `reference` must have the same shape, but got {} and {}".format(
source.shape, reference.shape
),
)
B, C, H, W = source.shape
kernel = kernel.unsqueeze(0).unsqueeze(0).repeat(C, 1, 1, 1)
if use_padding:
pad_size = kernel.shape[2] // 2
source = F.pad(source, (pad_size, pad_size, pad_size, pad_size), "reflect")
reference = F.pad(reference, (pad_size, pad_size, pad_size, pad_size), "reflect")
mu1 = F.conv2d(source, kernel, groups=C)
mu2 = F.conv2d(reference, kernel, groups=C)
mu1_sq = mu1.pow(2)
mu2_sq = mu2.pow(2)
mu1_mu2 = mu1 * mu2
mu_img1_sq = F.conv2d(source.pow(2), kernel, groups=C)
mu_img2_sq = F.conv2d(reference.pow(2), kernel, groups=C)
mu_img1_mu2 = F.conv2d(source * reference, kernel, groups=C)
sigma1_sq = mu_img1_sq - mu1_sq
sigma2_sq = mu_img2_sq - mu2_sq
sigma12 = mu_img1_mu2 - mu1_mu2
numerator = (2 * mu1_mu2 + c1) * (2 * sigma12 + c2)
denominator = (mu1_sq + mu2_sq + c1) * (sigma1_sq + sigma2_sq + c2)
ssim = numerator / (denominator + eps)
# doing 1 - ssim because we want to maximize the ssim
return 1 - ssim.mean(dim=(1, 2, 3))
class SSIM(nn.Module):
def __init__(
self,
kernel_size: int = 11,
max_val: float = 1.0,
sigma: float = 1.5,
eps: float = 1e-12,
use_padding: bool = True,
) -> None:
"""SSIM loss function.
Args:
kernel_size: size of the Gaussian kernel
max_val: constant scaling factor
sigma: sigma of the Gaussian kernel
eps: constant for division by zero
use_padding: whether to pad the input tensor such that we have a score for each pixel
"""
super().__init__()
self.kernel_size = kernel_size
self.max_val = max_val
self.sigma = sigma
gaussian_kernel = make_gaussian_kernel(kernel_size, sigma)
self.register_buffer("gaussian_kernel", gaussian_kernel)
self.c1 = (0.01 * self.max_val) ** 2
self.c2 = (0.03 * self.max_val) ** 2
self.use_padding = use_padding
self.eps = eps
def forward(self, source: torch.Tensor, reference: torch.Tensor) -> torch.Tensor:
"""
Args:
source: source image of shape (batch_size, C, H, W)
reference: reference image of shape (batch_size, C, H, W)
Returns:
SSIM loss of shape (batch_size,)
"""
return _ssim_loss_fn(
source,
reference,
kernel=self.gaussian_kernel,
c1=self.c1,
c2=self.c2,
use_padding=self.use_padding,
eps=self.eps,
)
def _smoothness_loss_fn(img_gx: Tensor, img_gy: Tensor, val_gx: Tensor, val_gy: Tensor):
# ref: https://github.com/nianticlabs/monodepth2/blob/b676244e5a1ca55564eb5d16ab521a48f823af31/layers.py#L202
torch._assert(
img_gx.ndim >= 3,
"smoothness_loss: `img_gx` must be at least 3-dimensional tensor of shape (..., C, H, W)",
)
torch._assert(
img_gx.ndim == val_gx.ndim,
"smoothness_loss: `img_gx` and `depth_gx` must have the same dimensionality, but got {} and {}".format(
img_gx.ndim, val_gx.ndim
),
)
for idx in range(img_gx.ndim):
torch._assert(
(img_gx.shape[idx] == val_gx.shape[idx] or (img_gx.shape[idx] == 1 or val_gx.shape[idx] == 1)),
"smoothness_loss: `img_gx` and `depth_gx` must have either the same shape or broadcastable shape, but got {} and {}".format(
img_gx.shape, val_gx.shape
),
)
# -3 is channel dimension
weights_x = torch.exp(-torch.mean(torch.abs(val_gx), axis=-3, keepdim=True))
weights_y = torch.exp(-torch.mean(torch.abs(val_gy), axis=-3, keepdim=True))
smoothness_x = img_gx * weights_x
smoothness_y = img_gy * weights_y
smoothness = (torch.abs(smoothness_x) + torch.abs(smoothness_y)).mean(axis=(-3, -2, -1))
return smoothness
class SmoothnessLoss(nn.Module):
def __init__(self) -> None:
super().__init__()
def _x_gradient(self, img: Tensor) -> Tensor:
if img.ndim > 4:
original_shape = img.shape
is_reshaped = True
img = img.reshape(-1, *original_shape[-3:])
else:
is_reshaped = False
padded = F.pad(img, (0, 1, 0, 0), mode="replicate")
grad = padded[..., :, :-1] - padded[..., :, 1:]
if is_reshaped:
grad = grad.reshape(original_shape)
return grad
def _y_gradient(self, x: torch.Tensor) -> torch.Tensor:
if x.ndim > 4:
original_shape = x.shape
is_reshaped = True
x = x.reshape(-1, *original_shape[-3:])
else:
is_reshaped = False
padded = F.pad(x, (0, 0, 0, 1), mode="replicate")
grad = padded[..., :-1, :] - padded[..., 1:, :]
if is_reshaped:
grad = grad.reshape(original_shape)
return grad
def forward(self, images: Tensor, vals: Tensor) -> Tensor:
"""
Args:
images: tensor of shape (D1, D2, ..., DN, C, H, W)
depths: tensor of shape (D1, D2, ..., DN, 1, H, W)
Returns:
smoothness loss of shape (D1, D2, ..., DN)
"""
img_gx = self._x_gradient(images)
img_gy = self._y_gradient(images)
val_gx = self._x_gradient(vals)
val_gy = self._y_gradient(vals)
return _smoothness_loss_fn(img_gx, img_gy, val_gx, val_gy)
def _flow_sequence_consistency_loss_fn(
flow_preds: List[Tensor],
gamma: float = 0.8,
resize_factor: float = 0.25,
rescale_factor: float = 0.25,
rescale_mode: str = "bilinear",
weights: Optional[Tensor] = None,
):
"""Loss function defined over sequence of flow predictions"""
# Simplified version of ref: https://arxiv.org/pdf/2006.11242.pdf
# In the original paper, an additional refinement network is used to refine a flow prediction.
# Each step performed by the recurrent module in Raft or CREStereo is a refinement step using a delta_flow update.
# which should be consistent with the previous step. In this implementation, we simplify the overall loss
# term and ignore left-right consistency loss or photometric loss which can be treated separetely.
torch._assert(
rescale_factor <= 1.0,
"sequence_consistency_loss: `rescale_factor` must be less than or equal to 1, but got {}".format(
rescale_factor
),
)
flow_preds = torch.stack(flow_preds) # shape = (num_flow_updates, batch_size, 2, H, W)
N, B, C, H, W = flow_preds.shape
# rescale flow predictions to account for bilinear upsampling artifacts
if rescale_factor:
flow_preds = (
F.interpolate(
flow_preds.view(N * B, C, H, W), scale_factor=resize_factor, mode=rescale_mode, align_corners=True
)
) * rescale_factor
flow_preds = torch.stack(torch.chunk(flow_preds, N, dim=0), dim=0)
# force the next prediction to be similar to the previous prediction
abs_diff = (flow_preds[1:] - flow_preds[:-1]).square()
abs_diff = abs_diff.mean(axis=(1, 2, 3, 4))
num_predictions = flow_preds.shape[0] - 1 # because we are comparing differences
if weights is None or len(weights) != num_predictions:
weights = gamma ** torch.arange(num_predictions - 1, -1, -1, device=flow_preds.device, dtype=flow_preds.dtype)
flow_loss = (abs_diff * weights).sum()
return flow_loss, weights
class FlowSequenceConsistencyLoss(nn.Module):
def __init__(
self,
gamma: float = 0.8,
resize_factor: float = 0.25,
rescale_factor: float = 0.25,
rescale_mode: str = "bilinear",
) -> None:
super().__init__()
self.gamma = gamma
self.resize_factor = resize_factor
self.rescale_factor = rescale_factor
self.rescale_mode = rescale_mode
self._weights = None
def forward(self, flow_preds: List[Tensor]) -> Tensor:
"""
Args:
flow_preds: list of tensors of shape (batch_size, C, H, W)
Returns:
sequence consistency loss of shape (batch_size,)
"""
loss, weights = _flow_sequence_consistency_loss_fn(
flow_preds,
gamma=self.gamma,
resize_factor=self.resize_factor,
rescale_factor=self.rescale_factor,
rescale_mode=self.rescale_mode,
weights=self._weights,
)
self._weights = weights
return loss
def set_gamma(self, gamma: float) -> None:
self.gamma.fill_(gamma)
# reset the cached scale factor
self._weights = None
def _psnr_loss_fn(source: torch.Tensor, target: torch.Tensor, max_val: float) -> torch.Tensor:
torch._assert(
source.shape == target.shape,
"psnr_loss: source and target must have the same shape, but got {} and {}".format(source.shape, target.shape),
)
# ref https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
return 10 * torch.log10(max_val**2 / ((source - target).pow(2).mean(axis=(-3, -2, -1))))
class PSNRLoss(nn.Module):
def __init__(self, max_val: float = 256) -> None:
"""
Args:
max_val: maximum value of the input tensor. This refers to the maximum domain value of the input tensor.
"""
super().__init__()
self.max_val = max_val
def forward(self, source: Tensor, target: Tensor) -> Tensor:
"""
Args:
source: tensor of shape (D1, D2, ..., DN, C, H, W)
target: tensor of shape (D1, D2, ..., DN, C, H, W)
Returns:
psnr loss of shape (D1, D2, ..., DN)
"""
# multiply by -1 as we want to maximize the psnr
return -1 * _psnr_loss_fn(source, target, self.max_val)
class FlowPhotoMetricLoss(nn.Module):
def __init__(
self,
ssim_weight: float = 0.85,
ssim_window_size: int = 11,
ssim_max_val: float = 1.0,
ssim_sigma: float = 1.5,
ssim_eps: float = 1e-12,
ssim_use_padding: bool = True,
max_displacement_ratio: float = 0.15,
) -> None:
super().__init__()
self._ssim_loss = SSIM(
kernel_size=ssim_window_size,
max_val=ssim_max_val,
sigma=ssim_sigma,
eps=ssim_eps,
use_padding=ssim_use_padding,
)
self._L1_weight = 1 - ssim_weight
self._SSIM_weight = ssim_weight
self._max_displacement_ratio = max_displacement_ratio
def forward(
self,
source: Tensor,
reference: Tensor,
flow_pred: Tensor,
valid_mask: Optional[Tensor] = None,
):
"""
Args:
source: tensor of shape (B, C, H, W)
reference: tensor of shape (B, C, H, W)
flow_pred: tensor of shape (B, 2, H, W)
valid_mask: tensor of shape (B, H, W) or None
Returns:
photometric loss of shape
"""
torch._assert(
source.ndim == 4,
"FlowPhotoMetricLoss: source must have 4 dimensions, but got {}".format(source.ndim),
)
torch._assert(
reference.ndim == source.ndim,
"FlowPhotoMetricLoss: source and other must have the same number of dimensions, but got {} and {}".format(
source.ndim, reference.ndim
),
)
torch._assert(
flow_pred.shape[1] == 2,
"FlowPhotoMetricLoss: flow_pred must have 2 channels, but got {}".format(flow_pred.shape[1]),
)
torch._assert(
flow_pred.ndim == 4,
"FlowPhotoMetricLoss: flow_pred must have 4 dimensions, but got {}".format(flow_pred.ndim),
)
B, C, H, W = source.shape
flow_channels = flow_pred.shape[1]
max_displacements = []
for dim in range(flow_channels):
shape_index = -1 - dim
max_displacements.append(int(self._max_displacement_ratio * source.shape[shape_index]))
# mask out all pixels that have larger flow than the max flow allowed
max_flow_mask = torch.logical_and(
*[flow_pred[:, dim, :, :] < max_displacements[dim] for dim in range(flow_channels)]
)
if valid_mask is not None:
valid_mask = torch.logical_and(valid_mask, max_flow_mask).unsqueeze(1)
else:
valid_mask = max_flow_mask.unsqueeze(1)
grid = make_coords_grid(B, H, W, device=str(source.device))
resampled_grids = grid - flow_pred
resampled_grids = resampled_grids.permute(0, 2, 3, 1)
resampled_source = grid_sample(reference, resampled_grids, mode="bilinear")
# compute SSIM loss
ssim_loss = self._ssim_loss(resampled_source * valid_mask, source * valid_mask)
l1_loss = (resampled_source * valid_mask - source * valid_mask).abs().mean(axis=(-3, -2, -1))
loss = self._L1_weight * l1_loss + self._SSIM_weight * ssim_loss
return loss.mean()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment