Unverified Commit 11caf37a authored by YosuaMichael's avatar YosuaMichael Committed by GitHub
Browse files

Add raft-stereo model to prototype/models (#6107)

* Add rough raft-stereo implementation on prototype/models

* Add standard raft_stereo builder, and modify context_encoder to be more similar with original implementation

* Follow original implementation on pre-convolve context

* Fix to make sure we can load original implementation weight and got same output

* reusing component from raft

* Make the raft_stereo_fast able to load original weight implementation

* Format with ufmt and update some comment

* Use raft FlowHead

* clean up comments

* Remove unnecessary import and use ufmt format

* Add __all__ and more docs for RaftStereo class

* Only accept param and not module for raft stereo builder

* Cleanup comment

* Adding typing to raft_stereo

* Update some of raft code and reuse on raft stereo

* Use bool instead of int

* Make standard raft_stereo model jit scriptable

* Make the function _make_out_layer using boolean with_block and init the block_layer with identity

* Separate corr_block into two modules for pyramid and building corr features

* Use tuple if input is not variable size, also remove default value if using List

* Format using ufmt and update ConvGRU to not inherit from raft in order to satisfy both jit script and mypy

* Change RaftStereo docs input type

* Ufmt format raft

* revert back convgru to see mypy errors, add test for jit and fx, make the model fx compatible

* ufmt format

* Specify device for new tensor, dont init module then overwrite and put if-else instead

* Ignore mypy problem on override, put back num_iters on forward

* Revert some effort to make it fx compatible but unnecessary now

* refactor code and remove num_iters from RaftStereo constructor

* Change to raft_stereo_realtime, and specify device directly for tensor creation

* Add description for raft_stereo_realtime

* Update the test for raft_stereo

* Fix raft stereo prototype test to properly test jit script

* Ufmt format

* Test against expected file, change name from raft_stereo to raft_stereo_builder to prevent import error

* Revert __init__.py changes

* Add default value for non-list param on model builder

* Add checking on out_with_block length, add more docs on the encoder

* Use base instead of basic since it is more commonly used

* rename expect file to base as well

* rename on test

* Revert the revert of __init__.py, also revert the adding default value to _raft_stereo to follow the standard pattern

* ufmt format __init__.py
parent 59c4de91
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
import pytest
import test_models as TM
import torch
import torchvision.prototype.models.depth.stereo.raft_stereo as raft_stereo
from common_utils import set_rng_seed, cpu_and_gpu
@pytest.mark.parametrize("model_builder", (raft_stereo.raft_stereo_base, raft_stereo.raft_stereo_realtime))
@pytest.mark.parametrize("model_mode", ("standard", "scripted"))
@pytest.mark.parametrize("dev", cpu_and_gpu())
def test_raft_stereo(model_builder, model_mode, dev):
# A simple test to make sure the model can do forward pass and jit scriptable
set_rng_seed(0)
# Use corr_pyramid and corr_block with smaller num_levels and radius to prevent nan output
# get the idea from test_models.test_raft
corr_pyramid = raft_stereo.CorrPyramid1d(num_levels=2)
corr_block = raft_stereo.CorrBlock1d(num_levels=2, radius=2)
model = model_builder(corr_pyramid=corr_pyramid, corr_block=corr_block).eval().to(dev)
if model_mode == "scripted":
model = torch.jit.script(model)
img1 = torch.rand(1, 3, 64, 64).to(dev)
img2 = torch.rand(1, 3, 64, 64).to(dev)
num_iters = 3
preds = model(img1, img2, num_iters=num_iters)
depth_pred = preds[-1]
assert len(preds) == num_iters, "Number of predictions should be the same as model.num_iters"
assert depth_pred.shape == torch.Size(
[1, 1, 64, 64]
), f"The output shape of depth_pred should be [1, 1, 64, 64] but instead it is {preds[0].shape}"
# Test against expected file output
TM._assert_expected(depth_pred, name=model_builder.__name__, atol=1e-2, rtol=1e-2)
...@@ -11,6 +11,8 @@ def grid_sample(img: Tensor, absolute_grid: Tensor, mode: str = "bilinear", alig ...@@ -11,6 +11,8 @@ def grid_sample(img: Tensor, absolute_grid: Tensor, mode: str = "bilinear", alig
xgrid, ygrid = absolute_grid.split([1, 1], dim=-1) xgrid, ygrid = absolute_grid.split([1, 1], dim=-1)
xgrid = 2 * xgrid / (w - 1) - 1 xgrid = 2 * xgrid / (w - 1) - 1
# Adding condition if h > 1 to enable this function be reused in raft-stereo
if h > 1:
ygrid = 2 * ygrid / (h - 1) - 1 ygrid = 2 * ygrid / (h - 1) - 1
normalized_grid = torch.cat([xgrid, ygrid], dim=-1) normalized_grid = torch.cat([xgrid, ygrid], dim=-1)
...@@ -23,23 +25,23 @@ def make_coords_grid(batch_size: int, h: int, w: int): ...@@ -23,23 +25,23 @@ def make_coords_grid(batch_size: int, h: int, w: int):
return coords[None].repeat(batch_size, 1, 1, 1) return coords[None].repeat(batch_size, 1, 1, 1)
def upsample_flow(flow, up_mask: Optional[Tensor] = None): def upsample_flow(flow, up_mask: Optional[Tensor] = None, factor: int = 8):
"""Upsample flow by a factor of 8. """Upsample flow by the input factor (default 8).
If up_mask is None we just interpolate. If up_mask is None we just interpolate.
If up_mask is specified, we upsample using a convex combination of its weights. See paper page 8 and appendix B. If up_mask is specified, we upsample using a convex combination of its weights. See paper page 8 and appendix B.
Note that in appendix B the picture assumes a downsample factor of 4 instead of 8. Note that in appendix B the picture assumes a downsample factor of 4 instead of 8.
""" """
batch_size, _, h, w = flow.shape batch_size, num_channels, h, w = flow.shape
new_h, new_w = h * 8, w * 8 new_h, new_w = h * factor, w * factor
if up_mask is None: if up_mask is None:
return 8 * F.interpolate(flow, size=(new_h, new_w), mode="bilinear", align_corners=True) return factor * F.interpolate(flow, size=(new_h, new_w), mode="bilinear", align_corners=True)
up_mask = up_mask.view(batch_size, 1, 9, 8, 8, h, w) up_mask = up_mask.view(batch_size, 1, 9, factor, factor, h, w)
up_mask = torch.softmax(up_mask, dim=2) # "convex" == weights sum to 1 up_mask = torch.softmax(up_mask, dim=2) # "convex" == weights sum to 1
upsampled_flow = F.unfold(8 * flow, kernel_size=3, padding=1).view(batch_size, 2, 9, 1, 1, h, w) upsampled_flow = F.unfold(factor * flow, kernel_size=3, padding=1).view(batch_size, num_channels, 9, 1, 1, h, w)
upsampled_flow = torch.sum(up_mask * upsampled_flow, dim=2) upsampled_flow = torch.sum(up_mask * upsampled_flow, dim=2)
return upsampled_flow.permute(0, 1, 4, 2, 5, 3).reshape(batch_size, 2, new_h, new_w) return upsampled_flow.permute(0, 1, 4, 2, 5, 3).reshape(batch_size, num_channels, new_h, new_w)
...@@ -116,7 +116,9 @@ class FeatureEncoder(nn.Module): ...@@ -116,7 +116,9 @@ class FeatureEncoder(nn.Module):
It must downsample its input by 8. It must downsample its input by 8.
""" """
def __init__(self, *, block=ResidualBlock, layers=(64, 64, 96, 128, 256), norm_layer=nn.BatchNorm2d): def __init__(
self, *, block=ResidualBlock, layers=(64, 64, 96, 128, 256), strides=(2, 1, 2, 2), norm_layer=nn.BatchNorm2d
):
super().__init__() super().__init__()
if len(layers) != 5: if len(layers) != 5:
...@@ -124,12 +126,12 @@ class FeatureEncoder(nn.Module): ...@@ -124,12 +126,12 @@ class FeatureEncoder(nn.Module):
# See note in ResidualBlock for the reason behind bias=True # See note in ResidualBlock for the reason behind bias=True
self.convnormrelu = Conv2dNormActivation( self.convnormrelu = Conv2dNormActivation(
3, layers[0], norm_layer=norm_layer, kernel_size=7, stride=2, bias=True 3, layers[0], norm_layer=norm_layer, kernel_size=7, stride=strides[0], bias=True
) )
self.layer1 = self._make_2_blocks(block, layers[0], layers[1], norm_layer=norm_layer, first_stride=1) self.layer1 = self._make_2_blocks(block, layers[0], layers[1], norm_layer=norm_layer, first_stride=strides[1])
self.layer2 = self._make_2_blocks(block, layers[1], layers[2], norm_layer=norm_layer, first_stride=2) self.layer2 = self._make_2_blocks(block, layers[1], layers[2], norm_layer=norm_layer, first_stride=strides[2])
self.layer3 = self._make_2_blocks(block, layers[2], layers[3], norm_layer=norm_layer, first_stride=2) self.layer3 = self._make_2_blocks(block, layers[2], layers[3], norm_layer=norm_layer, first_stride=strides[3])
self.conv = nn.Conv2d(layers[3], layers[4], kernel_size=1) self.conv = nn.Conv2d(layers[3], layers[4], kernel_size=1)
......
from . import datasets from . import datasets
from . import features from . import features
from . import models
from . import transforms from . import transforms
from . import utils from . import utils
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment