Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
0c20c455
Commit
0c20c455
authored
Mar 31, 2022
by
Thor Johnsen
Browse files
Some fixes to better support native nhwc
parent
34df0f79
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
97 additions
and
46 deletions
+97
-46
apex/contrib/bottleneck/bottleneck.py
apex/contrib/bottleneck/bottleneck.py
+97
-46
No files found.
apex/contrib/bottleneck/bottleneck.py
View file @
0c20c455
...
...
@@ -220,7 +220,7 @@ class Bottleneck(torch.nn.Module):
class
SpatialBottleneckFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
spatial_group_size
,
spatial_group_rank
,
spatial_communicator
,
spatial_halo_exchanger
,
spatial_method
,
nhwc
,
stride_1x1
,
scale
,
bias
,
x
,
*
conv
):
def
forward
(
ctx
,
spatial_group_size
,
spatial_group_rank
,
spatial_communicator
,
spatial_halo_exchanger
,
spatial_method
,
explicit_
nhwc
,
stride_1x1
,
scale
,
bias
,
x
,
*
conv
):
if
spatial_group_size
>
1
:
stream1
=
spatial_halo_exchanger
.
stream1
stream2
=
spatial_halo_exchanger
.
stream2
...
...
@@ -234,64 +234,91 @@ class SpatialBottleneckFunction(torch.autograd.Function):
args
.
append
(
scale
[
3
])
args
.
append
(
bias
[
3
])
# weight buffers are always in nhwc while shape can be nhwc or channels_last
# weight buffers are always in
explicit_
nhwc while shape can be
explicit_
nhwc or channels_last
# here we pass in flag and let c++ handle it
# alternatively, we can put all sizes into a fixed format and pass it in
outputs
=
fast_bottleneck
.
forward_init
(
nhwc
,
stride_1x1
,
args
)
fast_bottleneck
.
forward_out1
(
nhwc
,
stride_1x1
,
args
,
outputs
)
outputs
=
fast_bottleneck
.
forward_init
(
explicit_
nhwc
,
stride_1x1
,
args
)
fast_bottleneck
.
forward_out1
(
explicit_
nhwc
,
stride_1x1
,
args
,
outputs
)
if
spatial_group_size
>
1
:
out1
=
outputs
[
0
]
# TODO: This assumes explicit nhwc
N
,
Hs
,
W
,
C
=
list
(
out1
.
shape
)
out1_pad
=
torch
.
empty
([
N
,
Hs
+
2
,
W
,
C
],
dtype
=
out1
.
dtype
,
device
=
'cuda'
)
if
explicit_nhwc
:
N
,
Hs
,
W
,
C
=
list
(
out1
.
shape
)
memory_format
=
torch
.
contiguous_format
out1_pad
=
torch
.
empty
([
N
,
Hs
+
2
,
W
,
C
],
dtype
=
out1
.
dtype
,
device
=
'cuda'
)
else
:
N
,
C
,
Hs
,
W
=
list
(
out1
.
shape
)
memory_format
=
torch
.
channels_last
if
out1
.
is_contiguous
(
memory_format
=
torch
.
channels_last
)
else
torch
.
contiguous_format
out1_pad
=
torch
.
empty
([
N
,
C
,
Hs
+
2
,
W
],
dtype
=
out1
.
dtype
,
device
=
'cuda'
,
memory_format
=
memory_format
)
stream1
.
wait_stream
(
torch
.
cuda
.
current_stream
())
stream3
.
wait_stream
(
torch
.
cuda
.
current_stream
())
with
torch
.
cuda
.
stream
(
stream3
):
out1_pad
[:,
1
:
Hs
+
1
,:,:].
copy_
(
out1
)
if
explicit_nhwc
:
out1_pad
[:,
1
:
Hs
+
1
,:,:].
copy_
(
out1
)
else
:
out1_pad
[:,:,
1
:
Hs
+
1
,:].
copy_
(
out1
)
with
torch
.
cuda
.
stream
(
stream1
):
top_out1_halo
=
out1_pad
[:,:
1
,:,:]
btm_out1_halo
=
out1_pad
[:,
Hs
+
1
:
Hs
+
2
,:,:]
spatial_halo_exchanger
.
left_right_halo_exchange
(
out1
[:,:
1
,:,:],
out1
[:,
Hs
-
1
:,:,:],
top_out1_halo
,
btm_out1_halo
)
if
explicit_nhwc
:
top_out1_halo
=
out1_pad
[:,:
1
,:,:]
btm_out1_halo
=
out1_pad
[:,
Hs
+
1
:
Hs
+
2
,:,:]
spatial_halo_exchanger
.
left_right_halo_exchange
(
out1
[:,:
1
,:,:],
out1
[:,
Hs
-
1
:,:,:],
top_out1_halo
,
btm_out1_halo
)
else
:
top_out1_halo
=
out1_pad
[:,:,:
1
,:]
btm_out1_halo
=
out1_pad
[:,:,
Hs
+
1
:
Hs
+
2
,:]
spatial_halo_exchanger
.
left_right_halo_exchange
(
out1
[:,:,:
1
,:],
out1
[:,:,
Hs
-
1
:,:],
top_out1_halo
,
btm_out1_halo
)
if
spatial_method
==
1
:
# overlap mid convolution with halo transfer
if
spatial_group_rank
<
spatial_group_size
-
1
:
stream2
.
wait_stream
(
stream1
)
with
torch
.
cuda
.
stream
(
stream2
):
btm_fat_halo
=
torch
.
empty
((
N
,
3
,
W
,
C
),
dtype
=
out1
.
dtype
,
device
=
out1
.
device
)
btm_fat_halo
[:,
0
:
2
,:,:].
copy_
(
out1
[:,
Hs
-
2
:,:,:])
btm_fat_halo
[:,
2
:,:,:].
copy_
(
btm_out1_halo
)
btm_out2
=
fast_bottleneck
.
forward_out2_halo
(
nhwc
,
btm_fat_halo
,
args
)
if
explicit_nhwc
:
btm_fat_halo
[:,
0
:
2
,:,:].
copy_
(
out1
[:,
Hs
-
2
:,:,:])
btm_fat_halo
[:,
2
:,:,:].
copy_
(
btm_out1_halo
)
else
:
btm_fat_halo
[:,:,
0
:
2
,:].
copy_
(
out1
[:,
Hs
-
2
:,:,:])
btm_fat_halo
[:,:,
2
:,:].
copy_
(
btm_out1_halo
)
btm_out2
=
fast_bottleneck
.
forward_out2_halo
(
explicit_nhwc
,
btm_fat_halo
,
args
)
if
spatial_group_rank
>
0
:
with
torch
.
cuda
.
stream
(
stream1
):
top_fat_halo
=
torch
.
empty
((
N
,
3
,
W
,
C
),
dtype
=
out1
.
dtype
,
device
=
out1
.
device
)
top_fat_halo
[:,:
1
,:,:].
copy_
(
top_out1_halo
)
top_fat_halo
[:,
1
:
3
,:,:].
copy_
(
out1
[:,:
2
,:,:])
top_out2
=
fast_bottleneck
.
forward_out2_halo
(
nhwc
,
top_fat_halo
,
args
)
if
explicit_nhwc
:
top_fat_halo
[:,:
1
,:,:].
copy_
(
top_out1_halo
)
top_fat_halo
[:,
1
:
3
,:,:].
copy_
(
out1
[:,:
2
,:,:])
else
:
top_fat_halo
[:,:,:
1
,:].
copy_
(
top_out1_halo
)
top_fat_halo
[:,:,
1
:
3
,:].
copy_
(
out1
[:,:
2
,:,:])
top_out2
=
fast_bottleneck
.
forward_out2_halo
(
explicit_nhwc
,
top_fat_halo
,
args
)
inc
.
add_delay
(
10
)
elif
spatial_method
==
2
:
# wait for halo transfer to finish before doing a full convolution of padded x
torch
.
cuda
.
current_stream
().
wait_stream
(
stream1
)
torch
.
cuda
.
current_stream
().
wait_stream
(
stream3
)
fast_bottleneck
.
forward_out2_pad
(
nhwc
,
stride_1x1
,
args
,
outputs
,
out1_pad
)
fast_bottleneck
.
forward_out2_pad
(
explicit_
nhwc
,
stride_1x1
,
args
,
outputs
,
out1_pad
)
else
:
assert
(
False
),
"spatial_method must be 1 or 2"
if
spatial_group_size
<=
1
or
spatial_method
==
1
:
fast_bottleneck
.
forward_out2
(
nhwc
,
stride_1x1
,
args
,
outputs
)
fast_bottleneck
.
forward_out2
(
explicit_
nhwc
,
stride_1x1
,
args
,
outputs
)
# compute halo cells for outputs[1] (out2)
if
spatial_group_size
>
1
and
spatial_method
==
1
:
out2
=
outputs
[
1
]
if
spatial_group_rank
>
0
:
torch
.
cuda
.
current_stream
().
wait_stream
(
stream1
)
out2
[:,:
1
,:,:].
copy_
(
top_out2
)
if
explicit_nhwc
:
out2
[:,:
1
,:,:].
copy_
(
top_out2
)
else
:
out2
[:,:,:
1
,:].
copy_
(
top_out2
)
if
spatial_group_rank
<
spatial_group_size
-
1
:
torch
.
cuda
.
current_stream
().
wait_stream
(
stream2
)
out2
[:,
Hs
-
1
:,:,:].
copy_
(
btm_out2
)
if
explicit_nhwc
:
out2
[:,
Hs
-
1
:,:,:].
copy_
(
btm_out2
)
else
:
out2
[:,:,
Hs
-
1
:,:].
copy_
(
btm_out2
)
torch
.
cuda
.
current_stream
().
wait_stream
(
stream3
)
fast_bottleneck
.
forward_rest
(
nhwc
,
stride_1x1
,
args
,
outputs
)
fast_bottleneck
.
forward_rest
(
explicit_
nhwc
,
stride_1x1
,
args
,
outputs
)
# save halos for backward pass
if
spatial_group_size
>
1
:
torch
.
cuda
.
current_stream
().
wait_stream
(
stream3
)
...
...
@@ -299,7 +326,7 @@ class SpatialBottleneckFunction(torch.autograd.Function):
else
:
ctx
.
save_for_backward
(
*
(
args
+
outputs
))
# save relu outputs for drelu
ctx
.
nhwc
=
nhwc
ctx
.
explicit_nhwc
=
explicit_
nhwc
ctx
.
stride_1x1
=
stride_1x1
ctx
.
spatial_group_size
=
spatial_group_size
if
spatial_group_size
>
1
:
...
...
@@ -339,8 +366,8 @@ class SpatialBottleneckFunction(torch.autograd.Function):
if
ctx
.
downsample
:
t_list
.
append
(
ctx
.
saved_tensors
[
10
])
grads
=
fast_bottleneck
.
backward_init
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
)
grad_out2
=
fast_bottleneck
.
backward_grad_out2
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
)
grads
=
fast_bottleneck
.
backward_init
(
ctx
.
explicit_
nhwc
,
ctx
.
stride_1x1
,
t_list
)
grad_out2
=
fast_bottleneck
.
backward_grad_out2
(
ctx
.
explicit_
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
)
# do halo exchange of grad_out2 here
# compute halo cells for grad_out1
if
ctx
.
spatial_group_size
>
1
:
...
...
@@ -355,48 +382,66 @@ class SpatialBottleneckFunction(torch.autograd.Function):
with
torch
.
cuda
.
stream
(
ctx
.
stream2
):
btm_fat_halo
=
torch
.
empty
((
N
,
3
,
W
,
C
),
dtype
=
grad_out2
.
dtype
,
device
=
grad_out2
.
device
)
btm_relu_halo
=
torch
.
empty
((
N
,
3
,
W
,
C
),
dtype
=
grad_out2
.
dtype
,
device
=
grad_out2
.
device
)
btm_fat_halo
[:,:
2
,:,:].
copy_
(
grad_out2
[:,
Hs
-
2
:,:,:])
btm_fat_halo
[:,
2
:,:,:].
copy_
(
btm_halo
)
btm_relu_halo
[:,:
2
,:,:].
copy_
(
relu1
[:,
Hs
-
2
:,:,:])
btm_relu_halo
[:,
2
:,:,:].
zero_
()
btm_grad_out1_halo
=
fast_bottleneck
.
backward_grad_out1_halo
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
btm_fat_halo
,
btm_relu_halo
)
btm_grad_out1_halo
=
btm_grad_out1_halo
[:,
1
:
2
,:,:]
if
ctx
.
explicit_nhwc
:
btm_fat_halo
[:,:
2
,:,:].
copy_
(
grad_out2
[:,
Hs
-
2
:,:,:])
btm_fat_halo
[:,
2
:,:,:].
copy_
(
btm_halo
)
btm_relu_halo
[:,:
2
,:,:].
copy_
(
relu1
[:,
Hs
-
2
:,:,:])
btm_relu_halo
[:,
2
:,:,:].
zero_
()
else
:
btm_fat_halo
[:,:,:
2
,:].
copy_
(
grad_out2
[:,
Hs
-
2
:,:,:])
btm_fat_halo
[:,:,
2
:,:].
copy_
(
btm_halo
)
btm_relu_halo
[:,:,:
2
,:].
copy_
(
relu1
[:,
Hs
-
2
:,:,:])
btm_relu_halo
[:,:,
2
:,:].
zero_
()
btm_grad_out1_halo
=
fast_bottleneck
.
backward_grad_out1_halo
(
ctx
.
explicit_nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
btm_fat_halo
,
btm_relu_halo
)
if
ctx
.
explicit_nhwc
:
btm_grad_out1_halo
=
btm_grad_out1_halo
[:,
1
:
2
,:,:]
else
:
btm_grad_out1_halo
=
btm_grad_out1_halo
[:,:,
1
:
2
,:]
if
ctx
.
spatial_group_rank
>
0
:
with
torch
.
cuda
.
stream
(
ctx
.
stream1
):
top_fat_halo
=
torch
.
empty
((
N
,
3
,
W
,
C
),
dtype
=
grad_out2
.
dtype
,
device
=
grad_out2
.
device
)
top_relu_halo
=
torch
.
empty
((
N
,
3
,
W
,
C
),
dtype
=
grad_out2
.
dtype
,
device
=
grad_out2
.
device
)
top_fat_halo
[:,:
1
,:,:].
copy_
(
top_halo
)
top_fat_halo
[:,
1
:,:,:].
copy_
(
grad_out2
[:,:
2
,:,:])
top_relu_halo
[:,:
1
,:,:].
zero_
()
top_relu_halo
[:,
1
:,:,:].
copy_
(
relu1
[:,:
2
,:,:])
top_grad_out1_halo
=
fast_bottleneck
.
backward_grad_out1_halo
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
top_fat_halo
,
top_relu_halo
)
top_grad_out1_halo
=
top_grad_out1_halo
[:,
1
:
2
,:,:]
if
ctx
.
explicit_nhwc
:
top_fat_halo
[:,:
1
,:,:].
copy_
(
top_halo
)
top_fat_halo
[:,
1
:,:,:].
copy_
(
grad_out2
[:,:
2
,:,:])
top_relu_halo
[:,:
1
,:,:].
zero_
()
top_relu_halo
[:,
1
:,:,:].
copy_
(
relu1
[:,:
2
,:,:])
else
:
top_fat_halo
[:,:,:
1
,:].
copy_
(
top_halo
)
top_fat_halo
[:,:,
1
:,:].
copy_
(
grad_out2
[:,:
2
,:,:])
top_relu_halo
[:,:,:
1
,:].
zero_
()
top_relu_halo
[:,:,
1
:,:].
copy_
(
relu1
[:,:
2
,:,:])
top_grad_out1_halo
=
fast_bottleneck
.
backward_grad_out1_halo
(
ctx
.
explicit_nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
top_fat_halo
,
top_relu_halo
)
if
ctx
.
explicit_nhwc
:
top_grad_out1_halo
=
top_grad_out1_halo
[:,
1
:
2
,:,:]
else
:
top_grad_out1_halo
=
top_grad_out1_halo
[:,:,
1
:
2
,:]
inc
.
add_delay
(
10
)
wgrad2_stream
=
torch
.
cuda
.
Stream
()
wgrad2_stream
.
wait_stream
(
torch
.
cuda
.
current_stream
())
with
torch
.
cuda
.
stream
(
wgrad2_stream
):
if
ctx
.
spatial_group_size
>
1
:
wgrad2
=
fast_bottleneck
.
backward_wgrad2_pad
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
out1_pad
,
grad_out2
)
wgrad2
=
fast_bottleneck
.
backward_wgrad2_pad
(
ctx
.
explicit_
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
out1_pad
,
grad_out2
)
else
:
wgrad2
=
fast_bottleneck
.
backward_wgrad2
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
grad_out2
)
wgrad2
=
fast_bottleneck
.
backward_wgrad2
(
ctx
.
explicit_
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
grad_out2
)
# compute wgrad2 for internal cells
#wgrad2 = fast_bottleneck.backward_wgrad2(ctx.nhwc, ctx.stride_1x1, t_list, grads, grad_out2)
#wgrad2 = fast_bottleneck.backward_wgrad2(ctx.
explicit_
nhwc, ctx.stride_1x1, t_list, grads, grad_out2)
# apply wgrad2 halos
#if ctx.spatial_group_size > 1:
# if ctx.spatial_group_rank > 0:
# top_grad2_halo = grad_out2[:,:1,:,:]
# top_wgrad2_halo = fast_bottleneck.backward_wgrad2_halo(ctx.nhwc, ctx.stride_1x1, t_list, grads, top_out1_halo, top_grad2_halo)
# top_wgrad2_halo = fast_bottleneck.backward_wgrad2_halo(ctx.
explicit_
nhwc, ctx.stride_1x1, t_list, grads, top_out1_halo, top_grad2_halo)
# wgrad2[:,:1,:,:].add_(top_wgrad2_halo)
# if ctx.spatial_group_rank < ctx.spatial_group_size-1:
# btm_grad2_halo = grad_out2[:,-1:,:,:]
# btm_wgrad2_halo = fast_bottleneck.backward_wgrad2_halo(ctx.nhwc, ctx.stride_1x1, t_list, grads, btm_out1_halo, btm_grad2_halo)
# btm_wgrad2_halo = fast_bottleneck.backward_wgrad2_halo(ctx.
explicit_
nhwc, ctx.stride_1x1, t_list, grads, btm_out1_halo, btm_grad2_halo)
# wgrad2[:,-1:,:,:].add_(btm_wgrad2_halo)
# compute grad_out1 for internal cells
grad_out1
=
fast_bottleneck
.
backward_grad_out1
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
grad_out2
)
grad_out1
=
fast_bottleneck
.
backward_grad_out1
(
ctx
.
explicit_
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
grad_out2
)
# apply halo cells to grad_out1
if
ctx
.
spatial_group_size
>
1
:
...
...
@@ -406,14 +451,20 @@ class SpatialBottleneckFunction(torch.autograd.Function):
#print("w.shape = %s, z.shape = %s, relu1.shape = %s" % (str(list(w.shape)), str(list(z.shape)), str(list(relu1.shape))))
if
ctx
.
spatial_group_rank
>
0
:
torch
.
cuda
.
current_stream
().
wait_stream
(
ctx
.
stream1
)
grad_out1
[:,:
1
,:,:].
copy_
(
top_grad_out1_halo
)
if
ctx
.
explicit_nhwc
:
grad_out1
[:,:
1
,:,:].
copy_
(
top_grad_out1_halo
)
else
:
grad_out1
[:,:,:
1
,:].
copy_
(
top_grad_out1_halo
)
#print("ctx.spatial_group_rank = %d, apply grad_out1 top halo (grad_out1.shape = %s)" % (ctx.spatial_group_rank, str(list(grad_out1.shape))))
if
ctx
.
spatial_group_rank
<
ctx
.
spatial_group_size
-
1
:
torch
.
cuda
.
current_stream
().
wait_stream
(
ctx
.
stream2
)
grad_out1
[:,
Hs
-
1
:,:,:].
copy_
(
btm_grad_out1_halo
)
if
ctx
.
explicit_nhwc
:
grad_out1
[:,
Hs
-
1
:,:,:].
copy_
(
btm_grad_out1_halo
)
else
:
grad_out1
[:,:,
Hs
-
1
:,:].
copy_
(
btm_grad_out1_halo
)
#print("ctx.spatial_group_rank = %d, apply grad_out1 btm halo (grad_out1.shape = %s)" % (ctx.spatial_group_rank, str(list(grad_out1.shape))))
fast_bottleneck
.
backward_rest
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
grad_out2
,
grad_out1
,
wgrad2
)
fast_bottleneck
.
backward_rest
(
ctx
.
explicit_
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
grad_out2
,
grad_out1
,
wgrad2
)
torch
.
cuda
.
current_stream
().
wait_stream
(
wgrad2_stream
)
return
(
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
*
grads
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment