Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
0c20c455
"src/vscode:/vscode.git/clone" did not exist on "8d36d5adb1edb8eaaa40a29ef5510f51c503f19e"
Commit
0c20c455
authored
Mar 31, 2022
by
Thor Johnsen
Browse files
Some fixes to better support native nhwc
parent
34df0f79
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
97 additions
and
46 deletions
+97
-46
apex/contrib/bottleneck/bottleneck.py
apex/contrib/bottleneck/bottleneck.py
+97
-46
No files found.
apex/contrib/bottleneck/bottleneck.py
View file @
0c20c455
...
@@ -220,7 +220,7 @@ class Bottleneck(torch.nn.Module):
...
@@ -220,7 +220,7 @@ class Bottleneck(torch.nn.Module):
class
SpatialBottleneckFunction
(
torch
.
autograd
.
Function
):
class
SpatialBottleneckFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
spatial_group_size
,
spatial_group_rank
,
spatial_communicator
,
spatial_halo_exchanger
,
spatial_method
,
nhwc
,
stride_1x1
,
scale
,
bias
,
x
,
*
conv
):
def
forward
(
ctx
,
spatial_group_size
,
spatial_group_rank
,
spatial_communicator
,
spatial_halo_exchanger
,
spatial_method
,
explicit_
nhwc
,
stride_1x1
,
scale
,
bias
,
x
,
*
conv
):
if
spatial_group_size
>
1
:
if
spatial_group_size
>
1
:
stream1
=
spatial_halo_exchanger
.
stream1
stream1
=
spatial_halo_exchanger
.
stream1
stream2
=
spatial_halo_exchanger
.
stream2
stream2
=
spatial_halo_exchanger
.
stream2
...
@@ -234,64 +234,91 @@ class SpatialBottleneckFunction(torch.autograd.Function):
...
@@ -234,64 +234,91 @@ class SpatialBottleneckFunction(torch.autograd.Function):
args
.
append
(
scale
[
3
])
args
.
append
(
scale
[
3
])
args
.
append
(
bias
[
3
])
args
.
append
(
bias
[
3
])
# weight buffers are always in nhwc while shape can be nhwc or channels_last
# weight buffers are always in
explicit_
nhwc while shape can be
explicit_
nhwc or channels_last
# here we pass in flag and let c++ handle it
# here we pass in flag and let c++ handle it
# alternatively, we can put all sizes into a fixed format and pass it in
# alternatively, we can put all sizes into a fixed format and pass it in
outputs
=
fast_bottleneck
.
forward_init
(
nhwc
,
stride_1x1
,
args
)
outputs
=
fast_bottleneck
.
forward_init
(
explicit_
nhwc
,
stride_1x1
,
args
)
fast_bottleneck
.
forward_out1
(
nhwc
,
stride_1x1
,
args
,
outputs
)
fast_bottleneck
.
forward_out1
(
explicit_
nhwc
,
stride_1x1
,
args
,
outputs
)
if
spatial_group_size
>
1
:
if
spatial_group_size
>
1
:
out1
=
outputs
[
0
]
out1
=
outputs
[
0
]
# TODO: This assumes
explicit
nhwc
if
explicit
_
nhwc
:
N
,
Hs
,
W
,
C
=
list
(
out1
.
shape
)
N
,
Hs
,
W
,
C
=
list
(
out1
.
shape
)
memory_format
=
torch
.
contiguous_format
out1_pad
=
torch
.
empty
([
N
,
Hs
+
2
,
W
,
C
],
dtype
=
out1
.
dtype
,
device
=
'cuda'
)
out1_pad
=
torch
.
empty
([
N
,
Hs
+
2
,
W
,
C
],
dtype
=
out1
.
dtype
,
device
=
'cuda'
)
else
:
N
,
C
,
Hs
,
W
=
list
(
out1
.
shape
)
memory_format
=
torch
.
channels_last
if
out1
.
is_contiguous
(
memory_format
=
torch
.
channels_last
)
else
torch
.
contiguous_format
out1_pad
=
torch
.
empty
([
N
,
C
,
Hs
+
2
,
W
],
dtype
=
out1
.
dtype
,
device
=
'cuda'
,
memory_format
=
memory_format
)
stream1
.
wait_stream
(
torch
.
cuda
.
current_stream
())
stream1
.
wait_stream
(
torch
.
cuda
.
current_stream
())
stream3
.
wait_stream
(
torch
.
cuda
.
current_stream
())
stream3
.
wait_stream
(
torch
.
cuda
.
current_stream
())
with
torch
.
cuda
.
stream
(
stream3
):
with
torch
.
cuda
.
stream
(
stream3
):
if
explicit_nhwc
:
out1_pad
[:,
1
:
Hs
+
1
,:,:].
copy_
(
out1
)
out1_pad
[:,
1
:
Hs
+
1
,:,:].
copy_
(
out1
)
else
:
out1_pad
[:,:,
1
:
Hs
+
1
,:].
copy_
(
out1
)
with
torch
.
cuda
.
stream
(
stream1
):
with
torch
.
cuda
.
stream
(
stream1
):
if
explicit_nhwc
:
top_out1_halo
=
out1_pad
[:,:
1
,:,:]
top_out1_halo
=
out1_pad
[:,:
1
,:,:]
btm_out1_halo
=
out1_pad
[:,
Hs
+
1
:
Hs
+
2
,:,:]
btm_out1_halo
=
out1_pad
[:,
Hs
+
1
:
Hs
+
2
,:,:]
spatial_halo_exchanger
.
left_right_halo_exchange
(
out1
[:,:
1
,:,:],
out1
[:,
Hs
-
1
:,:,:],
top_out1_halo
,
btm_out1_halo
)
spatial_halo_exchanger
.
left_right_halo_exchange
(
out1
[:,:
1
,:,:],
out1
[:,
Hs
-
1
:,:,:],
top_out1_halo
,
btm_out1_halo
)
else
:
top_out1_halo
=
out1_pad
[:,:,:
1
,:]
btm_out1_halo
=
out1_pad
[:,:,
Hs
+
1
:
Hs
+
2
,:]
spatial_halo_exchanger
.
left_right_halo_exchange
(
out1
[:,:,:
1
,:],
out1
[:,:,
Hs
-
1
:,:],
top_out1_halo
,
btm_out1_halo
)
if
spatial_method
==
1
:
if
spatial_method
==
1
:
# overlap mid convolution with halo transfer
# overlap mid convolution with halo transfer
if
spatial_group_rank
<
spatial_group_size
-
1
:
if
spatial_group_rank
<
spatial_group_size
-
1
:
stream2
.
wait_stream
(
stream1
)
stream2
.
wait_stream
(
stream1
)
with
torch
.
cuda
.
stream
(
stream2
):
with
torch
.
cuda
.
stream
(
stream2
):
btm_fat_halo
=
torch
.
empty
((
N
,
3
,
W
,
C
),
dtype
=
out1
.
dtype
,
device
=
out1
.
device
)
btm_fat_halo
=
torch
.
empty
((
N
,
3
,
W
,
C
),
dtype
=
out1
.
dtype
,
device
=
out1
.
device
)
if
explicit_nhwc
:
btm_fat_halo
[:,
0
:
2
,:,:].
copy_
(
out1
[:,
Hs
-
2
:,:,:])
btm_fat_halo
[:,
0
:
2
,:,:].
copy_
(
out1
[:,
Hs
-
2
:,:,:])
btm_fat_halo
[:,
2
:,:,:].
copy_
(
btm_out1_halo
)
btm_fat_halo
[:,
2
:,:,:].
copy_
(
btm_out1_halo
)
btm_out2
=
fast_bottleneck
.
forward_out2_halo
(
nhwc
,
btm_fat_halo
,
args
)
else
:
btm_fat_halo
[:,:,
0
:
2
,:].
copy_
(
out1
[:,
Hs
-
2
:,:,:])
btm_fat_halo
[:,:,
2
:,:].
copy_
(
btm_out1_halo
)
btm_out2
=
fast_bottleneck
.
forward_out2_halo
(
explicit_nhwc
,
btm_fat_halo
,
args
)
if
spatial_group_rank
>
0
:
if
spatial_group_rank
>
0
:
with
torch
.
cuda
.
stream
(
stream1
):
with
torch
.
cuda
.
stream
(
stream1
):
top_fat_halo
=
torch
.
empty
((
N
,
3
,
W
,
C
),
dtype
=
out1
.
dtype
,
device
=
out1
.
device
)
top_fat_halo
=
torch
.
empty
((
N
,
3
,
W
,
C
),
dtype
=
out1
.
dtype
,
device
=
out1
.
device
)
if
explicit_nhwc
:
top_fat_halo
[:,:
1
,:,:].
copy_
(
top_out1_halo
)
top_fat_halo
[:,:
1
,:,:].
copy_
(
top_out1_halo
)
top_fat_halo
[:,
1
:
3
,:,:].
copy_
(
out1
[:,:
2
,:,:])
top_fat_halo
[:,
1
:
3
,:,:].
copy_
(
out1
[:,:
2
,:,:])
top_out2
=
fast_bottleneck
.
forward_out2_halo
(
nhwc
,
top_fat_halo
,
args
)
else
:
top_fat_halo
[:,:,:
1
,:].
copy_
(
top_out1_halo
)
top_fat_halo
[:,:,
1
:
3
,:].
copy_
(
out1
[:,:
2
,:,:])
top_out2
=
fast_bottleneck
.
forward_out2_halo
(
explicit_nhwc
,
top_fat_halo
,
args
)
inc
.
add_delay
(
10
)
inc
.
add_delay
(
10
)
elif
spatial_method
==
2
:
elif
spatial_method
==
2
:
# wait for halo transfer to finish before doing a full convolution of padded x
# wait for halo transfer to finish before doing a full convolution of padded x
torch
.
cuda
.
current_stream
().
wait_stream
(
stream1
)
torch
.
cuda
.
current_stream
().
wait_stream
(
stream1
)
torch
.
cuda
.
current_stream
().
wait_stream
(
stream3
)
torch
.
cuda
.
current_stream
().
wait_stream
(
stream3
)
fast_bottleneck
.
forward_out2_pad
(
nhwc
,
stride_1x1
,
args
,
outputs
,
out1_pad
)
fast_bottleneck
.
forward_out2_pad
(
explicit_
nhwc
,
stride_1x1
,
args
,
outputs
,
out1_pad
)
else
:
else
:
assert
(
False
),
"spatial_method must be 1 or 2"
assert
(
False
),
"spatial_method must be 1 or 2"
if
spatial_group_size
<=
1
or
spatial_method
==
1
:
if
spatial_group_size
<=
1
or
spatial_method
==
1
:
fast_bottleneck
.
forward_out2
(
nhwc
,
stride_1x1
,
args
,
outputs
)
fast_bottleneck
.
forward_out2
(
explicit_
nhwc
,
stride_1x1
,
args
,
outputs
)
# compute halo cells for outputs[1] (out2)
# compute halo cells for outputs[1] (out2)
if
spatial_group_size
>
1
and
spatial_method
==
1
:
if
spatial_group_size
>
1
and
spatial_method
==
1
:
out2
=
outputs
[
1
]
out2
=
outputs
[
1
]
if
spatial_group_rank
>
0
:
if
spatial_group_rank
>
0
:
torch
.
cuda
.
current_stream
().
wait_stream
(
stream1
)
torch
.
cuda
.
current_stream
().
wait_stream
(
stream1
)
if
explicit_nhwc
:
out2
[:,:
1
,:,:].
copy_
(
top_out2
)
out2
[:,:
1
,:,:].
copy_
(
top_out2
)
else
:
out2
[:,:,:
1
,:].
copy_
(
top_out2
)
if
spatial_group_rank
<
spatial_group_size
-
1
:
if
spatial_group_rank
<
spatial_group_size
-
1
:
torch
.
cuda
.
current_stream
().
wait_stream
(
stream2
)
torch
.
cuda
.
current_stream
().
wait_stream
(
stream2
)
if
explicit_nhwc
:
out2
[:,
Hs
-
1
:,:,:].
copy_
(
btm_out2
)
out2
[:,
Hs
-
1
:,:,:].
copy_
(
btm_out2
)
else
:
out2
[:,:,
Hs
-
1
:,:].
copy_
(
btm_out2
)
torch
.
cuda
.
current_stream
().
wait_stream
(
stream3
)
torch
.
cuda
.
current_stream
().
wait_stream
(
stream3
)
fast_bottleneck
.
forward_rest
(
nhwc
,
stride_1x1
,
args
,
outputs
)
fast_bottleneck
.
forward_rest
(
explicit_
nhwc
,
stride_1x1
,
args
,
outputs
)
# save halos for backward pass
# save halos for backward pass
if
spatial_group_size
>
1
:
if
spatial_group_size
>
1
:
torch
.
cuda
.
current_stream
().
wait_stream
(
stream3
)
torch
.
cuda
.
current_stream
().
wait_stream
(
stream3
)
...
@@ -299,7 +326,7 @@ class SpatialBottleneckFunction(torch.autograd.Function):
...
@@ -299,7 +326,7 @@ class SpatialBottleneckFunction(torch.autograd.Function):
else
:
else
:
ctx
.
save_for_backward
(
*
(
args
+
outputs
))
ctx
.
save_for_backward
(
*
(
args
+
outputs
))
# save relu outputs for drelu
# save relu outputs for drelu
ctx
.
nhwc
=
nhwc
ctx
.
explicit_nhwc
=
explicit_
nhwc
ctx
.
stride_1x1
=
stride_1x1
ctx
.
stride_1x1
=
stride_1x1
ctx
.
spatial_group_size
=
spatial_group_size
ctx
.
spatial_group_size
=
spatial_group_size
if
spatial_group_size
>
1
:
if
spatial_group_size
>
1
:
...
@@ -339,8 +366,8 @@ class SpatialBottleneckFunction(torch.autograd.Function):
...
@@ -339,8 +366,8 @@ class SpatialBottleneckFunction(torch.autograd.Function):
if
ctx
.
downsample
:
if
ctx
.
downsample
:
t_list
.
append
(
ctx
.
saved_tensors
[
10
])
t_list
.
append
(
ctx
.
saved_tensors
[
10
])
grads
=
fast_bottleneck
.
backward_init
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
)
grads
=
fast_bottleneck
.
backward_init
(
ctx
.
explicit_
nhwc
,
ctx
.
stride_1x1
,
t_list
)
grad_out2
=
fast_bottleneck
.
backward_grad_out2
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
)
grad_out2
=
fast_bottleneck
.
backward_grad_out2
(
ctx
.
explicit_
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
)
# do halo exchange of grad_out2 here
# do halo exchange of grad_out2 here
# compute halo cells for grad_out1
# compute halo cells for grad_out1
if
ctx
.
spatial_group_size
>
1
:
if
ctx
.
spatial_group_size
>
1
:
...
@@ -355,48 +382,66 @@ class SpatialBottleneckFunction(torch.autograd.Function):
...
@@ -355,48 +382,66 @@ class SpatialBottleneckFunction(torch.autograd.Function):
with
torch
.
cuda
.
stream
(
ctx
.
stream2
):
with
torch
.
cuda
.
stream
(
ctx
.
stream2
):
btm_fat_halo
=
torch
.
empty
((
N
,
3
,
W
,
C
),
dtype
=
grad_out2
.
dtype
,
device
=
grad_out2
.
device
)
btm_fat_halo
=
torch
.
empty
((
N
,
3
,
W
,
C
),
dtype
=
grad_out2
.
dtype
,
device
=
grad_out2
.
device
)
btm_relu_halo
=
torch
.
empty
((
N
,
3
,
W
,
C
),
dtype
=
grad_out2
.
dtype
,
device
=
grad_out2
.
device
)
btm_relu_halo
=
torch
.
empty
((
N
,
3
,
W
,
C
),
dtype
=
grad_out2
.
dtype
,
device
=
grad_out2
.
device
)
if
ctx
.
explicit_nhwc
:
btm_fat_halo
[:,:
2
,:,:].
copy_
(
grad_out2
[:,
Hs
-
2
:,:,:])
btm_fat_halo
[:,:
2
,:,:].
copy_
(
grad_out2
[:,
Hs
-
2
:,:,:])
btm_fat_halo
[:,
2
:,:,:].
copy_
(
btm_halo
)
btm_fat_halo
[:,
2
:,:,:].
copy_
(
btm_halo
)
btm_relu_halo
[:,:
2
,:,:].
copy_
(
relu1
[:,
Hs
-
2
:,:,:])
btm_relu_halo
[:,:
2
,:,:].
copy_
(
relu1
[:,
Hs
-
2
:,:,:])
btm_relu_halo
[:,
2
:,:,:].
zero_
()
btm_relu_halo
[:,
2
:,:,:].
zero_
()
btm_grad_out1_halo
=
fast_bottleneck
.
backward_grad_out1_halo
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
btm_fat_halo
,
btm_relu_halo
)
else
:
btm_fat_halo
[:,:,:
2
,:].
copy_
(
grad_out2
[:,
Hs
-
2
:,:,:])
btm_fat_halo
[:,:,
2
:,:].
copy_
(
btm_halo
)
btm_relu_halo
[:,:,:
2
,:].
copy_
(
relu1
[:,
Hs
-
2
:,:,:])
btm_relu_halo
[:,:,
2
:,:].
zero_
()
btm_grad_out1_halo
=
fast_bottleneck
.
backward_grad_out1_halo
(
ctx
.
explicit_nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
btm_fat_halo
,
btm_relu_halo
)
if
ctx
.
explicit_nhwc
:
btm_grad_out1_halo
=
btm_grad_out1_halo
[:,
1
:
2
,:,:]
btm_grad_out1_halo
=
btm_grad_out1_halo
[:,
1
:
2
,:,:]
else
:
btm_grad_out1_halo
=
btm_grad_out1_halo
[:,:,
1
:
2
,:]
if
ctx
.
spatial_group_rank
>
0
:
if
ctx
.
spatial_group_rank
>
0
:
with
torch
.
cuda
.
stream
(
ctx
.
stream1
):
with
torch
.
cuda
.
stream
(
ctx
.
stream1
):
top_fat_halo
=
torch
.
empty
((
N
,
3
,
W
,
C
),
dtype
=
grad_out2
.
dtype
,
device
=
grad_out2
.
device
)
top_fat_halo
=
torch
.
empty
((
N
,
3
,
W
,
C
),
dtype
=
grad_out2
.
dtype
,
device
=
grad_out2
.
device
)
top_relu_halo
=
torch
.
empty
((
N
,
3
,
W
,
C
),
dtype
=
grad_out2
.
dtype
,
device
=
grad_out2
.
device
)
top_relu_halo
=
torch
.
empty
((
N
,
3
,
W
,
C
),
dtype
=
grad_out2
.
dtype
,
device
=
grad_out2
.
device
)
if
ctx
.
explicit_nhwc
:
top_fat_halo
[:,:
1
,:,:].
copy_
(
top_halo
)
top_fat_halo
[:,:
1
,:,:].
copy_
(
top_halo
)
top_fat_halo
[:,
1
:,:,:].
copy_
(
grad_out2
[:,:
2
,:,:])
top_fat_halo
[:,
1
:,:,:].
copy_
(
grad_out2
[:,:
2
,:,:])
top_relu_halo
[:,:
1
,:,:].
zero_
()
top_relu_halo
[:,:
1
,:,:].
zero_
()
top_relu_halo
[:,
1
:,:,:].
copy_
(
relu1
[:,:
2
,:,:])
top_relu_halo
[:,
1
:,:,:].
copy_
(
relu1
[:,:
2
,:,:])
top_grad_out1_halo
=
fast_bottleneck
.
backward_grad_out1_halo
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
top_fat_halo
,
top_relu_halo
)
else
:
top_fat_halo
[:,:,:
1
,:].
copy_
(
top_halo
)
top_fat_halo
[:,:,
1
:,:].
copy_
(
grad_out2
[:,:
2
,:,:])
top_relu_halo
[:,:,:
1
,:].
zero_
()
top_relu_halo
[:,:,
1
:,:].
copy_
(
relu1
[:,:
2
,:,:])
top_grad_out1_halo
=
fast_bottleneck
.
backward_grad_out1_halo
(
ctx
.
explicit_nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
top_fat_halo
,
top_relu_halo
)
if
ctx
.
explicit_nhwc
:
top_grad_out1_halo
=
top_grad_out1_halo
[:,
1
:
2
,:,:]
top_grad_out1_halo
=
top_grad_out1_halo
[:,
1
:
2
,:,:]
else
:
top_grad_out1_halo
=
top_grad_out1_halo
[:,:,
1
:
2
,:]
inc
.
add_delay
(
10
)
inc
.
add_delay
(
10
)
wgrad2_stream
=
torch
.
cuda
.
Stream
()
wgrad2_stream
=
torch
.
cuda
.
Stream
()
wgrad2_stream
.
wait_stream
(
torch
.
cuda
.
current_stream
())
wgrad2_stream
.
wait_stream
(
torch
.
cuda
.
current_stream
())
with
torch
.
cuda
.
stream
(
wgrad2_stream
):
with
torch
.
cuda
.
stream
(
wgrad2_stream
):
if
ctx
.
spatial_group_size
>
1
:
if
ctx
.
spatial_group_size
>
1
:
wgrad2
=
fast_bottleneck
.
backward_wgrad2_pad
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
out1_pad
,
grad_out2
)
wgrad2
=
fast_bottleneck
.
backward_wgrad2_pad
(
ctx
.
explicit_
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
out1_pad
,
grad_out2
)
else
:
else
:
wgrad2
=
fast_bottleneck
.
backward_wgrad2
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
grad_out2
)
wgrad2
=
fast_bottleneck
.
backward_wgrad2
(
ctx
.
explicit_
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
grad_out2
)
# compute wgrad2 for internal cells
# compute wgrad2 for internal cells
#wgrad2 = fast_bottleneck.backward_wgrad2(ctx.nhwc, ctx.stride_1x1, t_list, grads, grad_out2)
#wgrad2 = fast_bottleneck.backward_wgrad2(ctx.
explicit_
nhwc, ctx.stride_1x1, t_list, grads, grad_out2)
# apply wgrad2 halos
# apply wgrad2 halos
#if ctx.spatial_group_size > 1:
#if ctx.spatial_group_size > 1:
# if ctx.spatial_group_rank > 0:
# if ctx.spatial_group_rank > 0:
# top_grad2_halo = grad_out2[:,:1,:,:]
# top_grad2_halo = grad_out2[:,:1,:,:]
# top_wgrad2_halo = fast_bottleneck.backward_wgrad2_halo(ctx.nhwc, ctx.stride_1x1, t_list, grads, top_out1_halo, top_grad2_halo)
# top_wgrad2_halo = fast_bottleneck.backward_wgrad2_halo(ctx.
explicit_
nhwc, ctx.stride_1x1, t_list, grads, top_out1_halo, top_grad2_halo)
# wgrad2[:,:1,:,:].add_(top_wgrad2_halo)
# wgrad2[:,:1,:,:].add_(top_wgrad2_halo)
# if ctx.spatial_group_rank < ctx.spatial_group_size-1:
# if ctx.spatial_group_rank < ctx.spatial_group_size-1:
# btm_grad2_halo = grad_out2[:,-1:,:,:]
# btm_grad2_halo = grad_out2[:,-1:,:,:]
# btm_wgrad2_halo = fast_bottleneck.backward_wgrad2_halo(ctx.nhwc, ctx.stride_1x1, t_list, grads, btm_out1_halo, btm_grad2_halo)
# btm_wgrad2_halo = fast_bottleneck.backward_wgrad2_halo(ctx.
explicit_
nhwc, ctx.stride_1x1, t_list, grads, btm_out1_halo, btm_grad2_halo)
# wgrad2[:,-1:,:,:].add_(btm_wgrad2_halo)
# wgrad2[:,-1:,:,:].add_(btm_wgrad2_halo)
# compute grad_out1 for internal cells
# compute grad_out1 for internal cells
grad_out1
=
fast_bottleneck
.
backward_grad_out1
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
grad_out2
)
grad_out1
=
fast_bottleneck
.
backward_grad_out1
(
ctx
.
explicit_
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
grad_out2
)
# apply halo cells to grad_out1
# apply halo cells to grad_out1
if
ctx
.
spatial_group_size
>
1
:
if
ctx
.
spatial_group_size
>
1
:
...
@@ -406,14 +451,20 @@ class SpatialBottleneckFunction(torch.autograd.Function):
...
@@ -406,14 +451,20 @@ class SpatialBottleneckFunction(torch.autograd.Function):
#print("w.shape = %s, z.shape = %s, relu1.shape = %s" % (str(list(w.shape)), str(list(z.shape)), str(list(relu1.shape))))
#print("w.shape = %s, z.shape = %s, relu1.shape = %s" % (str(list(w.shape)), str(list(z.shape)), str(list(relu1.shape))))
if
ctx
.
spatial_group_rank
>
0
:
if
ctx
.
spatial_group_rank
>
0
:
torch
.
cuda
.
current_stream
().
wait_stream
(
ctx
.
stream1
)
torch
.
cuda
.
current_stream
().
wait_stream
(
ctx
.
stream1
)
if
ctx
.
explicit_nhwc
:
grad_out1
[:,:
1
,:,:].
copy_
(
top_grad_out1_halo
)
grad_out1
[:,:
1
,:,:].
copy_
(
top_grad_out1_halo
)
else
:
grad_out1
[:,:,:
1
,:].
copy_
(
top_grad_out1_halo
)
#print("ctx.spatial_group_rank = %d, apply grad_out1 top halo (grad_out1.shape = %s)" % (ctx.spatial_group_rank, str(list(grad_out1.shape))))
#print("ctx.spatial_group_rank = %d, apply grad_out1 top halo (grad_out1.shape = %s)" % (ctx.spatial_group_rank, str(list(grad_out1.shape))))
if
ctx
.
spatial_group_rank
<
ctx
.
spatial_group_size
-
1
:
if
ctx
.
spatial_group_rank
<
ctx
.
spatial_group_size
-
1
:
torch
.
cuda
.
current_stream
().
wait_stream
(
ctx
.
stream2
)
torch
.
cuda
.
current_stream
().
wait_stream
(
ctx
.
stream2
)
if
ctx
.
explicit_nhwc
:
grad_out1
[:,
Hs
-
1
:,:,:].
copy_
(
btm_grad_out1_halo
)
grad_out1
[:,
Hs
-
1
:,:,:].
copy_
(
btm_grad_out1_halo
)
else
:
grad_out1
[:,:,
Hs
-
1
:,:].
copy_
(
btm_grad_out1_halo
)
#print("ctx.spatial_group_rank = %d, apply grad_out1 btm halo (grad_out1.shape = %s)" % (ctx.spatial_group_rank, str(list(grad_out1.shape))))
#print("ctx.spatial_group_rank = %d, apply grad_out1 btm halo (grad_out1.shape = %s)" % (ctx.spatial_group_rank, str(list(grad_out1.shape))))
fast_bottleneck
.
backward_rest
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
grad_out2
,
grad_out1
,
wgrad2
)
fast_bottleneck
.
backward_rest
(
ctx
.
explicit_
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
grad_out2
,
grad_out1
,
wgrad2
)
torch
.
cuda
.
current_stream
().
wait_stream
(
wgrad2_stream
)
torch
.
cuda
.
current_stream
().
wait_stream
(
wgrad2_stream
)
return
(
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
*
grads
)
return
(
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
*
grads
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment