Unverified Commit 7a62a545 authored by YosuaMichael's avatar YosuaMichael Committed by GitHub
Browse files

Some fixes for crestereo (#6791)

parent 78fdaf3a
......@@ -763,7 +763,7 @@ class CREStereo(nn.Module):
return "1d" if iteration % 2 == 0 else "2d"
def forward(
self, left_image: Tensor, right_image: Tensor, flow_init: Optional[Tensor], num_iters: int = 10
self, left_image: Tensor, right_image: Tensor, flow_init: Optional[Tensor] = None, num_iters: int = 10
) -> List[Tensor]:
features = torch.cat([left_image, right_image], dim=0)
features = self.feature_encoder(features)
......@@ -781,10 +781,10 @@ class CREStereo(nn.Module):
ctx_pyramid = self.downsampling_pyramid(ctx)
# we store in reversed order because we process the pyramid from top to bottom
l_pyramid: Dict[str, Tensor] = {res: l_pyramid[idx] for idx, res in enumerate(self.resolutions)}
r_pyramid: Dict[str, Tensor] = {res: r_pyramid[idx] for idx, res in enumerate(self.resolutions)}
net_pyramid: Dict[str, Tensor] = {res: net_pyramid[idx] for idx, res in enumerate(self.resolutions)}
ctx_pyramid: Dict[str, Tensor] = {res: ctx_pyramid[idx] for idx, res in enumerate(self.resolutions)}
l_pyramid = {res: l_pyramid[idx] for idx, res in enumerate(self.resolutions)}
r_pyramid = {res: r_pyramid[idx] for idx, res in enumerate(self.resolutions)}
net_pyramid = {res: net_pyramid[idx] for idx, res in enumerate(self.resolutions)}
ctx_pyramid = {res: ctx_pyramid[idx] for idx, res in enumerate(self.resolutions)}
# offsets for sampling pixel candidates in the correlation ops
offsets: Dict[str, Tensor] = {}
......@@ -1425,6 +1425,9 @@ def crestereo_base(*, weights: Optional[CREStereo_Base_Weights] = None, progress
.. autoclass:: torchvision.prototype.models.depth.stereo.CREStereo_Base_Weights
:members:
"""
weights = CREStereo_Base_Weights.verify(weights)
return _crestereo(
weights=weights,
progress=progress,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment