Unverified Commit 97eddc5d authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Change default weights of RAFT model builders (#5381)

* Change default weights of RAFT model builders

* update handle_legacy_interface input

* Oops, wrong default
parent dad6e6a4
...@@ -21,7 +21,7 @@ __all__ = ( ...@@ -21,7 +21,7 @@ __all__ = (
_MODELS_URLS = { _MODELS_URLS = {
"raft_large": "https://download.pytorch.org/models/raft_large_C_T_V2-1bb1363a.pth", "raft_large": "https://download.pytorch.org/models/raft_large_C_T_SKHT_V2-ff5fadd5.pth",
"raft_small": "https://download.pytorch.org/models/raft_small_C_T_V2-01064c6d.pth", "raft_small": "https://download.pytorch.org/models/raft_small_C_T_V2-01064c6d.pth",
} }
...@@ -587,10 +587,16 @@ def raft_large(*, pretrained=False, progress=True, **kwargs): ...@@ -587,10 +587,16 @@ def raft_large(*, pretrained=False, progress=True, **kwargs):
`RAFT: Recurrent All Pairs Field Transforms for Optical Flow <https://arxiv.org/abs/2003.12039>`_. `RAFT: Recurrent All Pairs Field Transforms for Optical Flow <https://arxiv.org/abs/2003.12039>`_.
Args: Args:
pretrained (bool): Whether to use pretrained weights. pretrained (bool): Whether to use weights that have been pre-trained on
progress (bool): If True, displays a progress bar of the download to stderr :class:`~torchvsion.datasets.FlyingChairs` + :class:`~torchvsion.datasets.FlyingThings3D`
kwargs (dict): Parameters that will be passed to the :class:`~torchvision.models.optical_flow.RAFT` class with two fine-tuning steps:
to override any default.
- one on :class:`~torchvsion.datasets.Sintel` + :class:`~torchvsion.datasets.FlyingThings3D`
- one on :class:`~torchvsion.datasets.KittiFlow`.
This corresponds to the ``C+T+S/K`` strategy in the paper.
progress (bool): If True, displays a progress bar of the download to stderr.
Returns: Returns:
nn.Module: The model. nn.Module: The model.
...@@ -632,10 +638,9 @@ def raft_small(*, pretrained=False, progress=True, **kwargs): ...@@ -632,10 +638,9 @@ def raft_small(*, pretrained=False, progress=True, **kwargs):
`RAFT: Recurrent All Pairs Field Transforms for Optical Flow <https://arxiv.org/abs/2003.12039>`_. `RAFT: Recurrent All Pairs Field Transforms for Optical Flow <https://arxiv.org/abs/2003.12039>`_.
Args: Args:
pretrained (bool): Whether to use pretrained weights. pretrained (bool): Whether to use weights that have been pre-trained on
:class:`~torchvsion.datasets.FlyingChairs` + :class:`~torchvsion.datasets.FlyingThings3D`.
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
kwargs (dict): Parameters that will be passed to the :class:`~torchvision.models.optical_flow.RAFT` class
to override any default.
Returns: Returns:
nn.Module: The model. nn.Module: The model.
......
...@@ -115,7 +115,7 @@ class Raft_Large_Weights(WeightsEnum): ...@@ -115,7 +115,7 @@ class Raft_Large_Weights(WeightsEnum):
}, },
) )
DEFAULT = C_T_V2 DEFAULT = C_T_SKHT_V2
class Raft_Small_Weights(WeightsEnum): class Raft_Small_Weights(WeightsEnum):
...@@ -151,7 +151,7 @@ class Raft_Small_Weights(WeightsEnum): ...@@ -151,7 +151,7 @@ class Raft_Small_Weights(WeightsEnum):
DEFAULT = C_T_V2 DEFAULT = C_T_V2
@handle_legacy_interface(weights=("pretrained", Raft_Large_Weights.C_T_V2)) @handle_legacy_interface(weights=("pretrained", Raft_Large_Weights.C_T_SKHT_V2))
def raft_large(*, weights: Optional[Raft_Large_Weights] = None, progress=True, **kwargs): def raft_large(*, weights: Optional[Raft_Large_Weights] = None, progress=True, **kwargs):
"""RAFT model from """RAFT model from
`RAFT: Recurrent All Pairs Field Transforms for Optical Flow <https://arxiv.org/abs/2003.12039>`_. `RAFT: Recurrent All Pairs Field Transforms for Optical Flow <https://arxiv.org/abs/2003.12039>`_.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment