Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
4aeb24cb
Commit
4aeb24cb
authored
Apr 07, 2022
by
Thor Johnsen
Browse files
Bug fix
parent
c70f0e32
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
7 additions
and
6 deletions
+7
-6
apex/contrib/bottleneck/bottleneck.py
apex/contrib/bottleneck/bottleneck.py
+7
-6
No files found.
apex/contrib/bottleneck/bottleneck.py
View file @
4aeb24cb
...
...
@@ -257,7 +257,7 @@ class Bottleneck(torch.nn.Module):
class
SpatialBottleneckFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
spatial_group_size
,
spatial_group_rank
,
spatial_communicator
,
spatial_halo_exchanger
,
spatial_method
,
explicit_nhwc
,
stride_1x1
,
scale
,
bias
,
thresholdTop
,
thresholdBottom
,
x
,
*
conv
):
def
forward
(
ctx
,
spatial_group_size
,
spatial_group_rank
,
spatial_communicator
,
spatial_halo_exchanger
,
spatial_method
,
use_delay_kernel
,
explicit_nhwc
,
stride_1x1
,
scale
,
bias
,
thresholdTop
,
thresholdBottom
,
x
,
*
conv
):
if
spatial_group_size
>
1
:
stream1
=
spatial_halo_exchanger
.
stream1
stream2
=
spatial_halo_exchanger
.
stream2
...
...
@@ -328,7 +328,7 @@ class SpatialBottleneckFunction(torch.autograd.Function):
top_fat_halo
[:,:,:
1
,:].
copy_
(
top_out1_halo
)
top_fat_halo
[:,:,
1
:
3
,:].
copy_
(
out1
[:,:,:
2
,:])
top_out2
=
fast_bottleneck
.
forward_out2_halo
(
explicit_nhwc
,
top_fat_halo
,
args
)
#
inc.add_delay(10)
if
use_delay_kernel
:
inc
.
add_delay
(
10
)
elif
spatial_method
!=
2
and
spatial_method
!=
3
:
assert
(
False
),
"spatial_method must be 1, 2 or 3"
...
...
@@ -416,6 +416,7 @@ class SpatialBottleneckFunction(torch.autograd.Function):
ctx
.
spatial_group_rank
=
spatial_group_rank
ctx
.
spatial_halo_exchanger
=
spatial_halo_exchanger
ctx
.
spatial_method
=
spatial_method
ctx
.
use_delay_kernel
=
use_delay_kernel
ctx
.
thresholdTop
=
thresholdTop
ctx
.
thresholdBottom
=
thresholdBottom
ctx
.
stream1
=
stream1
...
...
@@ -518,7 +519,7 @@ class SpatialBottleneckFunction(torch.autograd.Function):
top_grad_out1_halo
=
top_grad_out1_halo
[:,
1
:
2
,:,:]
else
:
top_grad_out1_halo
=
top_grad_out1_halo
[:,:,
1
:
2
,:]
#
inc.add_delay(10)
if
ctx
.
use_delay_kernel
:
inc
.
add_delay
(
10
)
elif
ctx
.
spatial_method
!=
3
:
assert
(
False
),
"spatial_method must be 1, 2 or 3"
...
...
@@ -583,7 +584,7 @@ class SpatialBottleneckFunction(torch.autograd.Function):
fast_bottleneck
.
backward_rest
(
ctx
.
explicit_nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
grad_out2
,
grad_out1
,
wgrad2
)
torch
.
cuda
.
current_stream
().
wait_stream
(
wgrad2_stream
)
return
(
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
*
grads
)
return
(
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
*
grads
)
spatial_bottleneck_function
=
SpatialBottleneckFunction
.
apply
...
...
@@ -655,7 +656,7 @@ class SpatialBottleneck(torch.nn.Module):
# spatial communicator
if
spatial_parallel_args
is
None
:
self
.
spatial_parallel_args
=
(
1
,
0
,
None
,
None
,
0
)
self
.
spatial_parallel_args
=
(
1
,
0
,
None
,
None
,
0
,
False
)
else
:
self
.
spatial_parallel_args
=
spatial_parallel_args
return
...
...
@@ -684,7 +685,7 @@ class SpatialBottleneck(torch.nn.Module):
def
forward
(
self
,
x
):
if
self
.
use_cudnn
:
if
self
.
thresholdTop
is
None
:
spatial_group_size
,
spatial_group_rank
,
_
,
_
,
_
=
self
.
spatial_parallel_args
spatial_group_size
,
spatial_group_rank
,
_
,
_
,
_
,
_
=
self
.
spatial_parallel_args
if
self
.
explicit_nhwc
:
N
,
H
,
W
,
C
=
list
(
x
.
shape
)
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment