Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
c70f0e32
Commit
c70f0e32
authored
Apr 07, 2022
by
Thor Johnsen
Browse files
Fix deadlock issue when peer memory halo exchanger is used with cuda graph
parent
d8db8c15
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
170 additions
and
58 deletions
+170
-58
apex/contrib/bottleneck/bottleneck.py
apex/contrib/bottleneck/bottleneck.py
+112
-36
apex/contrib/csrc/peer_memory/peer_memory.cpp
apex/contrib/csrc/peer_memory/peer_memory.cpp
+1
-0
apex/contrib/csrc/peer_memory/peer_memory_cuda.cu
apex/contrib/csrc/peer_memory/peer_memory_cuda.cu
+56
-22
apex/contrib/csrc/peer_memory/peer_memory_cuda.cuh
apex/contrib/csrc/peer_memory/peer_memory_cuda.cuh
+1
-0
No files found.
apex/contrib/bottleneck/bottleneck.py
View file @
c70f0e32
import
functools
as
func
import
torch
import
torch.distributed
as
dist
from
torch
import
nn
...
...
@@ -8,6 +9,17 @@ def kaiming_uniform_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu'):
weight_tensor_nchw
=
tensor
nn
.
init
.
kaiming_uniform_
(
weight_tensor_nchw
,
a
=
a
,
mode
=
mode
,
nonlinearity
=
nonlinearity
)
def
compute_scale_bias_one
(
nhwc
,
weight
,
bias
,
running_mean
,
running_var
,
w_scale
,
w_bias
):
scale
=
weight
*
running_var
.
rsqrt
()
bias
=
bias
-
running_mean
*
scale
w_scale
.
copy_
(
scale
)
w_bias
.
copy_
(
bias
)
def
compute_scale_bias_method
(
nhwc
,
args
):
for
arg
in
args
:
# arg is tuple of (weight, bias, running_mean, running_var, w_scale, w_bias)
compute_scale_bias_one
(
nhwc
,
*
arg
)
class
FrozenBatchNorm2d
(
torch
.
jit
.
ScriptModule
):
"""
BatchNorm2d where the batch statistics and the affine parameters are fixed
...
...
@@ -150,6 +162,7 @@ class Bottleneck(torch.nn.Module):
self
.
bn1
=
norm_func
(
bottleneck_channels
)
self
.
bn2
=
norm_func
(
bottleneck_channels
)
self
.
bn3
=
norm_func
(
out_channels
)
self
.
w_scale
=
None
self
.
use_cudnn
=
use_cudnn
...
...
@@ -173,23 +186,47 @@ class Bottleneck(torch.nn.Module):
for
p
in
self
.
parameters
():
with
torch
.
no_grad
():
p
.
data
=
p
.
data
.
permute
(
0
,
2
,
3
,
1
).
contiguous
()
return
# Returns single callable that recomputes scale and bias for all frozen batch-norms.
# This method must be called before cuda graphing.
# The callable it returns can be called anytime.
# Calling this method will prevent these from being computed every forward call.
def
get_scale_bias_callable
(
self
):
self
.
w_scale
,
self
.
w_bias
,
args
=
[],
[],
[]
batch_norms
=
[
self
.
bn1
,
self
.
bn2
,
self
.
bn3
]
if
self
.
downsample
is
not
None
:
batch_norms
.
append
(
self
.
downsample
[
1
])
for
bn
in
batch_norms
:
s
=
torch
.
empty_like
(
bn
.
weight
)
b
=
torch
.
empty_like
(
s
)
args
.
append
(
(
bn
.
weight
,
bn
.
bias
,
bn
.
running_mean
,
bn
.
running_var
,
s
,
b
)
)
if
self
.
explicit_nhwc
:
self
.
w_scale
.
append
(
s
.
reshape
(
1
,
1
,
1
,
-
1
)
)
self
.
w_bias
.
append
(
b
.
reshape
(
1
,
1
,
1
,
-
1
)
)
else
:
self
.
w_scale
.
append
(
s
.
reshape
(
1
,
-
1
,
1
,
1
)
)
self
.
w_bias
.
append
(
b
.
reshape
(
1
,
-
1
,
1
,
1
)
)
return
func
.
partial
(
compute_scale_bias_method
,
self
.
explicit_nhwc
,
args
)
def
forward
(
self
,
x
):
if
self
.
use_cudnn
:
# calculate scale/bias from registered buffers
# TODO: make this better
s1
,
b1
=
self
.
bn1
.
get_scale_bias
(
self
.
explicit_nhwc
)
s2
,
b2
=
self
.
bn2
.
get_scale_bias
(
self
.
explicit_nhwc
)
s3
,
b3
=
self
.
bn3
.
get_scale_bias
(
self
.
explicit_nhwc
)
w_scale
=
[
s1
,
s2
,
s3
]
w_bias
=
[
b1
,
b2
,
b3
]
if
self
.
downsample
is
not
None
:
s4
,
b4
=
self
.
downsample
[
1
].
get_scale_bias
(
self
.
explicit_nhwc
)
w_scale
.
append
(
s4
)
w_bias
.
append
(
b4
)
out
=
bottleneck_function
(
self
.
explicit_nhwc
,
self
.
stride
,
w_scale
,
w_bias
,
x
,
*
self
.
w_conv
)
if
self
.
w_scale
is
None
:
# calculate scale/bias from registered buffers
# TODO: make this better
s1
,
b1
=
self
.
bn1
.
get_scale_bias
(
self
.
explicit_nhwc
)
s2
,
b2
=
self
.
bn2
.
get_scale_bias
(
self
.
explicit_nhwc
)
s3
,
b3
=
self
.
bn3
.
get_scale_bias
(
self
.
explicit_nhwc
)
w_scale
=
[
s1
,
s2
,
s3
]
w_bias
=
[
b1
,
b2
,
b3
]
if
self
.
downsample
is
not
None
:
s4
,
b4
=
self
.
downsample
[
1
].
get_scale_bias
(
self
.
explicit_nhwc
)
w_scale
.
append
(
s4
)
w_bias
.
append
(
b4
)
out
=
bottleneck_function
(
self
.
explicit_nhwc
,
self
.
stride
,
w_scale
,
w_bias
,
x
,
*
self
.
w_conv
)
else
:
out
=
bottleneck_function
(
self
.
explicit_nhwc
,
self
.
stride
,
self
.
w_scale
,
self
.
w_bias
,
x
,
*
self
.
w_conv
)
return
out
if
self
.
explicit_nhwc
:
...
...
@@ -251,7 +288,7 @@ class SpatialBottleneckFunction(torch.autograd.Function):
memory_format
=
torch
.
channels_last
if
out1
.
is_contiguous
(
memory_format
=
torch
.
channels_last
)
else
torch
.
contiguous_format
out1_pad
=
torch
.
empty
([
N
,
C
,
Hs
+
2
,
W
],
dtype
=
out1
.
dtype
,
device
=
'cuda'
,
memory_format
=
memory_format
)
stream1
.
wait_stream
(
torch
.
cuda
.
current_stream
())
stream3
.
wait_stream
(
torch
.
cuda
.
current_stream
())
if
spatial_method
!=
2
:
stream3
.
wait_stream
(
torch
.
cuda
.
current_stream
())
with
torch
.
cuda
.
stream
(
stream3
):
if
explicit_nhwc
:
out1_pad
[:,
1
:
Hs
+
1
,:,:].
copy_
(
out1
)
...
...
@@ -291,7 +328,7 @@ class SpatialBottleneckFunction(torch.autograd.Function):
top_fat_halo
[:,:,:
1
,:].
copy_
(
top_out1_halo
)
top_fat_halo
[:,:,
1
:
3
,:].
copy_
(
out1
[:,:,:
2
,:])
top_out2
=
fast_bottleneck
.
forward_out2_halo
(
explicit_nhwc
,
top_fat_halo
,
args
)
inc
.
add_delay
(
10
)
#
inc.add_delay(10)
elif
spatial_method
!=
2
and
spatial_method
!=
3
:
assert
(
False
),
"spatial_method must be 1, 2 or 3"
...
...
@@ -299,13 +336,26 @@ class SpatialBottleneckFunction(torch.autograd.Function):
fast_bottleneck
.
forward_out2
(
explicit_nhwc
,
stride_1x1
,
args
,
outputs
)
elif
spatial_method
==
1
:
fast_bottleneck
.
forward_out2
(
explicit_nhwc
,
stride_1x1
,
args
,
outputs
)
with
torch
.
cuda
.
stream
(
stream3
):
if
explicit_nhwc
:
out1_pad
[:,
1
:
Hs
+
1
,:,:].
copy_
(
out1
)
else
:
out1_pad
[:,:,
1
:
Hs
+
1
,:].
copy_
(
out1
)
elif
spatial_method
==
2
:
# wait for halo transfer to finish before doing a full convolution of padded x
torch
.
cuda
.
current_stream
().
wait_stream
(
stream1
)
torch
.
cuda
.
current_stream
().
wait_stream
(
stream3
)
if
explicit_nhwc
:
out1_pad
[:,
1
:
Hs
+
1
,:,:].
copy_
(
out1
)
else
:
out1_pad
[:,:,
1
:
Hs
+
1
,:].
copy_
(
out1
)
fast_bottleneck
.
forward_out2_pad
(
explicit_nhwc
,
stride_1x1
,
args
,
outputs
,
out1_pad
)
elif
spatial_method
==
3
:
fast_bottleneck
.
forward_out2_mask
(
explicit_nhwc
,
stride_1x1
,
args
,
outputs
,
thresholdTop
,
thresholdBottom
)
with
torch
.
cuda
.
stream
(
stream3
):
if
explicit_nhwc
:
out1_pad
[:,
1
:
Hs
+
1
,:,:].
copy_
(
out1
)
else
:
out1_pad
[:,:,
1
:
Hs
+
1
,:].
copy_
(
out1
)
# compute halo cells for outputs[1] (out2)
if
spatial_group_size
>
1
:
...
...
@@ -405,6 +455,11 @@ class SpatialBottleneckFunction(torch.autograd.Function):
grad_out2
=
fast_bottleneck
.
backward_grad_out2
(
ctx
.
explicit_nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
)
wgrad2_stream
=
torch
.
cuda
.
Stream
()
wgrad2_stream
.
wait_stream
(
torch
.
cuda
.
current_stream
())
with
torch
.
cuda
.
stream
(
wgrad2_stream
):
if
ctx
.
spatial_group_size
>
1
:
wgrad2
=
fast_bottleneck
.
backward_wgrad2_pad
(
ctx
.
explicit_nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
out1_pad
,
grad_out2
)
else
:
wgrad2
=
fast_bottleneck
.
backward_wgrad2
(
ctx
.
explicit_nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
grad_out2
)
# do halo exchange of grad_out2 here
# compute halo cells for grad_out1
if
ctx
.
spatial_group_size
>
1
:
...
...
@@ -463,16 +518,10 @@ class SpatialBottleneckFunction(torch.autograd.Function):
top_grad_out1_halo
=
top_grad_out1_halo
[:,
1
:
2
,:,:]
else
:
top_grad_out1_halo
=
top_grad_out1_halo
[:,:,
1
:
2
,:]
inc
.
add_delay
(
10
)
#
inc.add_delay(10)
elif
ctx
.
spatial_method
!=
3
:
assert
(
False
),
"spatial_method must be 1, 2 or 3"
with
torch
.
cuda
.
stream
(
wgrad2_stream
):
if
ctx
.
spatial_group_size
>
1
:
wgrad2
=
fast_bottleneck
.
backward_wgrad2_pad
(
ctx
.
explicit_nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
out1_pad
,
grad_out2
)
else
:
wgrad2
=
fast_bottleneck
.
backward_wgrad2
(
ctx
.
explicit_nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
grad_out2
)
# compute grad_out1 for internal cells
if
ctx
.
spatial_group_size
<=
1
or
ctx
.
spatial_method
==
1
or
ctx
.
spatial_method
==
2
:
grad_out1
=
fast_bottleneck
.
backward_grad_out1
(
ctx
.
explicit_nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
grad_out2
)
...
...
@@ -577,6 +626,7 @@ class SpatialBottleneck(torch.nn.Module):
self
.
bn1
=
norm_func
(
bottleneck_channels
)
self
.
bn2
=
norm_func
(
bottleneck_channels
)
self
.
bn3
=
norm_func
(
out_channels
)
self
.
w_scale
=
None
self
.
use_cudnn
=
use_cudnn
...
...
@@ -610,6 +660,27 @@ class SpatialBottleneck(torch.nn.Module):
self
.
spatial_parallel_args
=
spatial_parallel_args
return
# Returns single callable that recomputes scale and bias for all frozen batch-norms.
# This method must be called before cuda graphing.
# The callable it returns can be called anytime.
# Calling this method will prevent these from being computed every forward call.
def
get_scale_bias_callable
(
self
):
self
.
w_scale
,
self
.
w_bias
,
args
=
[],
[],
[]
batch_norms
=
[
self
.
bn1
,
self
.
bn2
,
self
.
bn3
]
if
self
.
downsample
is
not
None
:
batch_norms
.
append
(
self
.
downsample
[
1
])
for
bn
in
batch_norms
:
s
=
torch
.
empty_like
(
bn
.
weight
)
b
=
torch
.
empty_like
(
s
)
args
.
append
(
(
bn
.
weight
,
bn
.
bias
,
bn
.
running_mean
,
bn
.
running_var
,
s
,
b
)
)
if
self
.
explicit_nhwc
:
self
.
w_scale
.
append
(
s
.
reshape
(
1
,
1
,
1
,
-
1
)
)
self
.
w_bias
.
append
(
b
.
reshape
(
1
,
1
,
1
,
-
1
)
)
else
:
self
.
w_scale
.
append
(
s
.
reshape
(
1
,
-
1
,
1
,
1
)
)
self
.
w_bias
.
append
(
b
.
reshape
(
1
,
-
1
,
1
,
1
)
)
return
func
.
partial
(
compute_scale_bias_method
,
self
.
explicit_nhwc
,
args
)
def
forward
(
self
,
x
):
if
self
.
use_cudnn
:
if
self
.
thresholdTop
is
None
:
...
...
@@ -620,19 +691,24 @@ class SpatialBottleneck(torch.nn.Module):
N
,
C
,
H
,
W
=
list
(
x
.
shape
)
self
.
thresholdTop
=
torch
.
tensor
([
1
if
spatial_group_rank
>
0
else
0
],
dtype
=
torch
.
int32
,
device
=
'cuda'
)
self
.
thresholdBottom
=
torch
.
tensor
([
H
-
2
if
spatial_group_rank
<
spatial_group_size
-
1
else
H
-
1
],
dtype
=
torch
.
int32
,
device
=
'cuda'
)
# calculate scale/bias from registered buffers
# TODO: make this better
s1
,
b1
=
self
.
bn1
.
get_scale_bias
(
self
.
explicit_nhwc
)
s2
,
b2
=
self
.
bn2
.
get_scale_bias
(
self
.
explicit_nhwc
)
s3
,
b3
=
self
.
bn3
.
get_scale_bias
(
self
.
explicit_nhwc
)
w_scale
=
[
s1
,
s2
,
s3
]
w_bias
=
[
b1
,
b2
,
b3
]
if
self
.
downsample
is
not
None
:
s4
,
b4
=
self
.
downsample
[
1
].
get_scale_bias
(
self
.
explicit_nhwc
)
w_scale
.
append
(
s4
)
w_bias
.
append
(
b4
)
out
=
spatial_bottleneck_function
(
*
self
.
spatial_parallel_args
,
self
.
explicit_nhwc
,
self
.
stride
,
w_scale
,
w_bias
,
self
.
thresholdTop
,
self
.
thresholdBottom
,
x
,
*
self
.
w_conv
)
if
self
.
w_scale
is
None
:
# calculate scale/bias from registered buffers
# TODO: make this better
s1
,
b1
=
self
.
bn1
.
get_scale_bias
(
self
.
explicit_nhwc
)
s2
,
b2
=
self
.
bn2
.
get_scale_bias
(
self
.
explicit_nhwc
)
s3
,
b3
=
self
.
bn3
.
get_scale_bias
(
self
.
explicit_nhwc
)
w_scale
=
[
s1
,
s2
,
s3
]
w_bias
=
[
b1
,
b2
,
b3
]
if
self
.
downsample
is
not
None
:
s4
,
b4
=
self
.
downsample
[
1
].
get_scale_bias
(
self
.
explicit_nhwc
)
w_scale
.
append
(
s4
)
w_bias
.
append
(
b4
)
self
.
w_scale
=
w_scale
self
.
w_bias
=
w_bias
out
=
spatial_bottleneck_function
(
*
self
.
spatial_parallel_args
,
self
.
explicit_nhwc
,
self
.
stride
,
w_scale
,
w_bias
,
self
.
thresholdTop
,
self
.
thresholdBottom
,
x
,
*
self
.
w_conv
)
else
:
out
=
spatial_bottleneck_function
(
*
self
.
spatial_parallel_args
,
self
.
explicit_nhwc
,
self
.
stride
,
self
.
w_scale
,
self
.
w_bias
,
self
.
thresholdTop
,
self
.
thresholdBottom
,
x
,
*
self
.
w_conv
)
return
out
if
self
.
explicit_nhwc
:
...
...
apex/contrib/csrc/peer_memory/peer_memory.cpp
View file @
c70f0e32
...
...
@@ -19,6 +19,7 @@
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"allocate_raw"
,
&
apex
::
contrib
::
peer_memory
::
allocate_raw
,
"allocate_raw"
);
m
.
def
(
"free_raw"
,
&
apex
::
contrib
::
peer_memory
::
free_raw
,
"free_raw"
);
m
.
def
(
"zero"
,
&
apex
::
contrib
::
peer_memory
::
zero
,
"zero"
);
m
.
def
(
"get_raw_ipc_address"
,
&
apex
::
contrib
::
peer_memory
::
get_raw_ipc_address
,
"get_raw_ipc_address"
);
m
.
def
(
"get_raw_peers"
,
&
apex
::
contrib
::
peer_memory
::
get_raw_peers
,
"get_raw_peers"
);
m
.
def
(
"blob_view_half"
,
&
apex
::
contrib
::
peer_memory
::
blob_view_half
,
"blob_view_half"
);
...
...
apex/contrib/csrc/peer_memory/peer_memory_cuda.cu
View file @
c70f0e32
...
...
@@ -148,33 +148,58 @@ __device__ void strided_copy_kernel(
}
}
template
<
bool
wait
,
bool
clear
>
__device__
void
dual_signal_wait_clear
(
volatile
int
*
signal1_flag
,
volatile
int
*
wait1_flag
,
volatile
int
*
signal2_flag
,
volatile
int
*
wait2_flag
,
__device__
void
checked_signal
(
volatile
int
*
signal1_flag
,
volatile
int
*
signal2_flag
,
const
int
v1
,
const
int
v2
,
const
int
v3
,
const
int
v4
)
{
register
int
r1
,
r2
,
r3
,
r4
,
r5
,
r6
,
r7
,
r8
;
bool
is_main_thread
=
(
blockIdx
.
x
==
0
&&
threadIdx
.
x
==
0
)
?
true
:
false
;
// signal and wait
if
(
is_main_thread
)
{
asm
volatile
(
"st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};"
::
"l"
(
signal1_flag
),
"r"
(
v1
),
"r"
(
v2
),
"r"
(
v3
),
"r"
(
v4
)
:
"memory"
);
asm
volatile
(
"st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};"
::
"l"
(
signal2_flag
),
"r"
(
v1
),
"r"
(
v2
),
"r"
(
v3
),
"r"
(
v4
)
:
"memory"
);
if
(
wait
)
{
if
(
blockIdx
.
x
==
0
)
{
register
int
r1
,
r2
,
r3
,
r4
;
if
(
threadIdx
.
x
==
0
)
{
// wait for top neighbor to clear bottom signal (indicating ready for new input)
do
{
asm
volatile
(
"ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];"
:
"=r"
(
r1
),
"=r"
(
r2
),
"=r"
(
r3
),
"=r"
(
r4
)
:
"l"
(
signal1_flag
)
:
"memory"
);
}
while
(
r1
==
v1
&&
r2
==
v2
&&
r3
==
v3
&&
r4
==
v4
);
// signal to top neighbor my output is ready
asm
volatile
(
"st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};"
::
"l"
(
signal1_flag
),
"r"
(
v1
),
"r"
(
v2
),
"r"
(
v3
),
"r"
(
v4
)
:
"memory"
);
}
else
if
(
threadIdx
.
x
==
1
)
{
// wait for bottom neighbor to clear top signal (indicating ready for new input)
do
{
asm
volatile
(
"ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];"
:
"=r"
(
r1
),
"=r"
(
r2
),
"=r"
(
r3
),
"=r"
(
r4
)
:
"l"
(
wait1_flag
)
:
"memory"
);
asm
volatile
(
"ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];"
:
"=r"
(
r5
),
"=r"
(
r6
),
"=r"
(
r7
),
"=r"
(
r8
)
:
"l"
(
wait2_flag
)
:
"memory"
);
}
while
(
r1
!=
v1
||
r5
!=
v1
||
r2
!=
v2
||
r6
!=
v2
||
r3
!=
v3
||
r7
!=
v3
||
r4
!=
v4
||
r8
!=
v4
);
asm
volatile
(
"ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];"
:
"=r"
(
r1
),
"=r"
(
r2
),
"=r"
(
r3
),
"=r"
(
r4
)
:
"l"
(
signal2_flag
)
:
"memory"
);
}
while
(
r1
==
v1
&&
r2
==
v2
&&
r3
==
v3
&&
r4
==
v4
);
// signal to bottom neighbor my output is ready
asm
volatile
(
"st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};"
::
"l"
(
signal2_flag
),
"r"
(
v1
),
"r"
(
v2
),
"r"
(
v3
),
"r"
(
v4
)
:
"memory"
);
}
}
cg
::
this_grid
().
sync
();
if
(
clear
)
{
if
(
is_main_thread
)
{
r1
=
0
;
r2
=
0
;
r3
=
0
;
r4
=
0
;
asm
volatile
(
"st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};"
::
"l"
(
wait1_flag
),
"r"
(
r1
),
"r"
(
r2
),
"r"
(
r3
),
"r"
(
r4
)
:
"memory"
);
asm
volatile
(
"st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};"
::
"l"
(
wait2_flag
),
"r"
(
r1
),
"r"
(
r2
),
"r"
(
r3
),
"r"
(
r4
)
:
"memory"
);
}
}
__device__
void
wait_for
(
volatile
int
*
wait_flag
,
const
int
v1
,
const
int
v2
,
const
int
v3
,
const
int
v4
)
{
bool
is_main_thread
=
(
blockIdx
.
x
==
0
&&
threadIdx
.
x
==
0
)
?
true
:
false
;
if
(
is_main_thread
)
{
register
int
r1
,
r2
,
r3
,
r4
;
// wait for senders to signal their output is read
do
{
asm
volatile
(
"ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];"
:
"=r"
(
r1
),
"=r"
(
r2
),
"=r"
(
r3
),
"=r"
(
r4
)
:
"l"
(
wait_flag
)
:
"memory"
);
}
while
(
r1
!=
v1
||
r2
!=
v2
||
r3
!=
v3
||
r4
!=
v4
);
}
cg
::
this_grid
().
sync
();
// all threads wait for main
}
__device__
void
clear_flag
(
volatile
int
*
wait_flag
)
{
cg
::
this_grid
().
sync
();
// wait for all threads in kernel to finish
bool
is_main_thread
=
(
blockIdx
.
x
==
0
&&
threadIdx
.
x
==
0
)
?
true
:
false
;
if
(
is_main_thread
)
{
register
int
r1
,
r2
,
r3
,
r4
;
r1
=
0
;
r2
=
0
;
r3
=
0
;
r4
=
0
;
asm
volatile
(
"st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};"
::
"l"
(
wait_flag
),
"r"
(
r1
),
"r"
(
r2
),
"r"
(
r3
),
"r"
(
r4
)
:
"memory"
);
}
}
...
...
@@ -208,11 +233,15 @@ __global__ void push_pull_halos_1d_kernel(
strided_copy_kernel
<
T
,
is_HWC
>
(
box
,
box_stride_C
,
box_stride_H
,
box_stride_W
,
boh
,
boh_stride_C
,
boh_stride_H
,
boh_stride_W
,
NC
,
NH
,
NW
);
// signal to top and btm neigbhbors that output halos are ready to be read
// the choice of values for v1-v4 is arbitrary and does not matter, as long as all ranks use the same values
dual_signal_wait_clear
<
true
,
true
>
(
signal1_flag
,
wait1_flag
,
signal2_flag
,
wait
2_flag
,
-
987751720
,
840868300
,
-
225529332
,
281513358
);
checked_signal
(
signal1_flag
,
signal
2_flag
,
-
987751720
,
840868300
,
-
225529332
,
281513358
);
// pull top halo from transfer buffer in peer memory to input
wait_for
(
wait1_flag
,
-
987751720
,
840868300
,
-
225529332
,
281513358
);
strided_copy_kernel
<
T
,
is_HWC
>
(
tih
,
tih_stride_C
,
tih_stride_H
,
tih_stride_W
,
tix
,
tix_stride_C
,
tix_stride_H
,
tix_stride_W
,
NC
,
NH
,
NW
);
clear_flag
(
wait1_flag
);
// pull btm halo from transfer buffer in peer memory to input
wait_for
(
wait2_flag
,
-
987751720
,
840868300
,
-
225529332
,
281513358
);
strided_copy_kernel
<
T
,
is_HWC
>
(
bih
,
bih_stride_C
,
bih_stride_H
,
bih_stride_W
,
bix
,
bix_stride_C
,
bix_stride_H
,
bix_stride_W
,
NC
,
NH
,
NW
);
clear_flag
(
wait2_flag
);
}
__global__
void
delay_kernel
(
int
delay_nanoseconds
,
int
*
counter
)
...
...
@@ -248,6 +277,11 @@ void free_raw(int64_t raw)
cudaFree
((
void
*
)
raw
);
}
void
zero
(
int64_t
raw
,
int64_t
size
)
{
cudaMemset
((
void
*
)
raw
,
0
,
size
);
}
at
::
Tensor
get_raw_ipc_address
(
int64_t
raw
)
{
cudaIpcMemHandle_t
mem_handle
;
...
...
apex/contrib/csrc/peer_memory/peer_memory_cuda.cuh
View file @
c70f0e32
...
...
@@ -22,6 +22,7 @@
namespace
apex
{
namespace
contrib
{
namespace
peer_memory
{
int64_t
allocate_raw
(
int64_t
size
);
void
free_raw
(
int64_t
raw
);
void
zero
(
int64_t
raw
,
int64_t
size
);
at
::
Tensor
get_raw_ipc_address
(
int64_t
raw
);
std
::
vector
<
int64_t
>
get_raw_peers
(
at
::
Tensor
ipc_addresses
,
int
peer_rank
,
int64_t
raw
);
at
::
Tensor
blob_view_half
(
int64_t
raw
,
std
::
vector
<
int64_t
>
shape
,
bool
channels_last
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment