Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
9e295728
Commit
9e295728
authored
Sep 02, 2021
by
Thor Johnsen
Browse files
Bug fix in wgrad
parent
8c4a0075
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
8 additions
and
7 deletions
+8
-7
apex/contrib/bottleneck/bottleneck.py
apex/contrib/bottleneck/bottleneck.py
+8
-7
No files found.
apex/contrib/bottleneck/bottleneck.py
View file @
9e295728
...
@@ -318,13 +318,14 @@ class SpatialBottleneckFunction(torch.autograd.Function):
...
@@ -318,13 +318,14 @@ class SpatialBottleneckFunction(torch.autograd.Function):
# apply wgrad2 halos
# apply wgrad2 halos
if
ctx
.
spatial_group_size
>
1
:
if
ctx
.
spatial_group_size
>
1
:
top_grad2_halo
=
grad_out2
[:,:
1
,:,:]
if
ctx
.
local_rank
>
0
:
btm_grad2_halo
=
grad_out2
[:,
-
1
:,:,:]
top_grad2_halo
=
grad_out2
[:,:
1
,:,:]
top_wgrad2_halo
=
fast_bottleneck
.
backward_wgrad2_halo
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
top_out1_halo
,
top_grad2_halo
)
top_wgrad2_halo
=
fast_bottleneck
.
backward_wgrad2_halo
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
top_out1_halo
,
top_grad2_halo
)
btm_wgrad2_halo
=
fast_bottleneck
.
backward_wgrad2_halo
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
btm_out1_halo
,
btm_grad2_halo
)
wgrad2
[:,:
1
,:,:].
add_
(
top_wgrad2_halo
)
#print("wgrad2.shape = %s, top_wgrad2_halo.shape = %s, btm_wgrad2_halo = %s" % (str(list(wgrad2.shape)), str(list(top_wgrad2_halo.shape)), str(list(btm_wgrad2_halo.shape))))
if
ctx
.
local_rank
<
ctx
.
spatial_group_size
-
1
:
wgrad2
[:,:
1
,:,:].
add_
(
top_wgrad2_halo
)
btm_grad2_halo
=
grad_out2
[:,
-
1
:,:,:]
wgrad2
[:,
-
1
:,:,:].
add_
(
btm_wgrad2_halo
)
btm_wgrad2_halo
=
fast_bottleneck
.
backward_wgrad2_halo
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
btm_out1_halo
,
btm_grad2_halo
)
wgrad2
[:,
-
1
:,:,:].
add_
(
btm_wgrad2_halo
)
# do halo exchange of grad_out2 here
# do halo exchange of grad_out2 here
# compute halo cells for grad_out1
# compute halo cells for grad_out1
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment