Unverified Commit 48e2f23c authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Add pretrained weights for raft_small from original paper (#5070)

parent 4cacf5a1
...@@ -20,7 +20,11 @@ __all__ = ( ...@@ -20,7 +20,11 @@ __all__ = (
) )
_MODELS_URLS = {"raft_large": "https://download.pytorch.org/models/raft_large_C_T_V2-1bb1363a.pth"} _MODELS_URLS = {
"raft_large": "https://download.pytorch.org/models/raft_large_C_T_V2-1bb1363a.pth",
# TODO: change to V2 once we upload our own weights
"raft_small": "https://download.pytorch.org/models/raft_small_C_T_V1-ad48884c.pth",
}
class ResidualBlock(nn.Module): class ResidualBlock(nn.Module):
...@@ -641,8 +645,6 @@ def raft_small(*, pretrained=False, progress=True, **kwargs): ...@@ -641,8 +645,6 @@ def raft_small(*, pretrained=False, progress=True, **kwargs):
nn.Module: The model. nn.Module: The model.
""" """
if pretrained:
raise ValueError("No checkpoint is available for raft_small")
return _raft( return _raft(
arch="raft_small", arch="raft_small",
......
...@@ -78,16 +78,19 @@ class Raft_Large_Weights(WeightsEnum): ...@@ -78,16 +78,19 @@ class Raft_Large_Weights(WeightsEnum):
class Raft_Small_Weights(WeightsEnum): class Raft_Small_Weights(WeightsEnum):
pass C_T_V1 = Weights(
# C_T_V1 = Weights( # Chairs + Things, ported from original paper repo (raft-small.pth)
# url="", # TODO url="https://download.pytorch.org/models/raft_small_C_T_V1-ad48884c.pth",
# transforms=RaftEval, transforms=RaftEval,
# meta={ meta={
# "recipe": "", **_COMMON_META,
# "epe": -1234, "recipe": "https://github.com/princeton-vl/RAFT",
# }, "sintel_train_cleanpass_epe": 2.1231,
# ) "sintel_train_finalpass_epe": 3.2790,
# default = C_T_V1 },
)
default = C_T_V1 # TODO: Change to V2 once we upload our own weights
@handle_legacy_interface(weights=("pretrained", Raft_Large_Weights.C_T_V2)) @handle_legacy_interface(weights=("pretrained", Raft_Large_Weights.C_T_V2))
...@@ -140,7 +143,8 @@ def raft_large(*, weights: Optional[Raft_Large_Weights] = None, progress=True, * ...@@ -140,7 +143,8 @@ def raft_large(*, weights: Optional[Raft_Large_Weights] = None, progress=True, *
return model return model
@handle_legacy_interface(weights=("pretrained", None)) # TODO: change to V2 once we upload our own weights
@handle_legacy_interface(weights=("pretrained", Raft_Small_Weights.C_T_V1))
def raft_small(*, weights: Optional[Raft_Small_Weights] = None, progress=True, **kwargs): def raft_small(*, weights: Optional[Raft_Small_Weights] = None, progress=True, **kwargs):
"""RAFT "small" model from """RAFT "small" model from
`RAFT: Recurrent All Pairs Field Transforms for Optical Flow <https://arxiv.org/abs/2003.12039>`_. `RAFT: Recurrent All Pairs Field Transforms for Optical Flow <https://arxiv.org/abs/2003.12039>`_.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment