Unverified Commit 1c3eedc4 authored by Ponku's avatar Ponku Committed by GitHub
Browse files

add crestereo implementation (#6310)



* crestereo draft implementation

* minor model fixes. positional embedding changes.

* aligned base configuration with paper

* Adressing comments

* Broke down Adaptive Correlation Layer. Adressed some other commets.

* adressed some nits

* changed search size, added output channels to model attrs

* changed weights naming

* changed from iterations to num_iters

* removed _make_coords, adressed comments

* fixed jit test

* config nit

* Changed device arg to str
Co-authored-by: default avatarJoao Gomes <jdsgomes@fb.com>
Co-authored-by: default avatarYosuaMichael <yosuamichaelm@gmail.com>
parent 544a4070
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
......@@ -5,7 +5,7 @@ from common_utils import cpu_and_gpu, set_rng_seed
from torchvision.prototype import models
@pytest.mark.parametrize("model_fn", TM.list_model_fns(models.depth.stereo))
@pytest.mark.parametrize("model_fn", (models.depth.stereo.raft_stereo_base,))
@pytest.mark.parametrize("model_mode", ("standard", "scripted"))
@pytest.mark.parametrize("dev", cpu_and_gpu())
def test_raft_stereo(model_fn, model_mode, dev):
......@@ -35,4 +35,50 @@ def test_raft_stereo(model_fn, model_mode, dev):
), f"The output shape of depth_pred should be [1, 1, 64, 64] but instead it is {preds[0].shape}"
# Test against expected file output
TM._assert_expected(depth_pred.cpu(), name=model_fn.__name__, atol=1e-2, rtol=1e-2)
TM._assert_expected(depth_pred, name=model_fn.__name__, atol=1e-2, rtol=1e-2)
@pytest.mark.parametrize("model_fn", (models.depth.stereo.crestereo_base,))
@pytest.mark.parametrize("model_mode", ("standard", "scripted"))
@pytest.mark.parametrize("dev", cpu_and_gpu())
def test_crestereo(model_fn, model_mode, dev):
set_rng_seed(0)
model = model_fn().eval().to(dev)
if model_mode == "scripted":
model = torch.jit.script(model)
img1 = torch.rand(1, 3, 64, 64).to(dev)
img2 = torch.rand(1, 3, 64, 64).to(dev)
iterations = 3
preds = model(img1, img2, flow_init=None, num_iters=iterations)
disparity_pred = preds[-1]
# all the pyramid levels except the highest res make only half the number of iterations
expected_iterations = (iterations // 2) * (len(model.resolutions) - 1)
expected_iterations += iterations
assert (
len(preds) == expected_iterations
), "Number of predictions should be the number of iterations multiplied by the number of pyramid levels"
assert disparity_pred.shape == torch.Size(
[1, 2, 64, 64]
), f"Predicted disparity should have the same spatial shape as the input. Inputs shape {img1.shape[2:]}, Prediction shape {disparity_pred.shape[2:]}"
assert all(
d.shape == torch.Size([1, 2, 64, 64]) for d in preds
), "All predicted disparities are expected to have the same shape"
# test a backward pass with a dummy loss as well
preds = torch.stack(preds, dim=0)
targets = torch.ones_like(preds, requires_grad=False)
loss = torch.nn.functional.mse_loss(preds, targets)
try:
loss.backward()
except Exception as e:
assert False, f"Backward pass failed with an unexpected exception: {e.__class__.__name__} {e}"
TM._assert_expected(disparity_pred, name=model_fn.__name__, atol=1e-2, rtol=1e-2)
......@@ -19,8 +19,9 @@ def grid_sample(img: Tensor, absolute_grid: Tensor, mode: str = "bilinear", alig
return F.grid_sample(img, normalized_grid, mode=mode, align_corners=align_corners)
def make_coords_grid(batch_size: int, h: int, w: int):
coords = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
def make_coords_grid(batch_size: int, h: int, w: int, device: str = "cpu"):
device = torch.device(device)
coords = torch.meshgrid(torch.arange(h, device=device), torch.arange(w, device=device), indexing="ij")
coords = torch.stack(coords[::-1], dim=0).float()
return coords[None].repeat(batch_size, 1, 1, 1)
......
......@@ -27,7 +27,7 @@ __all__ = (
class ResidualBlock(nn.Module):
"""Slightly modified Residual block with extra relu and biases."""
def __init__(self, in_channels, out_channels, *, norm_layer, stride=1):
def __init__(self, in_channels, out_channels, *, norm_layer, stride=1, always_project: bool = False):
super().__init__()
# Note regarding bias=True:
......@@ -43,7 +43,10 @@ class ResidualBlock(nn.Module):
out_channels, out_channels, norm_layer=norm_layer, kernel_size=3, bias=True
)
if stride == 1:
# make mypy happy
self.downsample: nn.Module
if stride == 1 and not always_project:
self.downsample = nn.Identity()
else:
self.downsample = Conv2dNormActivation(
......@@ -144,6 +147,10 @@ class FeatureEncoder(nn.Module):
if m.bias is not None:
nn.init.constant_(m.bias, 0)
num_downsamples = len(list(filter(lambda s: s == 2, strides)))
self.output_dim = layers[-1]
self.downsample_factor = 2**num_downsamples
def _make_2_blocks(self, block, in_channels, out_channels, norm_layer, first_stride):
block1 = block(in_channels, out_channels, norm_layer=norm_layer, stride=first_stride)
block2 = block(out_channels, out_channels, norm_layer=norm_layer, stride=1)
......
from .raft_stereo import *
from .crestereo import *
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment