Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
9c16d945
Commit
9c16d945
authored
Mar 31, 2022
by
Thor Johnsen
Browse files
Bug fixes
parent
0c20c455
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
20 additions
and
27 deletions
+20
-27
apex/contrib/bottleneck/bottleneck.py
apex/contrib/bottleneck/bottleneck.py
+20
-27
No files found.
apex/contrib/bottleneck/bottleneck.py
View file @
9c16d945
...
@@ -276,7 +276,7 @@ class SpatialBottleneckFunction(torch.autograd.Function):
...
@@ -276,7 +276,7 @@ class SpatialBottleneckFunction(torch.autograd.Function):
btm_fat_halo
[:,
0
:
2
,:,:].
copy_
(
out1
[:,
Hs
-
2
:,:,:])
btm_fat_halo
[:,
0
:
2
,:,:].
copy_
(
out1
[:,
Hs
-
2
:,:,:])
btm_fat_halo
[:,
2
:,:,:].
copy_
(
btm_out1_halo
)
btm_fat_halo
[:,
2
:,:,:].
copy_
(
btm_out1_halo
)
else
:
else
:
btm_fat_halo
[:,:,
0
:
2
,:].
copy_
(
out1
[:,
Hs
-
2
:,:
,:
])
btm_fat_halo
[:,:,
0
:
2
,:].
copy_
(
out1
[:,
:,
Hs
-
2
:,:])
btm_fat_halo
[:,:,
2
:,:].
copy_
(
btm_out1_halo
)
btm_fat_halo
[:,:,
2
:,:].
copy_
(
btm_out1_halo
)
btm_out2
=
fast_bottleneck
.
forward_out2_halo
(
explicit_nhwc
,
btm_fat_halo
,
args
)
btm_out2
=
fast_bottleneck
.
forward_out2_halo
(
explicit_nhwc
,
btm_fat_halo
,
args
)
if
spatial_group_rank
>
0
:
if
spatial_group_rank
>
0
:
...
@@ -287,7 +287,7 @@ class SpatialBottleneckFunction(torch.autograd.Function):
...
@@ -287,7 +287,7 @@ class SpatialBottleneckFunction(torch.autograd.Function):
top_fat_halo
[:,
1
:
3
,:,:].
copy_
(
out1
[:,:
2
,:,:])
top_fat_halo
[:,
1
:
3
,:,:].
copy_
(
out1
[:,:
2
,:,:])
else
:
else
:
top_fat_halo
[:,:,:
1
,:].
copy_
(
top_out1_halo
)
top_fat_halo
[:,:,:
1
,:].
copy_
(
top_out1_halo
)
top_fat_halo
[:,:,
1
:
3
,:].
copy_
(
out1
[:,:
2
,:,:])
top_fat_halo
[:,:,
1
:
3
,:].
copy_
(
out1
[:,:,:
2
,:])
top_out2
=
fast_bottleneck
.
forward_out2_halo
(
explicit_nhwc
,
top_fat_halo
,
args
)
top_out2
=
fast_bottleneck
.
forward_out2_halo
(
explicit_nhwc
,
top_fat_halo
,
args
)
inc
.
add_delay
(
10
)
inc
.
add_delay
(
10
)
elif
spatial_method
==
2
:
elif
spatial_method
==
2
:
...
@@ -368,10 +368,15 @@ class SpatialBottleneckFunction(torch.autograd.Function):
...
@@ -368,10 +368,15 @@ class SpatialBottleneckFunction(torch.autograd.Function):
grads
=
fast_bottleneck
.
backward_init
(
ctx
.
explicit_nhwc
,
ctx
.
stride_1x1
,
t_list
)
grads
=
fast_bottleneck
.
backward_init
(
ctx
.
explicit_nhwc
,
ctx
.
stride_1x1
,
t_list
)
grad_out2
=
fast_bottleneck
.
backward_grad_out2
(
ctx
.
explicit_nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
)
grad_out2
=
fast_bottleneck
.
backward_grad_out2
(
ctx
.
explicit_nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
)
wgrad2_stream
=
torch
.
cuda
.
Stream
()
wgrad2_stream
.
wait_stream
(
torch
.
cuda
.
current_stream
())
# do halo exchange of grad_out2 here
# do halo exchange of grad_out2 here
# compute halo cells for grad_out1
# compute halo cells for grad_out1
if
ctx
.
spatial_group_size
>
1
:
if
ctx
.
spatial_group_size
>
1
:
if
ctx
.
explicit_nhwc
:
N
,
Hs
,
W
,
C
=
list
(
grad_out2
.
shape
)
N
,
Hs
,
W
,
C
=
list
(
grad_out2
.
shape
)
else
:
N
,
C
,
Hs
,
W
=
list
(
grad_out2
.
shape
)
relu1
=
t_list
[
12
]
relu1
=
t_list
[
12
]
ctx
.
stream1
.
wait_stream
(
torch
.
cuda
.
current_stream
())
ctx
.
stream1
.
wait_stream
(
torch
.
cuda
.
current_stream
())
with
torch
.
cuda
.
stream
(
ctx
.
stream1
):
with
torch
.
cuda
.
stream
(
ctx
.
stream1
):
...
@@ -380,17 +385,19 @@ class SpatialBottleneckFunction(torch.autograd.Function):
...
@@ -380,17 +385,19 @@ class SpatialBottleneckFunction(torch.autograd.Function):
if
ctx
.
spatial_group_rank
<
ctx
.
spatial_group_size
-
1
:
if
ctx
.
spatial_group_rank
<
ctx
.
spatial_group_size
-
1
:
ctx
.
stream2
.
wait_stream
(
ctx
.
stream1
)
ctx
.
stream2
.
wait_stream
(
ctx
.
stream1
)
with
torch
.
cuda
.
stream
(
ctx
.
stream2
):
with
torch
.
cuda
.
stream
(
ctx
.
stream2
):
if
ctx
.
explicit_nhwc
:
btm_fat_halo
=
torch
.
empty
((
N
,
3
,
W
,
C
),
dtype
=
grad_out2
.
dtype
,
device
=
grad_out2
.
device
)
btm_fat_halo
=
torch
.
empty
((
N
,
3
,
W
,
C
),
dtype
=
grad_out2
.
dtype
,
device
=
grad_out2
.
device
)
btm_relu_halo
=
torch
.
empty
((
N
,
3
,
W
,
C
),
dtype
=
grad_out2
.
dtype
,
device
=
grad_out2
.
device
)
btm_relu_halo
=
torch
.
empty
((
N
,
3
,
W
,
C
),
dtype
=
grad_out2
.
dtype
,
device
=
grad_out2
.
device
)
if
ctx
.
explicit_nhwc
:
btm_fat_halo
[:,:
2
,:,:].
copy_
(
grad_out2
[:,
Hs
-
2
:,:,:])
btm_fat_halo
[:,:
2
,:,:].
copy_
(
grad_out2
[:,
Hs
-
2
:,:,:])
btm_fat_halo
[:,
2
:,:,:].
copy_
(
btm_halo
)
btm_fat_halo
[:,
2
:,:,:].
copy_
(
btm_halo
)
btm_relu_halo
[:,:
2
,:,:].
copy_
(
relu1
[:,
Hs
-
2
:,:,:])
btm_relu_halo
[:,:
2
,:,:].
copy_
(
relu1
[:,
Hs
-
2
:,:,:])
btm_relu_halo
[:,
2
:,:,:].
zero_
()
btm_relu_halo
[:,
2
:,:,:].
zero_
()
else
:
else
:
btm_fat_halo
[:,:,:
2
,:].
copy_
(
grad_out2
[:,
Hs
-
2
:,:,:])
btm_fat_halo
=
torch
.
empty
((
N
,
C
,
3
,
W
),
dtype
=
grad_out2
.
dtype
,
device
=
grad_out2
.
device
)
btm_relu_halo
=
torch
.
empty
((
N
,
C
,
3
,
W
),
dtype
=
grad_out2
.
dtype
,
device
=
grad_out2
.
device
)
btm_fat_halo
[:,:,:
2
,:].
copy_
(
grad_out2
[:,:,
Hs
-
2
:,:])
btm_fat_halo
[:,:,
2
:,:].
copy_
(
btm_halo
)
btm_fat_halo
[:,:,
2
:,:].
copy_
(
btm_halo
)
btm_relu_halo
[:,:,:
2
,:].
copy_
(
relu1
[:,
Hs
-
2
:,:
,:
])
btm_relu_halo
[:,:,:
2
,:].
copy_
(
relu1
[:,
:,
Hs
-
2
:,:])
btm_relu_halo
[:,:,
2
:,:].
zero_
()
btm_relu_halo
[:,:,
2
:,:].
zero_
()
btm_grad_out1_halo
=
fast_bottleneck
.
backward_grad_out1_halo
(
ctx
.
explicit_nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
btm_fat_halo
,
btm_relu_halo
)
btm_grad_out1_halo
=
fast_bottleneck
.
backward_grad_out1_halo
(
ctx
.
explicit_nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
btm_fat_halo
,
btm_relu_halo
)
if
ctx
.
explicit_nhwc
:
if
ctx
.
explicit_nhwc
:
...
@@ -399,18 +406,20 @@ class SpatialBottleneckFunction(torch.autograd.Function):
...
@@ -399,18 +406,20 @@ class SpatialBottleneckFunction(torch.autograd.Function):
btm_grad_out1_halo
=
btm_grad_out1_halo
[:,:,
1
:
2
,:]
btm_grad_out1_halo
=
btm_grad_out1_halo
[:,:,
1
:
2
,:]
if
ctx
.
spatial_group_rank
>
0
:
if
ctx
.
spatial_group_rank
>
0
:
with
torch
.
cuda
.
stream
(
ctx
.
stream1
):
with
torch
.
cuda
.
stream
(
ctx
.
stream1
):
if
ctx
.
explicit_nhwc
:
top_fat_halo
=
torch
.
empty
((
N
,
3
,
W
,
C
),
dtype
=
grad_out2
.
dtype
,
device
=
grad_out2
.
device
)
top_fat_halo
=
torch
.
empty
((
N
,
3
,
W
,
C
),
dtype
=
grad_out2
.
dtype
,
device
=
grad_out2
.
device
)
top_relu_halo
=
torch
.
empty
((
N
,
3
,
W
,
C
),
dtype
=
grad_out2
.
dtype
,
device
=
grad_out2
.
device
)
top_relu_halo
=
torch
.
empty
((
N
,
3
,
W
,
C
),
dtype
=
grad_out2
.
dtype
,
device
=
grad_out2
.
device
)
if
ctx
.
explicit_nhwc
:
top_fat_halo
[:,:
1
,:,:].
copy_
(
top_halo
)
top_fat_halo
[:,:
1
,:,:].
copy_
(
top_halo
)
top_fat_halo
[:,
1
:,:,:].
copy_
(
grad_out2
[:,:
2
,:,:])
top_fat_halo
[:,
1
:,:,:].
copy_
(
grad_out2
[:,:
2
,:,:])
top_relu_halo
[:,:
1
,:,:].
zero_
()
top_relu_halo
[:,:
1
,:,:].
zero_
()
top_relu_halo
[:,
1
:,:,:].
copy_
(
relu1
[:,:
2
,:,:])
top_relu_halo
[:,
1
:,:,:].
copy_
(
relu1
[:,:
2
,:,:])
else
:
else
:
top_fat_halo
=
torch
.
empty
((
N
,
C
,
3
,
W
),
dtype
=
grad_out2
.
dtype
,
device
=
grad_out2
.
device
)
top_relu_halo
=
torch
.
empty
((
N
,
C
,
3
,
W
),
dtype
=
grad_out2
.
dtype
,
device
=
grad_out2
.
device
)
top_fat_halo
[:,:,:
1
,:].
copy_
(
top_halo
)
top_fat_halo
[:,:,:
1
,:].
copy_
(
top_halo
)
top_fat_halo
[:,:,
1
:,:].
copy_
(
grad_out2
[:,:
2
,:,:])
top_fat_halo
[:,:,
1
:,:].
copy_
(
grad_out2
[:,:,:
2
,:])
top_relu_halo
[:,:,:
1
,:].
zero_
()
top_relu_halo
[:,:,:
1
,:].
zero_
()
top_relu_halo
[:,:,
1
:,:].
copy_
(
relu1
[:,:
2
,:,:])
top_relu_halo
[:,:,
1
:,:].
copy_
(
relu1
[:,:,:
2
,:])
top_grad_out1_halo
=
fast_bottleneck
.
backward_grad_out1_halo
(
ctx
.
explicit_nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
top_fat_halo
,
top_relu_halo
)
top_grad_out1_halo
=
fast_bottleneck
.
backward_grad_out1_halo
(
ctx
.
explicit_nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
top_fat_halo
,
top_relu_halo
)
if
ctx
.
explicit_nhwc
:
if
ctx
.
explicit_nhwc
:
top_grad_out1_halo
=
top_grad_out1_halo
[:,
1
:
2
,:,:]
top_grad_out1_halo
=
top_grad_out1_halo
[:,
1
:
2
,:,:]
...
@@ -418,28 +427,12 @@ class SpatialBottleneckFunction(torch.autograd.Function):
...
@@ -418,28 +427,12 @@ class SpatialBottleneckFunction(torch.autograd.Function):
top_grad_out1_halo
=
top_grad_out1_halo
[:,:,
1
:
2
,:]
top_grad_out1_halo
=
top_grad_out1_halo
[:,:,
1
:
2
,:]
inc
.
add_delay
(
10
)
inc
.
add_delay
(
10
)
wgrad2_stream
=
torch
.
cuda
.
Stream
()
wgrad2_stream
.
wait_stream
(
torch
.
cuda
.
current_stream
())
with
torch
.
cuda
.
stream
(
wgrad2_stream
):
with
torch
.
cuda
.
stream
(
wgrad2_stream
):
if
ctx
.
spatial_group_size
>
1
:
if
ctx
.
spatial_group_size
>
1
:
wgrad2
=
fast_bottleneck
.
backward_wgrad2_pad
(
ctx
.
explicit_nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
out1_pad
,
grad_out2
)
wgrad2
=
fast_bottleneck
.
backward_wgrad2_pad
(
ctx
.
explicit_nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
out1_pad
,
grad_out2
)
else
:
else
:
wgrad2
=
fast_bottleneck
.
backward_wgrad2
(
ctx
.
explicit_nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
grad_out2
)
wgrad2
=
fast_bottleneck
.
backward_wgrad2
(
ctx
.
explicit_nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
grad_out2
)
# compute wgrad2 for internal cells
#wgrad2 = fast_bottleneck.backward_wgrad2(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, grad_out2)
# apply wgrad2 halos
#if ctx.spatial_group_size > 1:
# if ctx.spatial_group_rank > 0:
# top_grad2_halo = grad_out2[:,:1,:,:]
# top_wgrad2_halo = fast_bottleneck.backward_wgrad2_halo(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, top_out1_halo, top_grad2_halo)
# wgrad2[:,:1,:,:].add_(top_wgrad2_halo)
# if ctx.spatial_group_rank < ctx.spatial_group_size-1:
# btm_grad2_halo = grad_out2[:,-1:,:,:]
# btm_wgrad2_halo = fast_bottleneck.backward_wgrad2_halo(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, btm_out1_halo, btm_grad2_halo)
# wgrad2[:,-1:,:,:].add_(btm_wgrad2_halo)
# compute grad_out1 for internal cells
# compute grad_out1 for internal cells
grad_out1
=
fast_bottleneck
.
backward_grad_out1
(
ctx
.
explicit_nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
grad_out2
)
grad_out1
=
fast_bottleneck
.
backward_grad_out1
(
ctx
.
explicit_nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
grad_out2
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment