Unverified Commit 355b2788 authored by YosuaMichael's avatar YosuaMichael Committed by GitHub
Browse files

Update RAFT Stereo to be more sync with CREStereo implementation (#6575)

* Update raft_stereo to sync forward param with CREStereo and have output_channel

* Add flow_init param to docstring

* Use output_channels instead of output_channel

* Replace depth with disparity since what we predict actually disparity instead of actual depth
parent 1ea73f58
...@@ -204,7 +204,7 @@ class MultiLevelUpdateBlock(nn.Module): ...@@ -204,7 +204,7 @@ class MultiLevelUpdateBlock(nn.Module):
hidden_states: List[Tensor], hidden_states: List[Tensor],
contexts: List[List[Tensor]], contexts: List[List[Tensor]],
corr_features: Tensor, corr_features: Tensor,
depth: Tensor, disparity: Tensor,
level_processed: List[bool], level_processed: List[bool],
) -> List[Tensor]: ) -> List[Tensor]:
# We call it reverse_i because it has a reversed ordering compared to hidden_states # We call it reverse_i because it has a reversed ordering compared to hidden_states
...@@ -215,7 +215,7 @@ class MultiLevelUpdateBlock(nn.Module): ...@@ -215,7 +215,7 @@ class MultiLevelUpdateBlock(nn.Module):
# X is concatination of 2x downsampled hidden_dim (or motion_features if no bigger dim) with # X is concatination of 2x downsampled hidden_dim (or motion_features if no bigger dim) with
# upsampled hidden_dim (or nothing if not exist). # upsampled hidden_dim (or nothing if not exist).
if i == 0: if i == 0:
features = self.motion_encoder(depth, corr_features) features = self.motion_encoder(disparity, corr_features)
else: else:
# 2x downsampled features from larger hidden states # 2x downsampled features from larger hidden states
features = F.avg_pool2d(hidden_states[i - 1], kernel_size=3, stride=2, padding=1) features = F.avg_pool2d(hidden_states[i - 1], kernel_size=3, stride=2, padding=1)
...@@ -235,14 +235,14 @@ class MultiLevelUpdateBlock(nn.Module): ...@@ -235,14 +235,14 @@ class MultiLevelUpdateBlock(nn.Module):
hidden_states[i] = gru(hidden_states[i], features, contexts[i]) hidden_states[i] = gru(hidden_states[i], features, contexts[i])
# NOTE: For slow-fast gru, we dont always want to calculate delta depth for every call on UpdateBlock # NOTE: For slow-fast gru, we dont always want to calculate delta disparity for every call on UpdateBlock
# Hence we move the delta depth calculation to the RAFT-Stereo main forward # Hence we move the delta disparity calculation to the RAFT-Stereo main forward
return hidden_states return hidden_states
class MaskPredictor(raft.MaskPredictor): class MaskPredictor(raft.MaskPredictor):
"""Mask predictor to be used when upsampling the predicted depth.""" """Mask predictor to be used when upsampling the predicted disparity."""
# We add out_channels compared to raft.MaskPredictor # We add out_channels compared to raft.MaskPredictor
def __init__(self, *, in_channels: int, hidden_size: int, out_channels: int, multiplier: float = 0.25): def __init__(self, *, in_channels: int, hidden_size: int, out_channels: int, multiplier: float = 0.25):
...@@ -346,7 +346,7 @@ class RaftStereo(nn.Module): ...@@ -346,7 +346,7 @@ class RaftStereo(nn.Module):
corr_pyramid: CorrPyramid1d, corr_pyramid: CorrPyramid1d,
corr_block: CorrBlock1d, corr_block: CorrBlock1d,
update_block: MultiLevelUpdateBlock, update_block: MultiLevelUpdateBlock,
depth_head: nn.Module, disparity_head: nn.Module,
mask_predictor: Optional[nn.Module] = None, mask_predictor: Optional[nn.Module] = None,
slow_fast: bool = False, slow_fast: bool = False,
): ):
...@@ -354,8 +354,8 @@ class RaftStereo(nn.Module): ...@@ -354,8 +354,8 @@ class RaftStereo(nn.Module):
`RAFT-Stereo: Multilevel Recurrent Field Transforms for Stereo Matching <https://arxiv.org/abs/2109.07547>`_. `RAFT-Stereo: Multilevel Recurrent Field Transforms for Stereo Matching <https://arxiv.org/abs/2109.07547>`_.
args: args:
feature_encoder (FeatureEncoder): The feature encoder. Its input is the concatenation of ``image1`` and ``image2``. feature_encoder (FeatureEncoder): The feature encoder. Its input is the concatenation of ``left_image`` and ``right_image``.
context_encoder (MultiLevelContextEncoder): The context encoder. Its input is ``image1``. context_encoder (MultiLevelContextEncoder): The context encoder. Its input is ``left_image``.
It has multi-level output and each level will have 2 parts: It has multi-level output and each level will have 2 parts:
- one part will be used as the actual "context", passed to the recurrent unit of the ``update_block`` - one part will be used as the actual "context", passed to the recurrent unit of the ``update_block``
...@@ -370,8 +370,8 @@ class RaftStereo(nn.Module): ...@@ -370,8 +370,8 @@ class RaftStereo(nn.Module):
update_block (MultiLevelUpdateBlock): The update block, which contains the motion encoder, and the recurrent unit. update_block (MultiLevelUpdateBlock): The update block, which contains the motion encoder, and the recurrent unit.
It takes as input the hidden state of its recurrent unit, the context, the correlation It takes as input the hidden state of its recurrent unit, the context, the correlation
features, and the current predicted depth. It outputs an updated hidden state features, and the current predicted disparity. It outputs an updated hidden state
depth_head (nn.Module): The depth head block will convert from the hidden state into changes in depth. disparity_head (nn.Module): The disparity head block will convert from the hidden state into changes in disparity.
mask_predictor (nn.Module, optional): Predicts the mask that will be used to upsample the predicted flow. mask_predictor (nn.Module, optional): Predicts the mask that will be used to upsample the predicted flow.
If ``None`` (default), the flow is upsampled using interpolation. If ``None`` (default), the flow is upsampled using interpolation.
slow_fast (bool): A boolean that specify whether we should use slow-fast GRU or not. See RAFT-Stereo paper slow_fast (bool): A boolean that specify whether we should use slow-fast GRU or not. See RAFT-Stereo paper
...@@ -380,6 +380,10 @@ class RaftStereo(nn.Module): ...@@ -380,6 +380,10 @@ class RaftStereo(nn.Module):
super().__init__() super().__init__()
_log_api_usage_once(self) _log_api_usage_once(self)
# This indicate that the disparity output will be only have 1 channel (represent horizontal axis).
# We need this because some stereo matching model like CREStereo might have 2 channel on the output
self.output_channels = 1
self.feature_encoder = feature_encoder self.feature_encoder = feature_encoder
self.context_encoder = context_encoder self.context_encoder = context_encoder
...@@ -388,7 +392,7 @@ class RaftStereo(nn.Module): ...@@ -388,7 +392,7 @@ class RaftStereo(nn.Module):
self.corr_pyramid = corr_pyramid self.corr_pyramid = corr_pyramid
self.corr_block = corr_block self.corr_block = corr_block
self.update_block = update_block self.update_block = update_block
self.depth_head = depth_head self.disparity_head = disparity_head
self.mask_predictor = mask_predictor self.mask_predictor = mask_predictor
hidden_dims = self.update_block.hidden_dims hidden_dims = self.update_block.hidden_dims
...@@ -399,18 +403,21 @@ class RaftStereo(nn.Module): ...@@ -399,18 +403,21 @@ class RaftStereo(nn.Module):
) )
self.slow_fast = slow_fast self.slow_fast = slow_fast
def forward(self, image1: Tensor, image2: Tensor, num_iters: int = 12) -> List[Tensor]: def forward(
self, left_image: Tensor, right_image: Tensor, flow_init: Optional[Tensor] = None, num_iters: int = 12
) -> List[Tensor]:
""" """
Return dept predictions on every iterations as a list of Tensor. Return disparity predictions on every iterations as a list of Tensor.
args: args:
image1 (Tensor): The input left image with layout B, C, H, W left_image (Tensor): The input left image with layout B, C, H, W
image2 (Tensor): The input right image with layout B, C, H, W right_image (Tensor): The input right image with layout B, C, H, W
flow_init (Optional[Tensor]): Initial estimate for the disparity. Default: None
num_iters (int): Number of update block iteration on the largest resolution. Default: 12 num_iters (int): Number of update block iteration on the largest resolution. Default: 12
""" """
batch_size, _, h, w = image1.shape batch_size, _, h, w = left_image.shape
torch._assert( torch._assert(
(h, w) == image2.shape[-2:], (h, w) == right_image.shape[-2:],
f"input images should have the same shape, instead got ({h}, {w}) != {image2.shape[-2:]}", f"input images should have the same shape, instead got ({h}, {w}) != {right_image.shape[-2:]}",
) )
torch._assert( torch._assert(
...@@ -418,7 +425,7 @@ class RaftStereo(nn.Module): ...@@ -418,7 +425,7 @@ class RaftStereo(nn.Module):
f"input image H and W should be divisible by {self.base_downsampling_ratio}, insted got H={h} and W={w}", f"input image H and W should be divisible by {self.base_downsampling_ratio}, insted got H={h} and W={w}",
) )
fmaps = self.feature_encoder(torch.cat([image1, image2], dim=0)) fmaps = self.feature_encoder(torch.cat([left_image, right_image], dim=0))
fmap1, fmap2 = torch.chunk(fmaps, chunks=2, dim=0) fmap1, fmap2 = torch.chunk(fmaps, chunks=2, dim=0)
torch._assert( torch._assert(
fmap1.shape[-2:] == (h // self.base_downsampling_ratio, w // self.base_downsampling_ratio), fmap1.shape[-2:] == (h // self.base_downsampling_ratio, w // self.base_downsampling_ratio),
...@@ -428,7 +435,7 @@ class RaftStereo(nn.Module): ...@@ -428,7 +435,7 @@ class RaftStereo(nn.Module):
corr_pyramid = self.corr_pyramid(fmap1, fmap2) corr_pyramid = self.corr_pyramid(fmap1, fmap2)
# Multi level contexts # Multi level contexts
context_outs = self.context_encoder(image1) context_outs = self.context_encoder(left_image)
hidden_dims = self.update_block.hidden_dims hidden_dims = self.update_block.hidden_dims
context_out_channels = [context_outs[i].shape[1] - hidden_dims[i] for i in range(len(context_outs))] context_out_channels = [context_outs[i].shape[1] - hidden_dims[i] for i in range(len(context_outs))]
...@@ -448,35 +455,41 @@ class RaftStereo(nn.Module): ...@@ -448,35 +455,41 @@ class RaftStereo(nn.Module):
coords0 = make_coords_grid(batch_size, Hf, Wf).to(fmap1.device) coords0 = make_coords_grid(batch_size, Hf, Wf).to(fmap1.device)
coords1 = make_coords_grid(batch_size, Hf, Wf).to(fmap1.device) coords1 = make_coords_grid(batch_size, Hf, Wf).to(fmap1.device)
depth_predictions = [] # We use flow_init for cascade inference
if flow_init is not None:
coords1 = coords1 + flow_init
disparity_predictions = []
for _ in range(num_iters): for _ in range(num_iters):
coords1 = coords1.detach() # Don't backpropagate gradients through this branch, see paper coords1 = coords1.detach() # Don't backpropagate gradients through this branch, see paper
corr_features = self.corr_block(centroids_coords=coords1, corr_pyramid=corr_pyramid) corr_features = self.corr_block(centroids_coords=coords1, corr_pyramid=corr_pyramid)
depth = coords1 - coords0 disparity = coords1 - coords0
if self.slow_fast: if self.slow_fast:
# Using slow_fast GRU (see paper section 3.4). The lower resolution are processed more often # Using slow_fast GRU (see paper section 3.4). The lower resolution are processed more often
for i in range(1, self.num_level): for i in range(1, self.num_level):
# We only processed the smallest i levels # We only processed the smallest i levels
level_processed = [False] * (self.num_level - i) + [True] * i level_processed = [False] * (self.num_level - i) + [True] * i
hidden_states = self.update_block( hidden_states = self.update_block(
hidden_states, contexts, corr_features, depth, level_processed=level_processed hidden_states, contexts, corr_features, disparity, level_processed=level_processed
) )
hidden_states = self.update_block( hidden_states = self.update_block(
hidden_states, contexts, corr_features, depth, level_processed=[True] * self.num_level hidden_states, contexts, corr_features, disparity, level_processed=[True] * self.num_level
) )
# Take the largest hidden_state to get the depth # Take the largest hidden_state to get the disparity
hidden_state = hidden_states[0] hidden_state = hidden_states[0]
delta_depth = self.depth_head(hidden_state) delta_disparity = self.disparity_head(hidden_state)
# in stereo mode, project depth onto epipolar # in stereo mode, project disparity onto epipolar
delta_depth[:, 1] = 0.0 delta_disparity[:, 1] = 0.0
coords1 = coords1 + delta_depth coords1 = coords1 + delta_disparity
up_mask = None if self.mask_predictor is None else self.mask_predictor(hidden_state) up_mask = None if self.mask_predictor is None else self.mask_predictor(hidden_state)
upsampled_depth = upsample_flow((coords1 - coords0), up_mask=up_mask, factor=self.base_downsampling_ratio) upsampled_disparity = upsample_flow(
depth_predictions.append(upsampled_depth[:, :1]) (coords1 - coords0), up_mask=up_mask, factor=self.base_downsampling_ratio
)
disparity_predictions.append(upsampled_disparity[:, :1])
return depth_predictions return disparity_predictions
def _raft_stereo( def _raft_stereo(
...@@ -576,8 +589,8 @@ def _raft_stereo( ...@@ -576,8 +589,8 @@ def _raft_stereo(
motion_encoder=motion_encoder, hidden_dims=update_block_hidden_dims motion_encoder=motion_encoder, hidden_dims=update_block_hidden_dims
) )
# We use the largest scale hidden_dims of update_block to get the predicted depth # We use the largest scale hidden_dims of update_block to get the predicted disparity
depth_head = kwargs.pop("depth_head", None) or FlowHead( disparity_head = kwargs.pop("disparity_head", None) or FlowHead(
in_channels=update_block_hidden_dims[0], in_channels=update_block_hidden_dims[0],
hidden_size=flow_head_hidden_size, hidden_size=flow_head_hidden_size,
) )
...@@ -598,7 +611,7 @@ def _raft_stereo( ...@@ -598,7 +611,7 @@ def _raft_stereo(
corr_pyramid=corr_pyramid, corr_pyramid=corr_pyramid,
corr_block=corr_block, corr_block=corr_block,
update_block=update_block, update_block=update_block,
depth_head=depth_head, disparity_head=disparity_head,
mask_predictor=mask_predictor, mask_predictor=mask_predictor,
slow_fast=slow_fast, slow_fast=slow_fast,
**kwargs, # not really needed, all params should be consumed by now **kwargs, # not really needed, all params should be consumed by now
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment