Unverified Commit 4c66712f authored by Ponku's avatar Ponku Committed by GitHub
Browse files

Add CREStereo weights. (#6629)

* added crestereo author weights

* update weights name

* synced fl-all metric

* changed resize size config param name

* extended weight sets

* changed weight link
parent 10dafd9b
......@@ -8,9 +8,11 @@ import torch.nn as nn
import torch.nn.functional as F
import torchvision.models.optical_flow.raft as raft
from torch import Tensor
from torchvision.models._api import WeightsEnum
from torchvision.models._api import register_model, Weights, WeightsEnum
from torchvision.models._utils import handle_legacy_interface
from torchvision.models.optical_flow._utils import grid_sample, make_coords_grid, upsample_flow
from torchvision.ops import Conv2dNormActivation
from torchvision.prototype.transforms._presets import StereoMatching
all = (
"CREStereo",
......@@ -66,7 +68,7 @@ def get_correlation(
right_padded = F.pad(right_feature, (pad_x, pad_x, pad_y, pad_y), mode="replicate")
# in order to vectorize the correlation computation over all pixel candidates
# we create multiple shifted right images which we stack on an extra dimension
right_padded = F.unfold(right_padded, kernel_size=(H, W), dilation=dilate).detach()
right_padded = F.unfold(right_padded, kernel_size=(H, W), dilation=dilate)
# torch unfold returns a tensor of shape [B, flattened_values, n_selections]
right_padded = right_padded.permute(0, 2, 1)
# we consider rehsape back into [B, n_views, C, H, W]
......@@ -1055,11 +1057,374 @@ def _crestereo(
return model
_COMMON_META = {
"resize_size": (384, 512),
}
class CREStereo_Base_Weights(WeightsEnum):
pass
"""The metrics reported here are as follows.
``mae`` is the "mean-average-error" and indicates how far (in pixels) the
predicted disparity is from its true value (equivalent to ``epe``). This is averaged over all pixels
of all images. ``1px``, ``3px``, ``5px`` and indicate the percentage of pixels that have a lower
error than that of the ground truth. ``relepe`` is the "relative-end-point-error" and is the
average ``epe`` divided by the average ground truth disparity. ``fl-all`` corresponds to the average of pixels whose epe
is either <3px, or whom's ``relepe`` is lower than 0.05 (therefore higher is better).
"""
MEGVII_V1 = Weights(
# Weights ported from https://github.com/megvii-research/CREStereo
url="https://download.pytorch.org/models/crestereo-756c8b0f.pth",
transforms=StereoMatching,
meta={
**_COMMON_META,
"num_params": 5432948,
"recipe": "https://github.com/megvii-research/CREStereo",
"_metrics": {
"Middlebury2014-train": {
# metrics for 10 refinement iterations and 1 cascade
"mae": 0.792,
"rmse": 2.765,
"1px": 0.905,
"3px": 0.958,
"5px": 0.97,
"relepe": 0.114,
"fl-all": 90.429,
"_detailed": {
# 1 is the number of cascades
1: {
# 2 is number of refininement interations
2: {
"mae": 1.704,
"rmse": 3.738,
"1px": 0.738,
"3px": 0.896,
"5px": 0.933,
"relepe": 0.157,
"fl-all": 76.464,
},
5: {
"mae": 0.956,
"rmse": 2.963,
"1px": 0.88,
"3px": 0.948,
"5px": 0.965,
"relepe": 0.124,
"fl-all": 88.186,
},
10: {
"mae": 0.792,
"rmse": 2.765,
"1px": 0.905,
"3px": 0.958,
"5px": 0.97,
"relepe": 0.114,
"fl-all": 90.429,
},
20: {
"mae": 0.749,
"rmse": 2.706,
"1px": 0.907,
"3px": 0.961,
"5px": 0.972,
"relepe": 0.113,
"fl-all": 90.807,
},
},
2: {
2: {
"mae": 1.702,
"rmse": 3.784,
"1px": 0.784,
"3px": 0.894,
"5px": 0.924,
"relepe": 0.172,
"fl-all": 80.313,
},
5: {
"mae": 0.932,
"rmse": 2.907,
"1px": 0.877,
"3px": 0.944,
"5px": 0.963,
"relepe": 0.125,
"fl-all": 87.979,
},
10: {
"mae": 0.773,
"rmse": 2.768,
"1px": 0.901,
"3px": 0.958,
"5px": 0.972,
"relepe": 0.117,
"fl-all": 90.43,
},
20: {
"mae": 0.854,
"rmse": 2.971,
"1px": 0.9,
"3px": 0.957,
"5px": 0.97,
"relepe": 0.122,
"fl-all": 90.269,
},
},
},
}
},
"_docs": """These weights were ported from the original paper. They
are trained on a dataset mixture of the author's choice.""",
},
)
CRESTEREO_ETH_MBL_V1 = Weights(
# Weights ported from https://github.com/megvii-research/CREStereo
url="https://download.pytorch.org/models/crestereo-8f0e0e9a.pth",
transforms=StereoMatching,
meta={
**_COMMON_META,
"num_params": 5432948,
"recipe": "https://github.com/pytorch/vision/tree/main/references/depth/stereo",
"_metrics": {
"Middlebury2014-train": {
# metrics for 10 refinement iterations and 1 cascade
"mae": 1.416,
"rmse": 3.53,
"1px": 0.777,
"3px": 0.896,
"5px": 0.933,
"relepe": 0.148,
"fl-all": 78.388,
"_detailed": {
# 1 is the number of cascades
1: {
# 2 is the number of refinement iterations
2: {
"mae": 2.363,
"rmse": 4.352,
"1px": 0.611,
"3px": 0.828,
"5px": 0.891,
"relepe": 0.176,
"fl-all": 64.511,
},
5: {
"mae": 1.618,
"rmse": 3.71,
"1px": 0.761,
"3px": 0.879,
"5px": 0.918,
"relepe": 0.154,
"fl-all": 77.128,
},
10: {
"mae": 1.416,
"rmse": 3.53,
"1px": 0.777,
"3px": 0.896,
"5px": 0.933,
"relepe": 0.148,
"fl-all": 78.388,
},
20: {
"mae": 1.448,
"rmse": 3.583,
"1px": 0.771,
"3px": 0.893,
"5px": 0.931,
"relepe": 0.145,
"fl-all": 77.7,
},
},
2: {
2: {
"mae": 1.972,
"rmse": 4.125,
"1px": 0.73,
"3px": 0.865,
"5px": 0.908,
"relepe": 0.169,
"fl-all": 74.396,
},
5: {
"mae": 1.403,
"rmse": 3.448,
"1px": 0.793,
"3px": 0.905,
"5px": 0.937,
"relepe": 0.151,
"fl-all": 80.186,
},
10: {
"mae": 1.312,
"rmse": 3.368,
"1px": 0.799,
"3px": 0.912,
"5px": 0.943,
"relepe": 0.148,
"fl-all": 80.379,
},
20: {
"mae": 1.376,
"rmse": 3.542,
"1px": 0.796,
"3px": 0.91,
"5px": 0.942,
"relepe": 0.149,
"fl-all": 80.054,
},
},
},
}
},
"_docs": """These weights were trained from scratch on
:class:`~torchvision.datasets._stereo_matching.CREStereo` +
:class:`~torchvision.datasets._stereo_matching.Middlebury2014Stereo` +
:class:`~torchvision.datasets._stereo_matching.ETH3DStereo`.""",
},
)
CRESTEREO_FINETUNE_MULTI_V1 = Weights(
# Weights ported from https://github.com/megvii-research/CREStereo
url="https://download.pytorch.org/models/crestereo-697c38f4.pth ",
transforms=StereoMatching,
meta={
**_COMMON_META,
"num_params": 5432948,
"recipe": "https://github.com/pytorch/vision/tree/main/references/depth/stereo",
"_metrics": {
"Middlebury2014-train": {
# metrics for 10 refinement iterations and 1 cascade
"mae": 1.038,
"rmse": 3.108,
"1px": 0.852,
"3px": 0.942,
"5px": 0.963,
"relepe": 0.129,
"fl-all": 85.522,
"_detailed": {
# 1 is the number of cascades
1: {
# 2 is number of refininement interations
2: {
"mae": 1.85,
"rmse": 3.797,
"1px": 0.673,
"3px": 0.862,
"5px": 0.917,
"relepe": 0.171,
"fl-all": 69.736,
},
5: {
"mae": 1.111,
"rmse": 3.166,
"1px": 0.838,
"3px": 0.93,
"5px": 0.957,
"relepe": 0.134,
"fl-all": 84.596,
},
10: {
"mae": 1.02,
"rmse": 3.073,
"1px": 0.854,
"3px": 0.938,
"5px": 0.96,
"relepe": 0.129,
"fl-all": 86.042,
},
20: {
"mae": 0.993,
"rmse": 3.059,
"1px": 0.855,
"3px": 0.942,
"5px": 0.967,
"relepe": 0.126,
"fl-all": 85.784,
},
},
2: {
2: {
"mae": 1.667,
"rmse": 3.867,
"1px": 0.78,
"3px": 0.891,
"5px": 0.922,
"relepe": 0.165,
"fl-all": 78.89,
},
5: {
"mae": 1.158,
"rmse": 3.278,
"1px": 0.843,
"3px": 0.926,
"5px": 0.955,
"relepe": 0.135,
"fl-all": 84.556,
},
10: {
"mae": 1.046,
"rmse": 3.13,
"1px": 0.85,
"3px": 0.934,
"5px": 0.96,
"relepe": 0.13,
"fl-all": 85.464,
},
20: {
"mae": 1.021,
"rmse": 3.102,
"1px": 0.85,
"3px": 0.935,
"5px": 0.963,
"relepe": 0.129,
"fl-all": 85.417,
},
},
},
},
},
"_docs": """These weights were finetuned on a mixture of
:class:`~torchvision.datasets._stereo_matching.CREStereo` +
:class:`~torchvision.datasets._stereo_matching.Middlebury2014Stereo` +
:class:`~torchvision.datasets._stereo_matching.ETH3DStereo` +
:class:`~torchvision.datasets._stereo_matching.InStereo2k` +
:class:`~torchvision.datasets._stereo_matching.CarlaStereo` +
:class:`~torchvision.datasets._stereo_matching.SintelStereo` +
:class:`~torchvision.datasets._stereo_matching.FallingThingsStereo` +
.""",
},
)
DEFAULT = MEGVII_V1
@register_model()
@handle_legacy_interface(weights=("pretrained", CREStereo_Base_Weights.MEGVII_V1))
def crestereo_base(*, weights: Optional[CREStereo_Base_Weights] = None, progress=True, **kwargs) -> CREStereo:
"""CREStereo model from
`Practical Stereo Matching via Cascaded Recurrent Network
With Adaptive Correlation <https://openaccess.thecvf.com/content/CVPR2022/papers/Li_Practical_Stereo_Matching_via_Cascaded_Recurrent_Network_With_Adaptive_Correlation_CVPR_2022_paper.pdf>`_.
Please see the example below for a tutorial on how to use this model.
Args:
weights(:class:`~torchvision.prototype.models.depth.stereo.CREStereo_Base_Weights`, optional): The
pretrained weights to use. See
:class:`~torchvision.prototype.models.depth.stereo.CREStereo_Base_Weights`
below for more details, and possible values. By default, no
pre-trained weights are used.
progress (bool): If True, displays a progress bar of the download to stderr. Default is True.
**kwargs: parameters passed to the ``torchvision.prototype.models.depth.stereo.raft_stereo.RaftStereo``
base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/optical_flow/crestereo.py>`_
for more details about this class.
.. autoclass:: torchvision.prototype.models.depth.stereo.CREStereo_Base_Weights
:members:
"""
return _crestereo(
weights=weights,
progress=progress,
......
......@@ -3,6 +3,7 @@ from torchvision.transforms import AutoAugmentPolicy, InterpolationMode # usort
from . import functional # usort: skip
from ._transform import Transform # usort: skip
from ._presets import StereoMatching # usort: skip
from ._augment import RandomCutmix, RandomErasing, RandomMixup, SimpleCopyPaste
from ._auto_augment import AugMix, AutoAugment, RandAugment, TrivialAugmentWide
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment