Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
705aa35d
"git@developer.sourcefind.cn:ox696c/ktransformers.git" did not exist on "0cd81fa85c0ffe3f44a21ecc7e7bdd5a15dbdabf"
Commit
705aa35d
authored
Apr 01, 2022
by
Thor Johnsen
Browse files
Fix halo correction kernel
parent
60000f73
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
33 additions
and
19 deletions
+33
-19
apex/contrib/bottleneck/bottleneck.py
apex/contrib/bottleneck/bottleneck.py
+33
-19
No files found.
apex/contrib/bottleneck/bottleneck.py
View file @
705aa35d
...
@@ -268,6 +268,17 @@ class SpatialBottleneckFunction(torch.autograd.Function):
...
@@ -268,6 +268,17 @@ class SpatialBottleneckFunction(torch.autograd.Function):
spatial_halo_exchanger
.
left_right_halo_exchange
(
out1
[:,:,:
1
,:],
out1
[:,:,
Hs
-
1
:,:],
top_out1_halo
,
btm_out1_halo
)
spatial_halo_exchanger
.
left_right_halo_exchange
(
out1
[:,:,:
1
,:],
out1
[:,:,
Hs
-
1
:,:],
top_out1_halo
,
btm_out1_halo
)
if
spatial_method
==
1
:
if
spatial_method
==
1
:
# overlap mid convolution with halo transfer
# overlap mid convolution with halo transfer
if
spatial_group_rank
>
0
:
with
torch
.
cuda
.
stream
(
stream1
):
if
explicit_nhwc
:
top_fat_halo
=
torch
.
empty
((
N
,
3
,
W
,
C
),
dtype
=
out1
.
dtype
,
device
=
out1
.
device
)
top_fat_halo
[:,:
1
,:,:].
copy_
(
top_out1_halo
)
top_fat_halo
[:,
1
:
3
,:,:].
copy_
(
out1
[:,:
2
,:,:])
else
:
top_fat_halo
=
torch
.
empty
((
N
,
C
,
3
,
W
),
dtype
=
out1
.
dtype
,
device
=
out1
.
device
)
top_fat_halo
[:,:,:
1
,:].
copy_
(
top_out1_halo
)
top_fat_halo
[:,:,
1
:
3
,:].
copy_
(
out1
[:,:,:
2
,:])
top_out2
=
fast_bottleneck
.
forward_out2_halo
(
explicit_nhwc
,
top_fat_halo
,
args
)
if
spatial_group_rank
<
spatial_group_size
-
1
:
if
spatial_group_rank
<
spatial_group_size
-
1
:
stream2
.
wait_stream
(
stream1
)
stream2
.
wait_stream
(
stream1
)
with
torch
.
cuda
.
stream
(
stream2
):
with
torch
.
cuda
.
stream
(
stream2
):
...
@@ -280,17 +291,6 @@ class SpatialBottleneckFunction(torch.autograd.Function):
...
@@ -280,17 +291,6 @@ class SpatialBottleneckFunction(torch.autograd.Function):
btm_fat_halo
[:,:,
0
:
2
,:].
copy_
(
out1
[:,:,
Hs
-
2
:,:])
btm_fat_halo
[:,:,
0
:
2
,:].
copy_
(
out1
[:,:,
Hs
-
2
:,:])
btm_fat_halo
[:,:,
2
:,:].
copy_
(
btm_out1_halo
)
btm_fat_halo
[:,:,
2
:,:].
copy_
(
btm_out1_halo
)
btm_out2
=
fast_bottleneck
.
forward_out2_halo
(
explicit_nhwc
,
btm_fat_halo
,
args
)
btm_out2
=
fast_bottleneck
.
forward_out2_halo
(
explicit_nhwc
,
btm_fat_halo
,
args
)
if
spatial_group_rank
>
0
:
with
torch
.
cuda
.
stream
(
stream1
):
if
explicit_nhwc
:
top_fat_halo
=
torch
.
empty
((
N
,
3
,
W
,
C
),
dtype
=
out1
.
dtype
,
device
=
out1
.
device
)
top_fat_halo
[:,:
1
,:,:].
copy_
(
top_out1_halo
)
top_fat_halo
[:,
1
:
3
,:,:].
copy_
(
out1
[:,:
2
,:,:])
else
:
top_fat_halo
=
torch
.
empty
((
N
,
C
,
3
,
W
),
dtype
=
out1
.
dtype
,
device
=
out1
.
device
)
top_fat_halo
[:,:,:
1
,:].
copy_
(
top_out1_halo
)
top_fat_halo
[:,:,
1
:
3
,:].
copy_
(
out1
[:,:,:
2
,:])
top_out2
=
fast_bottleneck
.
forward_out2_halo
(
explicit_nhwc
,
top_fat_halo
,
args
)
inc
.
add_delay
(
10
)
inc
.
add_delay
(
10
)
elif
spatial_method
!=
2
and
spatial_method
!=
3
:
elif
spatial_method
!=
2
and
spatial_method
!=
3
:
assert
(
False
),
"spatial_method must be 1, 2 or 3"
assert
(
False
),
"spatial_method must be 1, 2 or 3"
...
@@ -324,21 +324,35 @@ class SpatialBottleneckFunction(torch.autograd.Function):
...
@@ -324,21 +324,35 @@ class SpatialBottleneckFunction(torch.autograd.Function):
torch
.
cuda
.
current_stream
().
wait_stream
(
stream2
)
torch
.
cuda
.
current_stream
().
wait_stream
(
stream2
)
btm_out2_halo
.
copy_
(
btm_out2
)
btm_out2_halo
.
copy_
(
btm_out2
)
elif
spatial_method
==
3
:
elif
spatial_method
==
3
:
# Note
# out2 halo correction cannot overlap with anything since it has
# to wait for out2_mask to finish, but itself has to finish before
# the first kernel of _forward_rest can launch.
# At least we can overlap the two halo correction kernels.
if
spatial_group_rank
>
0
:
if
spatial_group_rank
>
0
:
w1by3
=
args
[
2
][:,:,
2
:
3
,:].
contiguous
(
memory_format
=
torch
.
preserve
)
stream1
.
wait_stream
(
torch
.
cuda
.
current_stream
())
# wait for *_out2_mask to finish
top_out1_halo
=
top_out1_halo
.
contiguous
(
memory_format
=
memory_format
)
with
torch
.
cuda
.
stream
(
stream1
):
top_out2
=
fast_bottleneck
.
forward_out2_halo_corr
(
explicit_nhwc
,
top_out1_halo
,
args
,
w1by3
,
top_out2_halo
.
contiguous
(
memory_format
=
memory_format
))
w1by3
=
args
[
2
][:,:
1
,:,:].
clone
()
top_out2_halo
.
copy_
(
top_out2
)
top_out1_halo
=
top_out1_halo
.
clone
()
top_out2
=
fast_bottleneck
.
forward_out2_halo_corr
(
explicit_nhwc
,
top_out1_halo
,
args
,
w1by3
,
top_out2_halo
.
clone
())
top_out2_halo
.
copy_
(
top_out2
)
if
spatial_group_rank
<
spatial_group_size
-
1
:
if
spatial_group_rank
<
spatial_group_size
-
1
:
w1by3
=
args
[
2
][:,:,:
1
,:].
contiguous
(
memory_format
=
torch
.
preserve
)
stream2
.
wait_stream
(
torch
.
cuda
.
current_stream
())
# wait for *_out2_mask to finish
btm_out1_halo
=
btm_out1_halo
.
contiguous
(
memory_format
=
memory_format
)
with
torch
.
cuda
.
stream
(
stream2
):
btm_out2
=
fast_bottleneck
.
forward_out2_halo_corr
(
explicit_nhwc
,
btm_out1_halo
,
args
,
w1by3
,
btm_out2_halo
.
contiguous
(
memory_format
=
memory_format
))
w1by3
=
args
[
2
][:,
2
:
3
,:,:].
clone
()
btm_out2_halo
.
copy_
(
btm_out2
)
btm_out1_halo
=
btm_out1_halo
.
clone
()
btm_out2
=
fast_bottleneck
.
forward_out2_halo_corr
(
explicit_nhwc
,
btm_out1_halo
,
args
,
w1by3
,
btm_out2_halo
.
clone
())
btm_out2_halo
.
copy_
(
btm_out2
)
if
spatial_group_rank
>
0
:
torch
.
cuda
.
current_stream
().
wait_stream
(
stream1
)
if
spatial_group_rank
<
spatial_group_size
-
1
:
torch
.
cuda
.
current_stream
().
wait_stream
(
stream2
)
fast_bottleneck
.
forward_rest
(
explicit_nhwc
,
stride_1x1
,
args
,
outputs
)
fast_bottleneck
.
forward_rest
(
explicit_nhwc
,
stride_1x1
,
args
,
outputs
)
# save halos for backward pass
# save halos for backward pass
if
spatial_group_size
>
1
:
if
spatial_group_size
>
1
:
if
spatial_method
!=
2
:
if
spatial_method
!=
2
:
# make sure copy of mid-section of out1 into out1_pad is done before exiting
torch
.
cuda
.
current_stream
().
wait_stream
(
stream3
)
torch
.
cuda
.
current_stream
().
wait_stream
(
stream3
)
ctx
.
save_for_backward
(
*
(
args
+
outputs
+
[
out1_pad
,]))
ctx
.
save_for_backward
(
*
(
args
+
outputs
+
[
out1_pad
,]))
else
:
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment