Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
7a62a545
"...text-generation-inference.git" did not exist on "678b2f39000f638e0099af0d84a98d409feca428"
Unverified
Commit
7a62a545
authored
Oct 19, 2022
by
YosuaMichael
Committed by
GitHub
Oct 19, 2022
Browse files
Some fixes for crestereo (#6791)
parent
78fdaf3a
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
8 additions
and
5 deletions
+8
-5
torchvision/prototype/models/depth/stereo/crestereo.py
torchvision/prototype/models/depth/stereo/crestereo.py
+8
-5
No files found.
torchvision/prototype/models/depth/stereo/crestereo.py
View file @
7a62a545
...
@@ -763,7 +763,7 @@ class CREStereo(nn.Module):
...
@@ -763,7 +763,7 @@ class CREStereo(nn.Module):
return
"1d"
if
iteration
%
2
==
0
else
"2d"
return
"1d"
if
iteration
%
2
==
0
else
"2d"
def
forward
(
def
forward
(
self
,
left_image
:
Tensor
,
right_image
:
Tensor
,
flow_init
:
Optional
[
Tensor
],
num_iters
:
int
=
10
self
,
left_image
:
Tensor
,
right_image
:
Tensor
,
flow_init
:
Optional
[
Tensor
]
=
None
,
num_iters
:
int
=
10
)
->
List
[
Tensor
]:
)
->
List
[
Tensor
]:
features
=
torch
.
cat
([
left_image
,
right_image
],
dim
=
0
)
features
=
torch
.
cat
([
left_image
,
right_image
],
dim
=
0
)
features
=
self
.
feature_encoder
(
features
)
features
=
self
.
feature_encoder
(
features
)
...
@@ -781,10 +781,10 @@ class CREStereo(nn.Module):
...
@@ -781,10 +781,10 @@ class CREStereo(nn.Module):
ctx_pyramid
=
self
.
downsampling_pyramid
(
ctx
)
ctx_pyramid
=
self
.
downsampling_pyramid
(
ctx
)
# we store in reversed order because we process the pyramid from top to bottom
# we store in reversed order because we process the pyramid from top to bottom
l_pyramid
:
Dict
[
str
,
Tensor
]
=
{
res
:
l_pyramid
[
idx
]
for
idx
,
res
in
enumerate
(
self
.
resolutions
)}
l_pyramid
=
{
res
:
l_pyramid
[
idx
]
for
idx
,
res
in
enumerate
(
self
.
resolutions
)}
r_pyramid
:
Dict
[
str
,
Tensor
]
=
{
res
:
r_pyramid
[
idx
]
for
idx
,
res
in
enumerate
(
self
.
resolutions
)}
r_pyramid
=
{
res
:
r_pyramid
[
idx
]
for
idx
,
res
in
enumerate
(
self
.
resolutions
)}
net_pyramid
:
Dict
[
str
,
Tensor
]
=
{
res
:
net_pyramid
[
idx
]
for
idx
,
res
in
enumerate
(
self
.
resolutions
)}
net_pyramid
=
{
res
:
net_pyramid
[
idx
]
for
idx
,
res
in
enumerate
(
self
.
resolutions
)}
ctx_pyramid
:
Dict
[
str
,
Tensor
]
=
{
res
:
ctx_pyramid
[
idx
]
for
idx
,
res
in
enumerate
(
self
.
resolutions
)}
ctx_pyramid
=
{
res
:
ctx_pyramid
[
idx
]
for
idx
,
res
in
enumerate
(
self
.
resolutions
)}
# offsets for sampling pixel candidates in the correlation ops
# offsets for sampling pixel candidates in the correlation ops
offsets
:
Dict
[
str
,
Tensor
]
=
{}
offsets
:
Dict
[
str
,
Tensor
]
=
{}
...
@@ -1425,6 +1425,9 @@ def crestereo_base(*, weights: Optional[CREStereo_Base_Weights] = None, progress
...
@@ -1425,6 +1425,9 @@ def crestereo_base(*, weights: Optional[CREStereo_Base_Weights] = None, progress
.. autoclass:: torchvision.prototype.models.depth.stereo.CREStereo_Base_Weights
.. autoclass:: torchvision.prototype.models.depth.stereo.CREStereo_Base_Weights
:members:
:members:
"""
"""
weights
=
CREStereo_Base_Weights
.
verify
(
weights
)
return
_crestereo
(
return
_crestereo
(
weights
=
weights
,
weights
=
weights
,
progress
=
progress
,
progress
=
progress
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment