Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
88914a50
Commit
88914a50
authored
Apr 01, 2022
by
Thor Johnsen
Browse files
Add halo correction kernel for bprop
parent
705aa35d
Changes
3
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
760 additions
and
80 deletions
+760
-80
apex/contrib/bottleneck/bottleneck.py
apex/contrib/bottleneck/bottleneck.py
+119
-78
apex/contrib/bottleneck/bottleneck_module_test.py
apex/contrib/bottleneck/bottleneck_module_test.py
+1
-1
apex/contrib/csrc/bottleneck/bottleneck.cpp
apex/contrib/csrc/bottleneck/bottleneck.cpp
+640
-1
No files found.
apex/contrib/bottleneck/bottleneck.py
View file @
88914a50
...
...
@@ -268,17 +268,6 @@ class SpatialBottleneckFunction(torch.autograd.Function):
spatial_halo_exchanger
.
left_right_halo_exchange
(
out1
[:,:,:
1
,:],
out1
[:,:,
Hs
-
1
:,:],
top_out1_halo
,
btm_out1_halo
)
if
spatial_method
==
1
:
# overlap mid convolution with halo transfer
if
spatial_group_rank
>
0
:
with
torch
.
cuda
.
stream
(
stream1
):
if
explicit_nhwc
:
top_fat_halo
=
torch
.
empty
((
N
,
3
,
W
,
C
),
dtype
=
out1
.
dtype
,
device
=
out1
.
device
)
top_fat_halo
[:,:
1
,:,:].
copy_
(
top_out1_halo
)
top_fat_halo
[:,
1
:
3
,:,:].
copy_
(
out1
[:,:
2
,:,:])
else
:
top_fat_halo
=
torch
.
empty
((
N
,
C
,
3
,
W
),
dtype
=
out1
.
dtype
,
device
=
out1
.
device
)
top_fat_halo
[:,:,:
1
,:].
copy_
(
top_out1_halo
)
top_fat_halo
[:,:,
1
:
3
,:].
copy_
(
out1
[:,:,:
2
,:])
top_out2
=
fast_bottleneck
.
forward_out2_halo
(
explicit_nhwc
,
top_fat_halo
,
args
)
if
spatial_group_rank
<
spatial_group_size
-
1
:
stream2
.
wait_stream
(
stream1
)
with
torch
.
cuda
.
stream
(
stream2
):
...
...
@@ -291,6 +280,17 @@ class SpatialBottleneckFunction(torch.autograd.Function):
btm_fat_halo
[:,:,
0
:
2
,:].
copy_
(
out1
[:,:,
Hs
-
2
:,:])
btm_fat_halo
[:,:,
2
:,:].
copy_
(
btm_out1_halo
)
btm_out2
=
fast_bottleneck
.
forward_out2_halo
(
explicit_nhwc
,
btm_fat_halo
,
args
)
if
spatial_group_rank
>
0
:
with
torch
.
cuda
.
stream
(
stream1
):
if
explicit_nhwc
:
top_fat_halo
=
torch
.
empty
((
N
,
3
,
W
,
C
),
dtype
=
out1
.
dtype
,
device
=
out1
.
device
)
top_fat_halo
[:,:
1
,:,:].
copy_
(
top_out1_halo
)
top_fat_halo
[:,
1
:
3
,:,:].
copy_
(
out1
[:,:
2
,:,:])
else
:
top_fat_halo
=
torch
.
empty
((
N
,
C
,
3
,
W
),
dtype
=
out1
.
dtype
,
device
=
out1
.
device
)
top_fat_halo
[:,:,:
1
,:].
copy_
(
top_out1_halo
)
top_fat_halo
[:,:,
1
:
3
,:].
copy_
(
out1
[:,:,:
2
,:])
top_out2
=
fast_bottleneck
.
forward_out2_halo
(
explicit_nhwc
,
top_fat_halo
,
args
)
inc
.
add_delay
(
10
)
elif
spatial_method
!=
2
and
spatial_method
!=
3
:
assert
(
False
),
"spatial_method must be 1, 2 or 3"
...
...
@@ -329,13 +329,6 @@ class SpatialBottleneckFunction(torch.autograd.Function):
# to wait for out2_mask to finish, but itself has to finish before
# the first kernel of _forward_rest can launch.
# At least we can overlap the two halo correction kernels.
if
spatial_group_rank
>
0
:
stream1
.
wait_stream
(
torch
.
cuda
.
current_stream
())
# wait for *_out2_mask to finish
with
torch
.
cuda
.
stream
(
stream1
):
w1by3
=
args
[
2
][:,:
1
,:,:].
clone
()
top_out1_halo
=
top_out1_halo
.
clone
()
top_out2
=
fast_bottleneck
.
forward_out2_halo_corr
(
explicit_nhwc
,
top_out1_halo
,
args
,
w1by3
,
top_out2_halo
.
clone
())
top_out2_halo
.
copy_
(
top_out2
)
if
spatial_group_rank
<
spatial_group_size
-
1
:
stream2
.
wait_stream
(
torch
.
cuda
.
current_stream
())
# wait for *_out2_mask to finish
with
torch
.
cuda
.
stream
(
stream2
):
...
...
@@ -344,9 +337,16 @@ class SpatialBottleneckFunction(torch.autograd.Function):
btm_out2
=
fast_bottleneck
.
forward_out2_halo_corr
(
explicit_nhwc
,
btm_out1_halo
,
args
,
w1by3
,
btm_out2_halo
.
clone
())
btm_out2_halo
.
copy_
(
btm_out2
)
if
spatial_group_rank
>
0
:
torch
.
cuda
.
current_stream
().
wait_stream
(
stream1
)
stream1
.
wait_stream
(
torch
.
cuda
.
current_stream
())
# wait for *_out2_mask to finish
with
torch
.
cuda
.
stream
(
stream1
):
w1by3
=
args
[
2
][:,:
1
,:,:].
clone
()
top_out1_halo
=
top_out1_halo
.
clone
()
top_out2
=
fast_bottleneck
.
forward_out2_halo_corr
(
explicit_nhwc
,
top_out1_halo
,
args
,
w1by3
,
top_out2_halo
.
clone
())
top_out2_halo
.
copy_
(
top_out2
)
if
spatial_group_rank
<
spatial_group_size
-
1
:
torch
.
cuda
.
current_stream
().
wait_stream
(
stream2
)
if
spatial_group_rank
>
0
:
torch
.
cuda
.
current_stream
().
wait_stream
(
stream1
)
fast_bottleneck
.
forward_rest
(
explicit_nhwc
,
stride_1x1
,
args
,
outputs
)
# save halos for backward pass
...
...
@@ -365,6 +365,8 @@ class SpatialBottleneckFunction(torch.autograd.Function):
ctx
.
spatial_group_rank
=
spatial_group_rank
ctx
.
spatial_halo_exchanger
=
spatial_halo_exchanger
ctx
.
spatial_method
=
spatial_method
ctx
.
thresholdTop
=
thresholdTop
ctx
.
thresholdBottom
=
thresholdBottom
ctx
.
stream1
=
stream1
ctx
.
stream2
=
stream2
ctx
.
stream3
=
stream3
...
...
@@ -414,50 +416,55 @@ class SpatialBottleneckFunction(torch.autograd.Function):
with
torch
.
cuda
.
stream
(
ctx
.
stream1
):
top_halo
,
btm_halo
=
ctx
.
spatial_halo_exchanger
.
left_right_halo_exchange
(
grad_out2
[:,:
1
,:,:],
grad_out2
[:,
Hs
-
1
:,:,:])
# copy halos to send buffer
if
ctx
.
spatial_group_rank
<
ctx
.
spatial_group_size
-
1
:
ctx
.
stream2
.
wait_stream
(
ctx
.
stream1
)
with
torch
.
cuda
.
stream
(
ctx
.
stream2
):
if
ctx
.
explicit_nhwc
:
btm_fat_halo
=
torch
.
empty
((
N
,
3
,
W
,
C
),
dtype
=
grad_out2
.
dtype
,
device
=
grad_out2
.
device
)
btm_relu_halo
=
torch
.
empty
((
N
,
3
,
W
,
C
),
dtype
=
grad_out2
.
dtype
,
device
=
grad_out2
.
device
)
btm_fat_halo
[:,:
2
,:,:].
copy_
(
grad_out2
[:,
Hs
-
2
:,:,:])
btm_fat_halo
[:,
2
:,:,:].
copy_
(
btm_halo
)
btm_relu_halo
[:,:
2
,:,:].
copy_
(
relu1
[:,
Hs
-
2
:,:,:])
btm_relu_halo
[:,
2
:,:,:].
zero_
()
else
:
btm_fat_halo
=
torch
.
empty
((
N
,
C
,
3
,
W
),
dtype
=
grad_out2
.
dtype
,
device
=
grad_out2
.
device
)
btm_relu_halo
=
torch
.
empty
((
N
,
C
,
3
,
W
),
dtype
=
grad_out2
.
dtype
,
device
=
grad_out2
.
device
)
btm_fat_halo
[:,:,:
2
,:].
copy_
(
grad_out2
[:,:,
Hs
-
2
:,:])
btm_fat_halo
[:,:,
2
:,:].
copy_
(
btm_halo
)
btm_relu_halo
[:,:,:
2
,:].
copy_
(
relu1
[:,:,
Hs
-
2
:,:])
btm_relu_halo
[:,:,
2
:,:].
zero_
()
btm_grad_out1_halo
=
fast_bottleneck
.
backward_grad_out1_halo
(
ctx
.
explicit_nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
btm_fat_halo
,
btm_relu_halo
)
if
ctx
.
explicit_nhwc
:
btm_grad_out1_halo
=
btm_grad_out1_halo
[:,
1
:
2
,:,:]
else
:
btm_grad_out1_halo
=
btm_grad_out1_halo
[:,:,
1
:
2
,:]
if
ctx
.
spatial_group_rank
>
0
:
with
torch
.
cuda
.
stream
(
ctx
.
stream1
):
if
ctx
.
explicit_nhwc
:
top_fat_halo
=
torch
.
empty
((
N
,
3
,
W
,
C
),
dtype
=
grad_out2
.
dtype
,
device
=
grad_out2
.
device
)
top_relu_halo
=
torch
.
empty
((
N
,
3
,
W
,
C
),
dtype
=
grad_out2
.
dtype
,
device
=
grad_out2
.
device
)
top_fat_halo
[:,:
1
,:,:].
copy_
(
top_halo
)
top_fat_halo
[:,
1
:,:,:].
copy_
(
grad_out2
[:,:
2
,:,:])
top_relu_halo
[:,:
1
,:,:].
zero_
()
top_relu_halo
[:,
1
:,:,:].
copy_
(
relu1
[:,:
2
,:,:])
else
:
top_fat_halo
=
torch
.
empty
((
N
,
C
,
3
,
W
),
dtype
=
grad_out2
.
dtype
,
device
=
grad_out2
.
device
)
top_relu_halo
=
torch
.
empty
((
N
,
C
,
3
,
W
),
dtype
=
grad_out2
.
dtype
,
device
=
grad_out2
.
device
)
top_fat_halo
[:,:,:
1
,:].
copy_
(
top_halo
)
top_fat_halo
[:,:,
1
:,:].
copy_
(
grad_out2
[:,:,:
2
,:])
top_relu_halo
[:,:,:
1
,:].
zero_
()
top_relu_halo
[:,:,
1
:,:].
copy_
(
relu1
[:,:,:
2
,:])
top_grad_out1_halo
=
fast_bottleneck
.
backward_grad_out1_halo
(
ctx
.
explicit_nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
top_fat_halo
,
top_relu_halo
)
if
ctx
.
explicit_nhwc
:
top_grad_out1_halo
=
top_grad_out1_halo
[:,
1
:
2
,:,:]
else
:
top_grad_out1_halo
=
top_grad_out1_halo
[:,:,
1
:
2
,:]
inc
.
add_delay
(
10
)
if
ctx
.
spatial_method
==
1
or
ctx
.
spatial_method
==
2
:
# 1 -> halo recompute approach
# 2 -> wait for concatenated halos, then do single conv on full input (not implemented yet for bprop)
if
ctx
.
spatial_group_rank
<
ctx
.
spatial_group_size
-
1
:
ctx
.
stream2
.
wait_stream
(
ctx
.
stream1
)
with
torch
.
cuda
.
stream
(
ctx
.
stream2
):
if
ctx
.
explicit_nhwc
:
btm_fat_halo
=
torch
.
empty
((
N
,
3
,
W
,
C
),
dtype
=
grad_out2
.
dtype
,
device
=
grad_out2
.
device
)
btm_fat_halo
[:,:
2
,:,:].
copy_
(
grad_out2
[:,
Hs
-
2
:,:,:])
btm_fat_halo
[:,
2
:,:,:].
copy_
(
btm_halo
)
btm_fat_relu_halo
=
torch
.
empty
((
N
,
3
,
W
,
C
),
dtype
=
grad_out2
.
dtype
,
device
=
grad_out2
.
device
)
btm_fat_relu_halo
[:,:
2
,:,:].
copy_
(
relu1
[:,
Hs
-
2
:,:,:])
btm_fat_relu_halo
[:,
2
:,:,:].
zero_
()
else
:
btm_fat_halo
=
torch
.
empty
((
N
,
C
,
3
,
W
),
dtype
=
grad_out2
.
dtype
,
device
=
grad_out2
.
device
)
btm_fat_halo
[:,:,:
2
,:].
copy_
(
grad_out2
[:,:,
Hs
-
2
:,:])
btm_fat_halo
[:,:,
2
:,:].
copy_
(
btm_halo
)
btm_fat_relu_halo
=
torch
.
empty
((
N
,
C
,
3
,
W
),
dtype
=
grad_out2
.
dtype
,
device
=
grad_out2
.
device
)
btm_fat_relu_halo
[:,:,:
2
,:].
copy_
(
relu1
[:,:,
Hs
-
2
:,:])
btm_fat_relu_halo
[:,:,
2
:,:].
zero_
()
btm_grad_out1_halo
=
fast_bottleneck
.
backward_grad_out1_halo
(
ctx
.
explicit_nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
btm_fat_halo
,
btm_fat_relu_halo
)
if
ctx
.
explicit_nhwc
:
btm_grad_out1_halo
=
btm_grad_out1_halo
[:,
1
:
2
,:,:]
else
:
btm_grad_out1_halo
=
btm_grad_out1_halo
[:,:,
1
:
2
,:]
if
ctx
.
spatial_group_rank
>
0
:
with
torch
.
cuda
.
stream
(
ctx
.
stream1
):
if
ctx
.
explicit_nhwc
:
top_fat_halo
=
torch
.
empty
((
N
,
3
,
W
,
C
),
dtype
=
grad_out2
.
dtype
,
device
=
grad_out2
.
device
)
top_fat_halo
[:,:
1
,:,:].
copy_
(
top_halo
)
top_fat_halo
[:,
1
:,:,:].
copy_
(
grad_out2
[:,:
2
,:,:])
top_fat_relu_halo
=
torch
.
empty
((
N
,
3
,
W
,
C
),
dtype
=
grad_out2
.
dtype
,
device
=
grad_out2
.
device
)
top_fat_relu_halo
[:,:
1
,:,:].
zero_
()
top_fat_relu_halo
[:,
1
:,:,:].
copy_
(
relu1
[:,:
2
,:,:])
else
:
top_fat_halo
=
torch
.
empty
((
N
,
C
,
3
,
W
),
dtype
=
grad_out2
.
dtype
,
device
=
grad_out2
.
device
)
top_fat_halo
[:,:,:
1
,:].
copy_
(
top_halo
)
top_fat_halo
[:,:,
1
:,:].
copy_
(
grad_out2
[:,:,:
2
,:])
top_fat_relu_halo
=
torch
.
empty
((
N
,
C
,
3
,
W
),
dtype
=
grad_out2
.
dtype
,
device
=
grad_out2
.
device
)
top_fat_relu_halo
[:,:,:
1
,:].
zero_
()
top_fat_relu_halo
[:,:,
1
:,:].
copy_
(
relu1
[:,:,:
2
,:])
top_grad_out1_halo
=
fast_bottleneck
.
backward_grad_out1_halo
(
ctx
.
explicit_nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
top_fat_halo
,
top_fat_relu_halo
)
if
ctx
.
explicit_nhwc
:
top_grad_out1_halo
=
top_grad_out1_halo
[:,
1
:
2
,:,:]
else
:
top_grad_out1_halo
=
top_grad_out1_halo
[:,:,
1
:
2
,:]
inc
.
add_delay
(
10
)
elif
ctx
.
spatial_method
!=
3
:
assert
(
False
),
"spatial_method must be 1, 2 or 3"
with
torch
.
cuda
.
stream
(
wgrad2_stream
):
if
ctx
.
spatial_group_size
>
1
:
...
...
@@ -466,7 +473,10 @@ class SpatialBottleneckFunction(torch.autograd.Function):
wgrad2
=
fast_bottleneck
.
backward_wgrad2
(
ctx
.
explicit_nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
grad_out2
)
# compute grad_out1 for internal cells
grad_out1
=
fast_bottleneck
.
backward_grad_out1
(
ctx
.
explicit_nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
grad_out2
)
if
ctx
.
spatial_group_size
<=
1
or
ctx
.
spatial_method
==
1
or
ctx
.
spatial_method
==
2
:
grad_out1
=
fast_bottleneck
.
backward_grad_out1
(
ctx
.
explicit_nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
grad_out2
)
elif
ctx
.
spatial_group_size
>
1
and
ctx
.
spatial_method
==
3
:
grad_out1
=
fast_bottleneck
.
backward_grad_out1_mask
(
ctx
.
explicit_nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
grad_out2
,
ctx
.
thresholdTop
,
ctx
.
thresholdBottom
)
# apply halo cells to grad_out1
if
ctx
.
spatial_group_size
>
1
:
...
...
@@ -474,20 +484,51 @@ class SpatialBottleneckFunction(torch.autograd.Function):
z
=
t_list
[
4
]
relu1
=
t_list
[
12
]
#print("w.shape = %s, z.shape = %s, relu1.shape = %s" % (str(list(w.shape)), str(list(z.shape)), str(list(relu1.shape))))
if
ctx
.
spatial_group_rank
>
0
:
torch
.
cuda
.
current_stream
().
wait_stream
(
ctx
.
stream1
)
if
ctx
.
explicit_nhwc
:
grad_out1
[:,:
1
,:,:].
copy_
(
top_grad_out1_halo
)
else
:
grad_out1
[:,:,:
1
,:].
copy_
(
top_grad_out1_halo
)
#print("ctx.spatial_group_rank = %d, apply grad_out1 top halo (grad_out1.shape = %s)" % (ctx.spatial_group_rank, str(list(grad_out1.shape))))
if
ctx
.
spatial_group_rank
<
ctx
.
spatial_group_size
-
1
:
torch
.
cuda
.
current_stream
().
wait_stream
(
ctx
.
stream2
)
if
ctx
.
explicit_nhwc
:
grad_out1
[:,
Hs
-
1
:,:,:].
copy_
(
btm_grad_out1_halo
)
else
:
grad_out1
[:,:,
Hs
-
1
:,:].
copy_
(
btm_grad_out1_halo
)
#print("ctx.spatial_group_rank = %d, apply grad_out1 btm halo (grad_out1.shape = %s)" % (ctx.spatial_group_rank, str(list(grad_out1.shape))))
if
ctx
.
spatial_method
==
1
or
ctx
.
spatial_method
==
2
:
if
ctx
.
spatial_group_rank
<
ctx
.
spatial_group_size
-
1
:
torch
.
cuda
.
current_stream
().
wait_stream
(
ctx
.
stream2
)
if
ctx
.
explicit_nhwc
:
grad_out1
[:,
Hs
-
1
:,:,:].
copy_
(
btm_grad_out1_halo
)
else
:
grad_out1
[:,:,
Hs
-
1
:,:].
copy_
(
btm_grad_out1_halo
)
#print("ctx.spatial_group_rank = %d, apply grad_out1 btm halo (grad_out1.shape = %s)" % (ctx.spatial_group_rank, str(list(grad_out1.shape))))
if
ctx
.
spatial_group_rank
>
0
:
torch
.
cuda
.
current_stream
().
wait_stream
(
ctx
.
stream1
)
if
ctx
.
explicit_nhwc
:
grad_out1
[:,:
1
,:,:].
copy_
(
top_grad_out1_halo
)
else
:
grad_out1
[:,:,:
1
,:].
copy_
(
top_grad_out1_halo
)
#print("ctx.spatial_group_rank = %d, apply grad_out1 top halo (grad_out1.shape = %s)" % (ctx.spatial_group_rank, str(list(grad_out1.shape))))
elif
ctx
.
spatial_method
==
3
:
if
ctx
.
spatial_group_rank
<
ctx
.
spatial_group_size
-
1
:
if
ctx
.
explicit_nhwc
:
btm_relu_halo
=
relu1
[:,
Hs
-
1
:,:,:].
clone
()
btm_grad_out1
=
grad_out1
[:,
Hs
-
1
:,:,:]
else
:
btm_relu_halo
=
relu1
[:,:,
Hs
-
1
:,:].
clone
()
btm_grad_out1
=
grad_out1
[:,:,
Hs
-
1
:,:]
w1by3
=
w
[:,:
1
,:,:].
clone
()
ctx
.
stream1
.
wait_stream
(
ctx
.
stream2
)
# wait for halo transfers to finish
ctx
.
stream2
.
wait_stream
(
torch
.
cuda
.
current_stream
())
# wait for backward_grad_out1_mask to finish before launching halo correction kernel
with
torch
.
cuda
.
stream
(
ctx
.
stream1
):
btm_grad_out1_halo
=
fast_bottleneck
.
backward_grad_out1_halo_corr
(
ctx
.
explicit_nhwc
,
ctx
.
stride_1x1
,
t_list
,
w1by3
,
grads
,
btm_halo
,
btm_relu_halo
,
btm_grad_out1
.
clone
())
btm_grad_out1
.
copy_
(
btm_grad_out1_halo
)
if
ctx
.
spatial_group_rank
>
0
:
if
ctx
.
explicit_nhwc
:
top_relu_halo
=
relu1
[:,:
1
,:,:].
clone
()
top_grad_out1
=
grad_out1
[:,:
1
,:,:]
else
:
top_relu_halo
=
relu1
[:,:,:
1
,:].
clone
()
top_grad_out1
=
grad_out1
[:,:,:
1
,:]
w1by3
=
w
[:,
2
:,:,:].
clone
()
ctx
.
stream1
.
wait_stream
(
torch
.
cuda
.
current_stream
())
# wait for backward_grad_out1_mask to finish before launching halo correction kernel
with
torch
.
cuda
.
stream
(
ctx
.
stream1
):
top_grad_out1_halo
=
fast_bottleneck
.
backward_grad_out1_halo_corr
(
ctx
.
explicit_nhwc
,
ctx
.
stride_1x1
,
t_list
,
w1by3
,
grads
,
top_halo
,
top_relu_halo
,
top_grad_out1
.
clone
())
top_grad_out1
.
copy_
(
top_grad_out1_halo
)
if
ctx
.
spatial_group_rank
<
ctx
.
spatial_group_size
-
1
:
torch
.
cuda
.
current_stream
().
wait_stream
(
ctx
.
stream2
)
# wait for halo correction to finish
if
ctx
.
spatial_group_rank
>
0
:
torch
.
cuda
.
current_stream
().
wait_stream
(
ctx
.
stream1
)
fast_bottleneck
.
backward_rest
(
ctx
.
explicit_nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
grad_out2
,
grad_out1
,
wgrad2
)
torch
.
cuda
.
current_stream
().
wait_stream
(
wgrad2_stream
)
...
...
apex/contrib/bottleneck/bottleneck_module_test.py
View file @
88914a50
...
...
@@ -161,7 +161,7 @@ def n_way_spatial(halex, gt_bottleneck, gt, explicit_nhwc, world_size, rank, fp3
spatial_group_rank
=
rank
spatial_communicator
=
None
spatial_halo_exchanger
=
halex
spatial_method
=
2
# 1 -> overlap halo and main conv, 2 -> wait for halo, conv on padded x
spatial_method
=
3
# 1 -> overlap halo and main conv, 2 -> wait for halo, conv on padded x
spatial_parallel_args
=
(
spatial_group_size
,
spatial_group_rank
,
spatial_communicator
,
spatial_halo_exchanger
,
spatial_method
)
spatial_bottleneck
=
spatial_parallel_bottleneck
(
C
,
dtype
,
explicit_nhwc
,
gt_bottleneck
,
spatial_parallel_args
)
...
...
apex/contrib/csrc/bottleneck/bottleneck.cpp
View file @
88914a50
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment