Unverified Commit 97eddc5d authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Change default weights of RAFT model builders (#5381)

* Change default weights of RAFT model builders

* update handle_legacy_interface input

* Oops, wrong default
parent dad6e6a4
......@@ -21,7 +21,7 @@ __all__ = (
_MODELS_URLS = {
"raft_large": "https://download.pytorch.org/models/raft_large_C_T_V2-1bb1363a.pth",
"raft_large": "https://download.pytorch.org/models/raft_large_C_T_SKHT_V2-ff5fadd5.pth",
"raft_small": "https://download.pytorch.org/models/raft_small_C_T_V2-01064c6d.pth",
}
......@@ -587,10 +587,16 @@ def raft_large(*, pretrained=False, progress=True, **kwargs):
`RAFT: Recurrent All Pairs Field Transforms for Optical Flow <https://arxiv.org/abs/2003.12039>`_.
Args:
pretrained (bool): Whether to use pretrained weights.
progress (bool): If True, displays a progress bar of the download to stderr
kwargs (dict): Parameters that will be passed to the :class:`~torchvision.models.optical_flow.RAFT` class
to override any default.
pretrained (bool): Whether to use weights that have been pre-trained on
:class:`~torchvsion.datasets.FlyingChairs` + :class:`~torchvsion.datasets.FlyingThings3D`
with two fine-tuning steps:
- one on :class:`~torchvsion.datasets.Sintel` + :class:`~torchvsion.datasets.FlyingThings3D`
- one on :class:`~torchvsion.datasets.KittiFlow`.
This corresponds to the ``C+T+S/K`` strategy in the paper.
progress (bool): If True, displays a progress bar of the download to stderr.
Returns:
nn.Module: The model.
......@@ -632,10 +638,9 @@ def raft_small(*, pretrained=False, progress=True, **kwargs):
`RAFT: Recurrent All Pairs Field Transforms for Optical Flow <https://arxiv.org/abs/2003.12039>`_.
Args:
pretrained (bool): Whether to use pretrained weights.
pretrained (bool): Whether to use weights that have been pre-trained on
:class:`~torchvsion.datasets.FlyingChairs` + :class:`~torchvsion.datasets.FlyingThings3D`.
progress (bool): If True, displays a progress bar of the download to stderr
kwargs (dict): Parameters that will be passed to the :class:`~torchvision.models.optical_flow.RAFT` class
to override any default.
Returns:
nn.Module: The model.
......
......@@ -115,7 +115,7 @@ class Raft_Large_Weights(WeightsEnum):
},
)
DEFAULT = C_T_V2
DEFAULT = C_T_SKHT_V2
class Raft_Small_Weights(WeightsEnum):
......@@ -151,7 +151,7 @@ class Raft_Small_Weights(WeightsEnum):
DEFAULT = C_T_V2
@handle_legacy_interface(weights=("pretrained", Raft_Large_Weights.C_T_V2))
@handle_legacy_interface(weights=("pretrained", Raft_Large_Weights.C_T_SKHT_V2))
def raft_large(*, weights: Optional[Raft_Large_Weights] = None, progress=True, **kwargs):
"""RAFT model from
`RAFT: Recurrent All Pairs Field Transforms for Optical Flow <https://arxiv.org/abs/2003.12039>`_.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment