"...text-generation-inference.git" did not exist on "678b2f39000f638e0099af0d84a98d409feca428"
Unverified Commit 7a62a545 authored by YosuaMichael's avatar YosuaMichael Committed by GitHub
Browse files

Some fixes for crestereo (#6791)

parent 78fdaf3a
...@@ -763,7 +763,7 @@ class CREStereo(nn.Module): ...@@ -763,7 +763,7 @@ class CREStereo(nn.Module):
return "1d" if iteration % 2 == 0 else "2d" return "1d" if iteration % 2 == 0 else "2d"
def forward( def forward(
self, left_image: Tensor, right_image: Tensor, flow_init: Optional[Tensor], num_iters: int = 10 self, left_image: Tensor, right_image: Tensor, flow_init: Optional[Tensor] = None, num_iters: int = 10
) -> List[Tensor]: ) -> List[Tensor]:
features = torch.cat([left_image, right_image], dim=0) features = torch.cat([left_image, right_image], dim=0)
features = self.feature_encoder(features) features = self.feature_encoder(features)
...@@ -781,10 +781,10 @@ class CREStereo(nn.Module): ...@@ -781,10 +781,10 @@ class CREStereo(nn.Module):
ctx_pyramid = self.downsampling_pyramid(ctx) ctx_pyramid = self.downsampling_pyramid(ctx)
# we store in reversed order because we process the pyramid from top to bottom # we store in reversed order because we process the pyramid from top to bottom
l_pyramid: Dict[str, Tensor] = {res: l_pyramid[idx] for idx, res in enumerate(self.resolutions)} l_pyramid = {res: l_pyramid[idx] for idx, res in enumerate(self.resolutions)}
r_pyramid: Dict[str, Tensor] = {res: r_pyramid[idx] for idx, res in enumerate(self.resolutions)} r_pyramid = {res: r_pyramid[idx] for idx, res in enumerate(self.resolutions)}
net_pyramid: Dict[str, Tensor] = {res: net_pyramid[idx] for idx, res in enumerate(self.resolutions)} net_pyramid = {res: net_pyramid[idx] for idx, res in enumerate(self.resolutions)}
ctx_pyramid: Dict[str, Tensor] = {res: ctx_pyramid[idx] for idx, res in enumerate(self.resolutions)} ctx_pyramid = {res: ctx_pyramid[idx] for idx, res in enumerate(self.resolutions)}
# offsets for sampling pixel candidates in the correlation ops # offsets for sampling pixel candidates in the correlation ops
offsets: Dict[str, Tensor] = {} offsets: Dict[str, Tensor] = {}
...@@ -1425,6 +1425,9 @@ def crestereo_base(*, weights: Optional[CREStereo_Base_Weights] = None, progress ...@@ -1425,6 +1425,9 @@ def crestereo_base(*, weights: Optional[CREStereo_Base_Weights] = None, progress
.. autoclass:: torchvision.prototype.models.depth.stereo.CREStereo_Base_Weights .. autoclass:: torchvision.prototype.models.depth.stereo.CREStereo_Base_Weights
:members: :members:
""" """
weights = CREStereo_Base_Weights.verify(weights)
return _crestereo( return _crestereo(
weights=weights, weights=weights,
progress=progress, progress=progress,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment