Unverified Commit e23542da authored by YosuaMichael's avatar YosuaMichael Committed by GitHub
Browse files

Add raft_stereo weights (#6786)

* Add raft_stereo weights

* Update the metrics layout
parent 0610b13a
from functools import partial
from typing import Callable, List, Optional, Tuple from typing import Callable, List, Optional, Tuple
import torch import torch
...@@ -5,11 +6,12 @@ import torch.nn as nn ...@@ -5,11 +6,12 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torchvision.models.optical_flow.raft as raft import torchvision.models.optical_flow.raft as raft
from torch import Tensor from torch import Tensor
from torchvision.models._api import register_model, WeightsEnum from torchvision.models._api import register_model, Weights, WeightsEnum
from torchvision.models._utils import handle_legacy_interface from torchvision.models._utils import handle_legacy_interface
from torchvision.models.optical_flow._utils import grid_sample, make_coords_grid, upsample_flow from torchvision.models.optical_flow._utils import grid_sample, make_coords_grid, upsample_flow
from torchvision.models.optical_flow.raft import FlowHead, MotionEncoder, ResidualBlock from torchvision.models.optical_flow.raft import FlowHead, MotionEncoder, ResidualBlock
from torchvision.ops import Conv2dNormActivation from torchvision.ops import Conv2dNormActivation
from torchvision.prototype.transforms._presets import StereoMatching
from torchvision.utils import _log_api_usage_once from torchvision.utils import _log_api_usage_once
...@@ -624,11 +626,97 @@ def _raft_stereo( ...@@ -624,11 +626,97 @@ def _raft_stereo(
class Raft_Stereo_Realtime_Weights(WeightsEnum): class Raft_Stereo_Realtime_Weights(WeightsEnum):
pass SCENEFLOW_V1 = Weights(
# Weights ported from https://github.com/princeton-vl/RAFT-Stereo
url="https://download.pytorch.org/models/raft_stereo_realtime-cf345ccb.pth",
transforms=partial(StereoMatching, resize_size=(224, 224)),
meta={
"num_params": 8077152,
"recipe": "https://github.com/princeton-vl/RAFT-Stereo",
"_metrics": {
# Following metrics from paper: https://arxiv.org/abs/2109.07547
"Kitty2015": {
"3px": 0.9409,
}
},
},
)
DEFAULT = SCENEFLOW_V1
class Raft_Stereo_Base_Weights(WeightsEnum): class Raft_Stereo_Base_Weights(WeightsEnum):
pass SCENEFLOW_V1 = Weights(
# Weights ported from https://github.com/princeton-vl/RAFT-Stereo
url="https://download.pytorch.org/models/raft_stereo_base_sceneflow-eff3f2e6.pth",
transforms=partial(StereoMatching, resize_size=(224, 224)),
meta={
"num_params": 11116176,
"recipe": "https://github.com/princeton-vl/RAFT-Stereo",
"_metrics": {
# Following metrics from paper: https://arxiv.org/abs/2109.07547
# Using standard metrics for each datasets
"Kitty2015": {
# Ratio of pixels with difference less than 3px from ground truth
"3px": 0.9426,
},
# For middlebury, ratio of pixels with difference less than 2px from ground truth
# on full, half, and quarter image resolution
"Middlebury2014-val-full": {
"2px": 0.8167,
},
"Middlebury2014-val-half": {
"2px": 0.8741,
},
"Middlebury2014-val-quarter": {
"2px": 0.9064,
},
"ETH3D-val": {
# Ratio of pixels with difference less than 1px from ground truth
"1px": 0.9672,
},
},
},
)
MIDDLEBURY_V1 = Weights(
# Weights ported from https://github.com/princeton-vl/RAFT-Stereo
url="https://download.pytorch.org/models/raft_stereo_base_middlebury-afa9d252.pth",
transforms=partial(StereoMatching, resize_size=(224, 224)),
meta={
"num_params": 11116176,
"recipe": "https://github.com/princeton-vl/RAFT-Stereo",
"_metrics": {
# Following metrics from paper: https://arxiv.org/abs/2109.07547
"Middlebury-test": {
"mae": 1.27,
"1px": 0.9063,
"2px": 0.9526,
"5px": 0.9725,
}
},
},
)
ETH3D_V1 = Weights(
# Weights ported from https://github.com/princeton-vl/RAFT-Stereo
url="https://download.pytorch.org/models/raft_stereo_base_eth3d-d4830f22.pth",
transforms=partial(StereoMatching, resize_size=(224, 224)),
meta={
"num_params": 11116176,
"recipe": "https://github.com/princeton-vl/RAFT-Stereo",
"_metrics": {
# Following metrics from paper: https://arxiv.org/abs/2109.07547
"ETH3D-test": {
"mae": 0.18,
"1px": 0.9756,
"2px": 0.9956,
}
},
},
)
DEFAULT = MIDDLEBURY_V1
@register_model() @register_model()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment