Unverified Commit 5dc1e20b authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Add raft builders and presets in prototypes (#5043)


Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent a57e45c8
...@@ -3,7 +3,7 @@ import importlib ...@@ -3,7 +3,7 @@ import importlib
import pytest import pytest
import test_models as TM import test_models as TM
import torch import torch
from common_utils import cpu_and_gpu, run_on_env_var from common_utils import cpu_and_gpu, run_on_env_var, needs_cuda
from torchvision.prototype import models from torchvision.prototype import models
from torchvision.prototype.models._api import WeightsEnum, Weights from torchvision.prototype.models._api import WeightsEnum, Weights
from torchvision.prototype.models._utils import handle_legacy_interface from torchvision.prototype.models._utils import handle_legacy_interface
...@@ -75,10 +75,12 @@ def test_get_weight(name, weight): ...@@ -75,10 +75,12 @@ def test_get_weight(name, weight):
+ TM.get_models_from_module(models.detection) + TM.get_models_from_module(models.detection)
+ TM.get_models_from_module(models.quantization) + TM.get_models_from_module(models.quantization)
+ TM.get_models_from_module(models.segmentation) + TM.get_models_from_module(models.segmentation)
+ TM.get_models_from_module(models.video), + TM.get_models_from_module(models.video)
+ TM.get_models_from_module(models.optical_flow),
) )
def test_naming_conventions(model_fn): def test_naming_conventions(model_fn):
weights_enum = _get_model_weights(model_fn) weights_enum = _get_model_weights(model_fn)
print(weights_enum)
assert weights_enum is not None assert weights_enum is not None
assert len(weights_enum) == 0 or hasattr(weights_enum, "default") assert len(weights_enum) == 0 or hasattr(weights_enum, "default")
...@@ -149,13 +151,22 @@ def test_video_model(model_fn, dev): ...@@ -149,13 +151,22 @@ def test_video_model(model_fn, dev):
TM.test_video_model(model_fn, dev) TM.test_video_model(model_fn, dev)
@needs_cuda
@pytest.mark.parametrize("model_builder", TM.get_models_from_module(models.optical_flow))
@pytest.mark.parametrize("scripted", (False, True))
@run_if_test_with_prototype
def test_raft(model_builder, scripted):
TM.test_raft(model_builder, scripted)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model_fn", "model_fn",
TM.get_models_from_module(models) TM.get_models_from_module(models)
+ TM.get_models_from_module(models.detection) + TM.get_models_from_module(models.detection)
+ TM.get_models_from_module(models.quantization) + TM.get_models_from_module(models.quantization)
+ TM.get_models_from_module(models.segmentation) + TM.get_models_from_module(models.segmentation)
+ TM.get_models_from_module(models.video), + TM.get_models_from_module(models.video)
+ TM.get_models_from_module(models.optical_flow),
) )
@pytest.mark.parametrize("dev", cpu_and_gpu()) @pytest.mark.parametrize("dev", cpu_and_gpu())
@run_if_test_with_prototype @run_if_test_with_prototype
...@@ -177,6 +188,9 @@ def test_old_vs_new_factory(model_fn, dev): ...@@ -177,6 +188,9 @@ def test_old_vs_new_factory(model_fn, dev):
"video": { "video": {
"input_shape": (1, 3, 4, 112, 112), "input_shape": (1, 3, 4, 112, 112),
}, },
"optical_flow": {
"input_shape": (1, 3, 128, 128),
},
} }
model_name = model_fn.__name__ model_name = model_fn.__name__
module_name = model_fn.__module__.split(".")[-2] module_name = model_fn.__module__.split(".")[-2]
......
...@@ -585,7 +585,7 @@ def raft_large(*, pretrained=False, progress=True, **kwargs): ...@@ -585,7 +585,7 @@ def raft_large(*, pretrained=False, progress=True, **kwargs):
""" """
if pretrained: if pretrained:
raise ValueError("Pretrained weights aren't available yet") raise ValueError("No checkpoint is available for raft_large")
return _raft( return _raft(
# Feature encoder # Feature encoder
...@@ -631,7 +631,7 @@ def raft_small(*, pretrained=False, progress=True, **kwargs): ...@@ -631,7 +631,7 @@ def raft_small(*, pretrained=False, progress=True, **kwargs):
""" """
if pretrained: if pretrained:
raise ValueError("Pretrained weights aren't available yet") raise ValueError("No checkpoint is available for raft_small")
return _raft( return _raft(
# Feature encoder # Feature encoder
......
...@@ -12,6 +12,7 @@ from .squeezenet import * ...@@ -12,6 +12,7 @@ from .squeezenet import *
from .vgg import * from .vgg import *
from .vision_transformer import * from .vision_transformer import *
from . import detection from . import detection
from . import optical_flow
from . import quantization from . import quantization
from . import segmentation from . import segmentation
from . import video from . import video
......
from .raft import RAFT, raft_large, raft_small, Raft_Large_Weights, Raft_Small_Weights
from typing import Optional
from torch.nn.modules.batchnorm import BatchNorm2d
from torch.nn.modules.instancenorm import InstanceNorm2d
from torchvision.models.optical_flow import RAFT
from torchvision.models.optical_flow.raft import _raft, BottleneckBlock, ResidualBlock
# from torchvision.prototype.transforms import RaftEval
from .._api import WeightsEnum
# from .._api import Weights
from .._utils import handle_legacy_interface
__all__ = (
"RAFT",
"raft_large",
"raft_small",
"Raft_Large_Weights",
"Raft_Small_Weights",
)
class Raft_Large_Weights(WeightsEnum):
pass
# C_T_V1 = Weights(
# # Chairs + Things
# url="",
# transforms=RaftEval,
# meta={
# "recipe": "",
# "epe": -1234,
# },
# )
# C_T_SKHT_V1 = Weights(
# # Chairs + Things + Sintel fine-tuning, i.e.:
# # Chairs + Things + (Sintel + Kitti + HD1K + Things_clean)
# # Corresponds to the C+T+S+K+H on paper with fine-tuning on Sintel
# url="",
# transforms=RaftEval,
# meta={
# "recipe": "",
# "epe": -1234,
# },
# )
# C_T_SKHT_K_V1 = Weights(
# # Chairs + Things + Sintel fine-tuning + Kitti fine-tuning i.e.:
# # Chairs + Things + (Sintel + Kitti + HD1K + Things_clean) + Kitti
# # Same as CT_SKHT with extra fine-tuning on Kitti
# # Corresponds to the C+T+S+K+H on paper with fine-tuning on Sintel and then on Kitti
# url="",
# transforms=RaftEval,
# meta={
# "recipe": "",
# "epe": -1234,
# },
# )
# default = C_T_V1
class Raft_Small_Weights(WeightsEnum):
pass
# C_T_V1 = Weights(
# url="", # TODO
# transforms=RaftEval,
# meta={
# "recipe": "",
# "epe": -1234,
# },
# )
# default = C_T_V1
@handle_legacy_interface(weights=("pretrained", None))
def raft_large(*, weights: Optional[Raft_Large_Weights] = None, progress=True, **kwargs):
"""RAFT model from
`RAFT: Recurrent All Pairs Field Transforms for Optical Flow <https://arxiv.org/abs/2003.12039>`_.
Args:
weights(Raft_Large_weights, optinal): TODO not implemented yet
progress (bool): If True, displays a progress bar of the download to stderr
kwargs (dict): Parameters that will be passed to the :class:`~torchvision.models.optical_flow.RAFT` class
to override any default.
Returns:
nn.Module: The model.
"""
weights = Raft_Large_Weights.verify(weights)
return _raft(
# Feature encoder
feature_encoder_layers=(64, 64, 96, 128, 256),
feature_encoder_block=ResidualBlock,
feature_encoder_norm_layer=InstanceNorm2d,
# Context encoder
context_encoder_layers=(64, 64, 96, 128, 256),
context_encoder_block=ResidualBlock,
context_encoder_norm_layer=BatchNorm2d,
# Correlation block
corr_block_num_levels=4,
corr_block_radius=4,
# Motion encoder
motion_encoder_corr_layers=(256, 192),
motion_encoder_flow_layers=(128, 64),
motion_encoder_out_channels=128,
# Recurrent block
recurrent_block_hidden_state_size=128,
recurrent_block_kernel_size=((1, 5), (5, 1)),
recurrent_block_padding=((0, 2), (2, 0)),
# Flow head
flow_head_hidden_size=256,
# Mask predictor
use_mask_predictor=True,
**kwargs,
)
@handle_legacy_interface(weights=("pretrained", None))
def raft_small(*, weights: Optional[Raft_Small_Weights] = None, progress=True, **kwargs):
"""RAFT "small" model from
`RAFT: Recurrent All Pairs Field Transforms for Optical Flow <https://arxiv.org/abs/2003.12039>`_.
Args:
weights(Raft_Small_weights, optinal): TODO not implemented yet
progress (bool): If True, displays a progress bar of the download to stderr
kwargs (dict): Parameters that will be passed to the :class:`~torchvision.models.optical_flow.RAFT` class
to override any default.
Returns:
nn.Module: The model.
"""
weights = Raft_Small_Weights.verify(weights)
return _raft(
# Feature encoder
feature_encoder_layers=(32, 32, 64, 96, 128),
feature_encoder_block=BottleneckBlock,
feature_encoder_norm_layer=InstanceNorm2d,
# Context encoder
context_encoder_layers=(32, 32, 64, 96, 160),
context_encoder_block=BottleneckBlock,
context_encoder_norm_layer=None,
# Correlation block
corr_block_num_levels=4,
corr_block_radius=3,
# Motion encoder
motion_encoder_corr_layers=(96,),
motion_encoder_flow_layers=(64, 32),
motion_encoder_out_channels=82,
# Recurrent block
recurrent_block_hidden_state_size=96,
recurrent_block_kernel_size=(3,),
recurrent_block_padding=(1,),
# Flow head
flow_head_hidden_size=128,
# Mask predictor
use_mask_predictor=False,
**kwargs,
)
...@@ -3,4 +3,4 @@ from ._container import Compose, RandomApply, RandomChoice, RandomOrder # usort ...@@ -3,4 +3,4 @@ from ._container import Compose, RandomApply, RandomChoice, RandomOrder # usort
from ._geometry import Resize, RandomResize, HorizontalFlip, Crop, CenterCrop, RandomCrop from ._geometry import Resize, RandomResize, HorizontalFlip, Crop, CenterCrop, RandomCrop
from ._misc import Identity, Normalize from ._misc import Identity, Normalize
from ._presets import CocoEval, ImageNetEval, VocEval, Kinect400Eval from ._presets import CocoEval, ImageNetEval, VocEval, Kinect400Eval, RaftEval
...@@ -6,7 +6,7 @@ from torch import Tensor, nn ...@@ -6,7 +6,7 @@ from torch import Tensor, nn
from ...transforms import functional as F, InterpolationMode from ...transforms import functional as F, InterpolationMode
__all__ = ["CocoEval", "ImageNetEval", "Kinect400Eval", "VocEval"] __all__ = ["CocoEval", "ImageNetEval", "Kinect400Eval", "VocEval", "RaftEval"]
class CocoEval(nn.Module): class CocoEval(nn.Module):
...@@ -97,3 +97,38 @@ class VocEval(nn.Module): ...@@ -97,3 +97,38 @@ class VocEval(nn.Module):
target = F.pil_to_tensor(target) target = F.pil_to_tensor(target)
target = target.squeeze(0).to(torch.int64) target = target.squeeze(0).to(torch.int64)
return img, target return img, target
class RaftEval(nn.Module):
def forward(
self, img1: Tensor, img2: Tensor, flow: Optional[Tensor], valid_flow_mask: Optional[Tensor]
) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]:
img1, img2, flow, valid_flow_mask = self._pil_or_numpy_to_tensor(img1, img2, flow, valid_flow_mask)
img1 = F.convert_image_dtype(img1, torch.float32)
img2 = F.convert_image_dtype(img2, torch.float32)
# map [0, 1] into [-1, 1]
img1 = F.normalize(img1, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
img2 = F.normalize(img2, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
img1 = img1.contiguous()
img2 = img2.contiguous()
return img1, img2, flow, valid_flow_mask
def _pil_or_numpy_to_tensor(
self, img1: Tensor, img2: Tensor, flow: Optional[Tensor], valid_flow_mask: Optional[Tensor]
) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]:
if not isinstance(img1, Tensor):
img1 = F.pil_to_tensor(img1)
if not isinstance(img2, Tensor):
img2 = F.pil_to_tensor(img2)
if flow is not None and not isinstance(flow, Tensor):
flow = torch.from_numpy(flow)
if valid_flow_mask is not None and not isinstance(valid_flow_mask, Tensor):
valid_flow_mask = torch.from_numpy(valid_flow_mask)
return img1, img2, flow, valid_flow_mask
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment