Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
0c20c455
Commit
0c20c455
authored
Mar 31, 2022
by
Thor Johnsen
Browse files
Some fixes to better support native nhwc
parent
34df0f79
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
97 additions
and
46 deletions
+97
-46
apex/contrib/bottleneck/bottleneck.py
apex/contrib/bottleneck/bottleneck.py
+97
-46
No files found.
apex/contrib/bottleneck/bottleneck.py
View file @
0c20c455
...
@@ -220,7 +220,7 @@ class Bottleneck(torch.nn.Module):
...
@@ -220,7 +220,7 @@ class Bottleneck(torch.nn.Module):
class
SpatialBottleneckFunction
(
torch
.
autograd
.
Function
):
class
SpatialBottleneckFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
spatial_group_size
,
spatial_group_rank
,
spatial_communicator
,
spatial_halo_exchanger
,
spatial_method
,
nhwc
,
stride_1x1
,
scale
,
bias
,
x
,
*
conv
):
def
forward
(
ctx
,
spatial_group_size
,
spatial_group_rank
,
spatial_communicator
,
spatial_halo_exchanger
,
spatial_method
,
explicit_
nhwc
,
stride_1x1
,
scale
,
bias
,
x
,
*
conv
):
if
spatial_group_size
>
1
:
if
spatial_group_size
>
1
:
stream1
=
spatial_halo_exchanger
.
stream1
stream1
=
spatial_halo_exchanger
.
stream1
stream2
=
spatial_halo_exchanger
.
stream2
stream2
=
spatial_halo_exchanger
.
stream2
...
@@ -234,64 +234,91 @@ class SpatialBottleneckFunction(torch.autograd.Function):
...
@@ -234,64 +234,91 @@ class SpatialBottleneckFunction(torch.autograd.Function):
args
.
append
(
scale
[
3
])
args
.
append
(
scale
[
3
])
args
.
append
(
bias
[
3
])
args
.
append
(
bias
[
3
])
# weight buffers are always in nhwc while shape can be nhwc or channels_last
# weight buffers are always in
explicit_
nhwc while shape can be
explicit_
nhwc or channels_last
# here we pass in flag and let c++ handle it
# here we pass in flag and let c++ handle it
# alternatively, we can put all sizes into a fixed format and pass it in
# alternatively, we can put all sizes into a fixed format and pass it in
outputs
=
fast_bottleneck
.
forward_init
(
nhwc
,
stride_1x1
,
args
)
outputs
=
fast_bottleneck
.
forward_init
(
explicit_
nhwc
,
stride_1x1
,
args
)
fast_bottleneck
.
forward_out1
(
nhwc
,
stride_1x1
,
args
,
outputs
)
fast_bottleneck
.
forward_out1
(
explicit_
nhwc
,
stride_1x1
,
args
,
outputs
)
if
spatial_group_size
>
1
:
if
spatial_group_size
>
1
:
out1
=
outputs
[
0
]
out1
=
outputs
[
0
]
# TODO: This assumes explicit nhwc
if
explicit_nhwc
:
N
,
Hs
,
W
,
C
=
list
(
out1
.
shape
)
N
,
Hs
,
W
,
C
=
list
(
out1
.
shape
)
out1_pad
=
torch
.
empty
([
N
,
Hs
+
2
,
W
,
C
],
dtype
=
out1
.
dtype
,
device
=
'cuda'
)
memory_format
=
torch
.
contiguous_format
out1_pad
=
torch
.
empty
([
N
,
Hs
+
2
,
W
,
C
],
dtype
=
out1
.
dtype
,
device
=
'cuda'
)
else
:
N
,
C
,
Hs
,
W
=
list
(
out1
.
shape
)
memory_format
=
torch
.
channels_last
if
out1
.
is_contiguous
(
memory_format
=
torch
.
channels_last
)
else
torch
.
contiguous_format
out1_pad
=
torch
.
empty
([
N
,
C
,
Hs
+
2
,
W
],
dtype
=
out1
.
dtype
,
device
=
'cuda'
,
memory_format
=
memory_format
)
stream1
.
wait_stream
(
torch
.
cuda
.
current_stream
())
stream1
.
wait_stream
(
torch
.
cuda
.
current_stream
())
stream3
.
wait_stream
(
torch
.
cuda
.
current_stream
())
stream3
.
wait_stream
(
torch
.
cuda
.
current_stream
())
with
torch
.
cuda
.
stream
(
stream3
):
with
torch
.
cuda
.
stream
(
stream3
):
out1_pad
[:,
1
:
Hs
+
1
,:,:].
copy_
(
out1
)
if
explicit_nhwc
:
out1_pad
[:,
1
:
Hs
+
1
,:,:].
copy_
(
out1
)
else
:
out1_pad
[:,:,
1
:
Hs
+
1
,:].
copy_
(
out1
)
with
torch
.
cuda
.
stream
(
stream1
):
with
torch
.
cuda
.
stream
(
stream1
):
top_out1_halo
=
out1_pad
[:,:
1
,:,:]
if
explicit_nhwc
:
btm_out1_halo
=
out1_pad
[:,
Hs
+
1
:
Hs
+
2
,:,:]
top_out1_halo
=
out1_pad
[:,:
1
,:,:]
spatial_halo_exchanger
.
left_right_halo_exchange
(
out1
[:,:
1
,:,:],
out1
[:,
Hs
-
1
:,:,:],
top_out1_halo
,
btm_out1_halo
)
btm_out1_halo
=
out1_pad
[:,
Hs
+
1
:
Hs
+
2
,:,:]
spatial_halo_exchanger
.
left_right_halo_exchange
(
out1
[:,:
1
,:,:],
out1
[:,
Hs
-
1
:,:,:],
top_out1_halo
,
btm_out1_halo
)
else
:
top_out1_halo
=
out1_pad
[:,:,:
1
,:]
btm_out1_halo
=
out1_pad
[:,:,
Hs
+
1
:
Hs
+
2
,:]
spatial_halo_exchanger
.
left_right_halo_exchange
(
out1
[:,:,:
1
,:],
out1
[:,:,
Hs
-
1
:,:],
top_out1_halo
,
btm_out1_halo
)
if
spatial_method
==
1
:
if
spatial_method
==
1
:
# overlap mid convolution with halo transfer
# overlap mid convolution with halo transfer
if
spatial_group_rank
<
spatial_group_size
-
1
:
if
spatial_group_rank
<
spatial_group_size
-
1
:
stream2
.
wait_stream
(
stream1
)
stream2
.
wait_stream
(
stream1
)
with
torch
.
cuda
.
stream
(
stream2
):
with
torch
.
cuda
.
stream
(
stream2
):
btm_fat_halo
=
torch
.
empty
((
N
,
3
,
W
,
C
),
dtype
=
out1
.
dtype
,
device
=
out1
.
device
)
btm_fat_halo
=
torch
.
empty
((
N
,
3
,
W
,
C
),
dtype
=
out1
.
dtype
,
device
=
out1
.
device
)
btm_fat_halo
[:,
0
:
2
,:,:].
copy_
(
out1
[:,
Hs
-
2
:,:,:])
if
explicit_nhwc
:
btm_fat_halo
[:,
2
:,:,:].
copy_
(
btm_out1_halo
)
btm_fat_halo
[:,
0
:
2
,:,:].
copy_
(
out1
[:,
Hs
-
2
:,:,:])
btm_out2
=
fast_bottleneck
.
forward_out2_halo
(
nhwc
,
btm_fat_halo
,
args
)
btm_fat_halo
[:,
2
:,:,:].
copy_
(
btm_out1_halo
)
else
:
btm_fat_halo
[:,:,
0
:
2
,:].
copy_
(
out1
[:,
Hs
-
2
:,:,:])
btm_fat_halo
[:,:,
2
:,:].
copy_
(
btm_out1_halo
)
btm_out2
=
fast_bottleneck
.
forward_out2_halo
(
explicit_nhwc
,
btm_fat_halo
,
args
)
if
spatial_group_rank
>
0
:
if
spatial_group_rank
>
0
:
with
torch
.
cuda
.
stream
(
stream1
):
with
torch
.
cuda
.
stream
(
stream1
):
top_fat_halo
=
torch
.
empty
((
N
,
3
,
W
,
C
),
dtype
=
out1
.
dtype
,
device
=
out1
.
device
)
top_fat_halo
=
torch
.
empty
((
N
,
3
,
W
,
C
),
dtype
=
out1
.
dtype
,
device
=
out1
.
device
)
top_fat_halo
[:,:
1
,:,:].
copy_
(
top_out1_halo
)
if
explicit_nhwc
:
top_fat_halo
[:,
1
:
3
,:,:].
copy_
(
out1
[:,:
2
,:,:])
top_fat_halo
[:,:
1
,:,:].
copy_
(
top_out1_halo
)
top_out2
=
fast_bottleneck
.
forward_out2_halo
(
nhwc
,
top_fat_halo
,
args
)
top_fat_halo
[:,
1
:
3
,:,:].
copy_
(
out1
[:,:
2
,:,:])
else
:
top_fat_halo
[:,:,:
1
,:].
copy_
(
top_out1_halo
)
top_fat_halo
[:,:,
1
:
3
,:].
copy_
(
out1
[:,:
2
,:,:])
top_out2
=
fast_bottleneck
.
forward_out2_halo
(
explicit_nhwc
,
top_fat_halo
,
args
)
inc
.
add_delay
(
10
)
inc
.
add_delay
(
10
)
elif
spatial_method
==
2
:
elif
spatial_method
==
2
:
# wait for halo transfer to finish before doing a full convolution of padded x
# wait for halo transfer to finish before doing a full convolution of padded x
torch
.
cuda
.
current_stream
().
wait_stream
(
stream1
)
torch
.
cuda
.
current_stream
().
wait_stream
(
stream1
)
torch
.
cuda
.
current_stream
().
wait_stream
(
stream3
)
torch
.
cuda
.
current_stream
().
wait_stream
(
stream3
)
fast_bottleneck
.
forward_out2_pad
(
nhwc
,
stride_1x1
,
args
,
outputs
,
out1_pad
)
fast_bottleneck
.
forward_out2_pad
(
explicit_
nhwc
,
stride_1x1
,
args
,
outputs
,
out1_pad
)
else
:
else
:
assert
(
False
),
"spatial_method must be 1 or 2"
assert
(
False
),
"spatial_method must be 1 or 2"
if
spatial_group_size
<=
1
or
spatial_method
==
1
:
if
spatial_group_size
<=
1
or
spatial_method
==
1
:
fast_bottleneck
.
forward_out2
(
nhwc
,
stride_1x1
,
args
,
outputs
)
fast_bottleneck
.
forward_out2
(
explicit_
nhwc
,
stride_1x1
,
args
,
outputs
)
# compute halo cells for outputs[1] (out2)
# compute halo cells for outputs[1] (out2)
if
spatial_group_size
>
1
and
spatial_method
==
1
:
if
spatial_group_size
>
1
and
spatial_method
==
1
:
out2
=
outputs
[
1
]
out2
=
outputs
[
1
]
if
spatial_group_rank
>
0
:
if
spatial_group_rank
>
0
:
torch
.
cuda
.
current_stream
().
wait_stream
(
stream1
)
torch
.
cuda
.
current_stream
().
wait_stream
(
stream1
)
out2
[:,:
1
,:,:].
copy_
(
top_out2
)
if
explicit_nhwc
:
out2
[:,:
1
,:,:].
copy_
(
top_out2
)
else
:
out2
[:,:,:
1
,:].
copy_
(
top_out2
)
if
spatial_group_rank
<
spatial_group_size
-
1
:
if
spatial_group_rank
<
spatial_group_size
-
1
:
torch
.
cuda
.
current_stream
().
wait_stream
(
stream2
)
torch
.
cuda
.
current_stream
().
wait_stream
(
stream2
)
out2
[:,
Hs
-
1
:,:,:].
copy_
(
btm_out2
)
if
explicit_nhwc
:
out2
[:,
Hs
-
1
:,:,:].
copy_
(
btm_out2
)
else
:
out2
[:,:,
Hs
-
1
:,:].
copy_
(
btm_out2
)
torch
.
cuda
.
current_stream
().
wait_stream
(
stream3
)
torch
.
cuda
.
current_stream
().
wait_stream
(
stream3
)
fast_bottleneck
.
forward_rest
(
nhwc
,
stride_1x1
,
args
,
outputs
)
fast_bottleneck
.
forward_rest
(
explicit_
nhwc
,
stride_1x1
,
args
,
outputs
)
# save halos for backward pass
# save halos for backward pass
if
spatial_group_size
>
1
:
if
spatial_group_size
>
1
:
torch
.
cuda
.
current_stream
().
wait_stream
(
stream3
)
torch
.
cuda
.
current_stream
().
wait_stream
(
stream3
)
...
@@ -299,7 +326,7 @@ class SpatialBottleneckFunction(torch.autograd.Function):
...
@@ -299,7 +326,7 @@ class SpatialBottleneckFunction(torch.autograd.Function):
else
:
else
:
ctx
.
save_for_backward
(
*
(
args
+
outputs
))
ctx
.
save_for_backward
(
*
(
args
+
outputs
))
# save relu outputs for drelu
# save relu outputs for drelu
ctx
.
nhwc
=
nhwc
ctx
.
explicit_nhwc
=
explicit_
nhwc
ctx
.
stride_1x1
=
stride_1x1
ctx
.
stride_1x1
=
stride_1x1
ctx
.
spatial_group_size
=
spatial_group_size
ctx
.
spatial_group_size
=
spatial_group_size
if
spatial_group_size
>
1
:
if
spatial_group_size
>
1
:
...
@@ -339,8 +366,8 @@ class SpatialBottleneckFunction(torch.autograd.Function):
...
@@ -339,8 +366,8 @@ class SpatialBottleneckFunction(torch.autograd.Function):
if
ctx
.
downsample
:
if
ctx
.
downsample
:
t_list
.
append
(
ctx
.
saved_tensors
[
10
])
t_list
.
append
(
ctx
.
saved_tensors
[
10
])
grads
=
fast_bottleneck
.
backward_init
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
)
grads
=
fast_bottleneck
.
backward_init
(
ctx
.
explicit_
nhwc
,
ctx
.
stride_1x1
,
t_list
)
grad_out2
=
fast_bottleneck
.
backward_grad_out2
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
)
grad_out2
=
fast_bottleneck
.
backward_grad_out2
(
ctx
.
explicit_
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
)
# do halo exchange of grad_out2 here
# do halo exchange of grad_out2 here
# compute halo cells for grad_out1
# compute halo cells for grad_out1
if
ctx
.
spatial_group_size
>
1
:
if
ctx
.
spatial_group_size
>
1
:
...
@@ -355,48 +382,66 @@ class SpatialBottleneckFunction(torch.autograd.Function):
...
@@ -355,48 +382,66 @@ class SpatialBottleneckFunction(torch.autograd.Function):
with
torch
.
cuda
.
stream
(
ctx
.
stream2
):
with
torch
.
cuda
.
stream
(
ctx
.
stream2
):
btm_fat_halo
=
torch
.
empty
((
N
,
3
,
W
,
C
),
dtype
=
grad_out2
.
dtype
,
device
=
grad_out2
.
device
)
btm_fat_halo
=
torch
.
empty
((
N
,
3
,
W
,
C
),
dtype
=
grad_out2
.
dtype
,
device
=
grad_out2
.
device
)
btm_relu_halo
=
torch
.
empty
((
N
,
3
,
W
,
C
),
dtype
=
grad_out2
.
dtype
,
device
=
grad_out2
.
device
)
btm_relu_halo
=
torch
.
empty
((
N
,
3
,
W
,
C
),
dtype
=
grad_out2
.
dtype
,
device
=
grad_out2
.
device
)
btm_fat_halo
[:,:
2
,:,:].
copy_
(
grad_out2
[:,
Hs
-
2
:,:,:])
if
ctx
.
explicit_nhwc
:
btm_fat_halo
[:,
2
:,:,:].
copy_
(
btm_halo
)
btm_fat_halo
[:,:
2
,:,:].
copy_
(
grad_out2
[:,
Hs
-
2
:,:,:])
btm_relu_halo
[:,:
2
,:,:].
copy_
(
relu1
[:,
Hs
-
2
:,:,:])
btm_fat_halo
[:,
2
:,:,:].
copy_
(
btm_halo
)
btm_relu_halo
[:,
2
:,:,:].
zero_
()
btm_relu_halo
[:,:
2
,:,:].
copy_
(
relu1
[:,
Hs
-
2
:,:,:])
btm_grad_out1_halo
=
fast_bottleneck
.
backward_grad_out1_halo
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
btm_fat_halo
,
btm_relu_halo
)
btm_relu_halo
[:,
2
:,:,:].
zero_
()
btm_grad_out1_halo
=
btm_grad_out1_halo
[:,
1
:
2
,:,:]
else
:
btm_fat_halo
[:,:,:
2
,:].
copy_
(
grad_out2
[:,
Hs
-
2
:,:,:])
btm_fat_halo
[:,:,
2
:,:].
copy_
(
btm_halo
)
btm_relu_halo
[:,:,:
2
,:].
copy_
(
relu1
[:,
Hs
-
2
:,:,:])
btm_relu_halo
[:,:,
2
:,:].
zero_
()
btm_grad_out1_halo
=
fast_bottleneck
.
backward_grad_out1_halo
(
ctx
.
explicit_nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
btm_fat_halo
,
btm_relu_halo
)
if
ctx
.
explicit_nhwc
:
btm_grad_out1_halo
=
btm_grad_out1_halo
[:,
1
:
2
,:,:]
else
:
btm_grad_out1_halo
=
btm_grad_out1_halo
[:,:,
1
:
2
,:]
if
ctx
.
spatial_group_rank
>
0
:
if
ctx
.
spatial_group_rank
>
0
:
with
torch
.
cuda
.
stream
(
ctx
.
stream1
):
with
torch
.
cuda
.
stream
(
ctx
.
stream1
):
top_fat_halo
=
torch
.
empty
((
N
,
3
,
W
,
C
),
dtype
=
grad_out2
.
dtype
,
device
=
grad_out2
.
device
)
top_fat_halo
=
torch
.
empty
((
N
,
3
,
W
,
C
),
dtype
=
grad_out2
.
dtype
,
device
=
grad_out2
.
device
)
top_relu_halo
=
torch
.
empty
((
N
,
3
,
W
,
C
),
dtype
=
grad_out2
.
dtype
,
device
=
grad_out2
.
device
)
top_relu_halo
=
torch
.
empty
((
N
,
3
,
W
,
C
),
dtype
=
grad_out2
.
dtype
,
device
=
grad_out2
.
device
)
top_fat_halo
[:,:
1
,:,:].
copy_
(
top_halo
)
if
ctx
.
explicit_nhwc
:
top_fat_halo
[:,
1
:,:,:].
copy_
(
grad_out2
[:,:
2
,:,:])
top_fat_halo
[:,:
1
,:,:].
copy_
(
top_halo
)
top_relu_halo
[:,:
1
,:,:].
zero_
()
top_fat_halo
[:,
1
:,:,:].
copy_
(
grad_out2
[:,:
2
,:,:])
top_relu_halo
[:,
1
:,:,:].
copy_
(
relu1
[:,:
2
,:,:])
top_relu_halo
[:,:
1
,:,:].
zero_
()
top_grad_out1_halo
=
fast_bottleneck
.
backward_grad_out1_halo
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
top_fat_halo
,
top_relu_halo
)
top_relu_halo
[:,
1
:,:,:].
copy_
(
relu1
[:,:
2
,:,:])
top_grad_out1_halo
=
top_grad_out1_halo
[:,
1
:
2
,:,:]
else
:
top_fat_halo
[:,:,:
1
,:].
copy_
(
top_halo
)
top_fat_halo
[:,:,
1
:,:].
copy_
(
grad_out2
[:,:
2
,:,:])
top_relu_halo
[:,:,:
1
,:].
zero_
()
top_relu_halo
[:,:,
1
:,:].
copy_
(
relu1
[:,:
2
,:,:])
top_grad_out1_halo
=
fast_bottleneck
.
backward_grad_out1_halo
(
ctx
.
explicit_nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
top_fat_halo
,
top_relu_halo
)
if
ctx
.
explicit_nhwc
:
top_grad_out1_halo
=
top_grad_out1_halo
[:,
1
:
2
,:,:]
else
:
top_grad_out1_halo
=
top_grad_out1_halo
[:,:,
1
:
2
,:]
inc
.
add_delay
(
10
)
inc
.
add_delay
(
10
)
wgrad2_stream
=
torch
.
cuda
.
Stream
()
wgrad2_stream
=
torch
.
cuda
.
Stream
()
wgrad2_stream
.
wait_stream
(
torch
.
cuda
.
current_stream
())
wgrad2_stream
.
wait_stream
(
torch
.
cuda
.
current_stream
())
with
torch
.
cuda
.
stream
(
wgrad2_stream
):
with
torch
.
cuda
.
stream
(
wgrad2_stream
):
if
ctx
.
spatial_group_size
>
1
:
if
ctx
.
spatial_group_size
>
1
:
wgrad2
=
fast_bottleneck
.
backward_wgrad2_pad
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
out1_pad
,
grad_out2
)
wgrad2
=
fast_bottleneck
.
backward_wgrad2_pad
(
ctx
.
explicit_
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
out1_pad
,
grad_out2
)
else
:
else
:
wgrad2
=
fast_bottleneck
.
backward_wgrad2
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
grad_out2
)
wgrad2
=
fast_bottleneck
.
backward_wgrad2
(
ctx
.
explicit_
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
grad_out2
)
# compute wgrad2 for internal cells
# compute wgrad2 for internal cells
#wgrad2 = fast_bottleneck.backward_wgrad2(ctx.nhwc, ctx.stride_1x1, t_list, grads, grad_out2)
#wgrad2 = fast_bottleneck.backward_wgrad2(ctx.
explicit_
nhwc, ctx.stride_1x1, t_list, grads, grad_out2)
# apply wgrad2 halos
# apply wgrad2 halos
#if ctx.spatial_group_size > 1:
#if ctx.spatial_group_size > 1:
# if ctx.spatial_group_rank > 0:
# if ctx.spatial_group_rank > 0:
# top_grad2_halo = grad_out2[:,:1,:,:]
# top_grad2_halo = grad_out2[:,:1,:,:]
# top_wgrad2_halo = fast_bottleneck.backward_wgrad2_halo(ctx.nhwc, ctx.stride_1x1, t_list, grads, top_out1_halo, top_grad2_halo)
# top_wgrad2_halo = fast_bottleneck.backward_wgrad2_halo(ctx.
explicit_
nhwc, ctx.stride_1x1, t_list, grads, top_out1_halo, top_grad2_halo)
# wgrad2[:,:1,:,:].add_(top_wgrad2_halo)
# wgrad2[:,:1,:,:].add_(top_wgrad2_halo)
# if ctx.spatial_group_rank < ctx.spatial_group_size-1:
# if ctx.spatial_group_rank < ctx.spatial_group_size-1:
# btm_grad2_halo = grad_out2[:,-1:,:,:]
# btm_grad2_halo = grad_out2[:,-1:,:,:]
# btm_wgrad2_halo = fast_bottleneck.backward_wgrad2_halo(ctx.nhwc, ctx.stride_1x1, t_list, grads, btm_out1_halo, btm_grad2_halo)
# btm_wgrad2_halo = fast_bottleneck.backward_wgrad2_halo(ctx.
explicit_
nhwc, ctx.stride_1x1, t_list, grads, btm_out1_halo, btm_grad2_halo)
# wgrad2[:,-1:,:,:].add_(btm_wgrad2_halo)
# wgrad2[:,-1:,:,:].add_(btm_wgrad2_halo)
# compute grad_out1 for internal cells
# compute grad_out1 for internal cells
grad_out1
=
fast_bottleneck
.
backward_grad_out1
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
grad_out2
)
grad_out1
=
fast_bottleneck
.
backward_grad_out1
(
ctx
.
explicit_
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
grad_out2
)
# apply halo cells to grad_out1
# apply halo cells to grad_out1
if
ctx
.
spatial_group_size
>
1
:
if
ctx
.
spatial_group_size
>
1
:
...
@@ -406,14 +451,20 @@ class SpatialBottleneckFunction(torch.autograd.Function):
...
@@ -406,14 +451,20 @@ class SpatialBottleneckFunction(torch.autograd.Function):
#print("w.shape = %s, z.shape = %s, relu1.shape = %s" % (str(list(w.shape)), str(list(z.shape)), str(list(relu1.shape))))
#print("w.shape = %s, z.shape = %s, relu1.shape = %s" % (str(list(w.shape)), str(list(z.shape)), str(list(relu1.shape))))
if
ctx
.
spatial_group_rank
>
0
:
if
ctx
.
spatial_group_rank
>
0
:
torch
.
cuda
.
current_stream
().
wait_stream
(
ctx
.
stream1
)
torch
.
cuda
.
current_stream
().
wait_stream
(
ctx
.
stream1
)
grad_out1
[:,:
1
,:,:].
copy_
(
top_grad_out1_halo
)
if
ctx
.
explicit_nhwc
:
grad_out1
[:,:
1
,:,:].
copy_
(
top_grad_out1_halo
)
else
:
grad_out1
[:,:,:
1
,:].
copy_
(
top_grad_out1_halo
)
#print("ctx.spatial_group_rank = %d, apply grad_out1 top halo (grad_out1.shape = %s)" % (ctx.spatial_group_rank, str(list(grad_out1.shape))))
#print("ctx.spatial_group_rank = %d, apply grad_out1 top halo (grad_out1.shape = %s)" % (ctx.spatial_group_rank, str(list(grad_out1.shape))))
if
ctx
.
spatial_group_rank
<
ctx
.
spatial_group_size
-
1
:
if
ctx
.
spatial_group_rank
<
ctx
.
spatial_group_size
-
1
:
torch
.
cuda
.
current_stream
().
wait_stream
(
ctx
.
stream2
)
torch
.
cuda
.
current_stream
().
wait_stream
(
ctx
.
stream2
)
grad_out1
[:,
Hs
-
1
:,:,:].
copy_
(
btm_grad_out1_halo
)
if
ctx
.
explicit_nhwc
:
grad_out1
[:,
Hs
-
1
:,:,:].
copy_
(
btm_grad_out1_halo
)
else
:
grad_out1
[:,:,
Hs
-
1
:,:].
copy_
(
btm_grad_out1_halo
)
#print("ctx.spatial_group_rank = %d, apply grad_out1 btm halo (grad_out1.shape = %s)" % (ctx.spatial_group_rank, str(list(grad_out1.shape))))
#print("ctx.spatial_group_rank = %d, apply grad_out1 btm halo (grad_out1.shape = %s)" % (ctx.spatial_group_rank, str(list(grad_out1.shape))))
fast_bottleneck
.
backward_rest
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
grad_out2
,
grad_out1
,
wgrad2
)
fast_bottleneck
.
backward_rest
(
ctx
.
explicit_
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
grad_out2
,
grad_out1
,
wgrad2
)
torch
.
cuda
.
current_stream
().
wait_stream
(
wgrad2_stream
)
torch
.
cuda
.
current_stream
().
wait_stream
(
wgrad2_stream
)
return
(
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
*
grads
)
return
(
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
*
grads
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment