Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
c70f0e32
Commit
c70f0e32
authored
Apr 07, 2022
by
Thor Johnsen
Browse files
Fix deadlock issue when peer memory halo exchanger is used with cuda graph
parent
d8db8c15
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
170 additions
and
58 deletions
+170
-58
apex/contrib/bottleneck/bottleneck.py
apex/contrib/bottleneck/bottleneck.py
+112
-36
apex/contrib/csrc/peer_memory/peer_memory.cpp
apex/contrib/csrc/peer_memory/peer_memory.cpp
+1
-0
apex/contrib/csrc/peer_memory/peer_memory_cuda.cu
apex/contrib/csrc/peer_memory/peer_memory_cuda.cu
+56
-22
apex/contrib/csrc/peer_memory/peer_memory_cuda.cuh
apex/contrib/csrc/peer_memory/peer_memory_cuda.cuh
+1
-0
No files found.
apex/contrib/bottleneck/bottleneck.py
View file @
c70f0e32
import
functools
as
func
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
torch
import
nn
from
torch
import
nn
...
@@ -8,6 +9,17 @@ def kaiming_uniform_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu'):
...
@@ -8,6 +9,17 @@ def kaiming_uniform_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu'):
weight_tensor_nchw
=
tensor
weight_tensor_nchw
=
tensor
nn
.
init
.
kaiming_uniform_
(
weight_tensor_nchw
,
a
=
a
,
mode
=
mode
,
nonlinearity
=
nonlinearity
)
nn
.
init
.
kaiming_uniform_
(
weight_tensor_nchw
,
a
=
a
,
mode
=
mode
,
nonlinearity
=
nonlinearity
)
def
compute_scale_bias_one
(
nhwc
,
weight
,
bias
,
running_mean
,
running_var
,
w_scale
,
w_bias
):
scale
=
weight
*
running_var
.
rsqrt
()
bias
=
bias
-
running_mean
*
scale
w_scale
.
copy_
(
scale
)
w_bias
.
copy_
(
bias
)
def
compute_scale_bias_method
(
nhwc
,
args
):
for
arg
in
args
:
# arg is tuple of (weight, bias, running_mean, running_var, w_scale, w_bias)
compute_scale_bias_one
(
nhwc
,
*
arg
)
class
FrozenBatchNorm2d
(
torch
.
jit
.
ScriptModule
):
class
FrozenBatchNorm2d
(
torch
.
jit
.
ScriptModule
):
"""
"""
BatchNorm2d where the batch statistics and the affine parameters are fixed
BatchNorm2d where the batch statistics and the affine parameters are fixed
...
@@ -150,6 +162,7 @@ class Bottleneck(torch.nn.Module):
...
@@ -150,6 +162,7 @@ class Bottleneck(torch.nn.Module):
self
.
bn1
=
norm_func
(
bottleneck_channels
)
self
.
bn1
=
norm_func
(
bottleneck_channels
)
self
.
bn2
=
norm_func
(
bottleneck_channels
)
self
.
bn2
=
norm_func
(
bottleneck_channels
)
self
.
bn3
=
norm_func
(
out_channels
)
self
.
bn3
=
norm_func
(
out_channels
)
self
.
w_scale
=
None
self
.
use_cudnn
=
use_cudnn
self
.
use_cudnn
=
use_cudnn
...
@@ -173,23 +186,47 @@ class Bottleneck(torch.nn.Module):
...
@@ -173,23 +186,47 @@ class Bottleneck(torch.nn.Module):
for
p
in
self
.
parameters
():
for
p
in
self
.
parameters
():
with
torch
.
no_grad
():
with
torch
.
no_grad
():
p
.
data
=
p
.
data
.
permute
(
0
,
2
,
3
,
1
).
contiguous
()
p
.
data
=
p
.
data
.
permute
(
0
,
2
,
3
,
1
).
contiguous
()
return
return
# Returns single callable that recomputes scale and bias for all frozen batch-norms.
# This method must be called before cuda graphing.
# The callable it returns can be called anytime.
# Calling this method will prevent these from being computed every forward call.
def
get_scale_bias_callable
(
self
):
self
.
w_scale
,
self
.
w_bias
,
args
=
[],
[],
[]
batch_norms
=
[
self
.
bn1
,
self
.
bn2
,
self
.
bn3
]
if
self
.
downsample
is
not
None
:
batch_norms
.
append
(
self
.
downsample
[
1
])
for
bn
in
batch_norms
:
s
=
torch
.
empty_like
(
bn
.
weight
)
b
=
torch
.
empty_like
(
s
)
args
.
append
(
(
bn
.
weight
,
bn
.
bias
,
bn
.
running_mean
,
bn
.
running_var
,
s
,
b
)
)
if
self
.
explicit_nhwc
:
self
.
w_scale
.
append
(
s
.
reshape
(
1
,
1
,
1
,
-
1
)
)
self
.
w_bias
.
append
(
b
.
reshape
(
1
,
1
,
1
,
-
1
)
)
else
:
self
.
w_scale
.
append
(
s
.
reshape
(
1
,
-
1
,
1
,
1
)
)
self
.
w_bias
.
append
(
b
.
reshape
(
1
,
-
1
,
1
,
1
)
)
return
func
.
partial
(
compute_scale_bias_method
,
self
.
explicit_nhwc
,
args
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
if
self
.
use_cudnn
:
if
self
.
use_cudnn
:
# calculate scale/bias from registered buffers
if
self
.
w_scale
is
None
:
# TODO: make this better
# calculate scale/bias from registered buffers
s1
,
b1
=
self
.
bn1
.
get_scale_bias
(
self
.
explicit_nhwc
)
# TODO: make this better
s2
,
b2
=
self
.
bn2
.
get_scale_bias
(
self
.
explicit_nhwc
)
s1
,
b1
=
self
.
bn1
.
get_scale_bias
(
self
.
explicit_nhwc
)
s3
,
b3
=
self
.
bn3
.
get_scale_bias
(
self
.
explicit_nhwc
)
s2
,
b2
=
self
.
bn2
.
get_scale_bias
(
self
.
explicit_nhwc
)
w_scale
=
[
s1
,
s2
,
s3
]
s3
,
b3
=
self
.
bn3
.
get_scale_bias
(
self
.
explicit_nhwc
)
w_bias
=
[
b1
,
b2
,
b3
]
w_scale
=
[
s1
,
s2
,
s3
]
if
self
.
downsample
is
not
None
:
w_bias
=
[
b1
,
b2
,
b3
]
s4
,
b4
=
self
.
downsample
[
1
].
get_scale_bias
(
self
.
explicit_nhwc
)
if
self
.
downsample
is
not
None
:
w_scale
.
append
(
s4
)
s4
,
b4
=
self
.
downsample
[
1
].
get_scale_bias
(
self
.
explicit_nhwc
)
w_bias
.
append
(
b4
)
w_scale
.
append
(
s4
)
w_bias
.
append
(
b4
)
out
=
bottleneck_function
(
self
.
explicit_nhwc
,
self
.
stride
,
w_scale
,
w_bias
,
x
,
*
self
.
w_conv
)
out
=
bottleneck_function
(
self
.
explicit_nhwc
,
self
.
stride
,
w_scale
,
w_bias
,
x
,
*
self
.
w_conv
)
else
:
out
=
bottleneck_function
(
self
.
explicit_nhwc
,
self
.
stride
,
self
.
w_scale
,
self
.
w_bias
,
x
,
*
self
.
w_conv
)
return
out
return
out
if
self
.
explicit_nhwc
:
if
self
.
explicit_nhwc
:
...
@@ -251,7 +288,7 @@ class SpatialBottleneckFunction(torch.autograd.Function):
...
@@ -251,7 +288,7 @@ class SpatialBottleneckFunction(torch.autograd.Function):
memory_format
=
torch
.
channels_last
if
out1
.
is_contiguous
(
memory_format
=
torch
.
channels_last
)
else
torch
.
contiguous_format
memory_format
=
torch
.
channels_last
if
out1
.
is_contiguous
(
memory_format
=
torch
.
channels_last
)
else
torch
.
contiguous_format
out1_pad
=
torch
.
empty
([
N
,
C
,
Hs
+
2
,
W
],
dtype
=
out1
.
dtype
,
device
=
'cuda'
,
memory_format
=
memory_format
)
out1_pad
=
torch
.
empty
([
N
,
C
,
Hs
+
2
,
W
],
dtype
=
out1
.
dtype
,
device
=
'cuda'
,
memory_format
=
memory_format
)
stream1
.
wait_stream
(
torch
.
cuda
.
current_stream
())
stream1
.
wait_stream
(
torch
.
cuda
.
current_stream
())
stream3
.
wait_stream
(
torch
.
cuda
.
current_stream
())
if
spatial_method
!=
2
:
stream3
.
wait_stream
(
torch
.
cuda
.
current_stream
())
with
torch
.
cuda
.
stream
(
stream3
):
with
torch
.
cuda
.
stream
(
stream3
):
if
explicit_nhwc
:
if
explicit_nhwc
:
out1_pad
[:,
1
:
Hs
+
1
,:,:].
copy_
(
out1
)
out1_pad
[:,
1
:
Hs
+
1
,:,:].
copy_
(
out1
)
...
@@ -291,7 +328,7 @@ class SpatialBottleneckFunction(torch.autograd.Function):
...
@@ -291,7 +328,7 @@ class SpatialBottleneckFunction(torch.autograd.Function):
top_fat_halo
[:,:,:
1
,:].
copy_
(
top_out1_halo
)
top_fat_halo
[:,:,:
1
,:].
copy_
(
top_out1_halo
)
top_fat_halo
[:,:,
1
:
3
,:].
copy_
(
out1
[:,:,:
2
,:])
top_fat_halo
[:,:,
1
:
3
,:].
copy_
(
out1
[:,:,:
2
,:])
top_out2
=
fast_bottleneck
.
forward_out2_halo
(
explicit_nhwc
,
top_fat_halo
,
args
)
top_out2
=
fast_bottleneck
.
forward_out2_halo
(
explicit_nhwc
,
top_fat_halo
,
args
)
inc
.
add_delay
(
10
)
#
inc.add_delay(10)
elif
spatial_method
!=
2
and
spatial_method
!=
3
:
elif
spatial_method
!=
2
and
spatial_method
!=
3
:
assert
(
False
),
"spatial_method must be 1, 2 or 3"
assert
(
False
),
"spatial_method must be 1, 2 or 3"
...
@@ -299,13 +336,26 @@ class SpatialBottleneckFunction(torch.autograd.Function):
...
@@ -299,13 +336,26 @@ class SpatialBottleneckFunction(torch.autograd.Function):
fast_bottleneck
.
forward_out2
(
explicit_nhwc
,
stride_1x1
,
args
,
outputs
)
fast_bottleneck
.
forward_out2
(
explicit_nhwc
,
stride_1x1
,
args
,
outputs
)
elif
spatial_method
==
1
:
elif
spatial_method
==
1
:
fast_bottleneck
.
forward_out2
(
explicit_nhwc
,
stride_1x1
,
args
,
outputs
)
fast_bottleneck
.
forward_out2
(
explicit_nhwc
,
stride_1x1
,
args
,
outputs
)
with
torch
.
cuda
.
stream
(
stream3
):
if
explicit_nhwc
:
out1_pad
[:,
1
:
Hs
+
1
,:,:].
copy_
(
out1
)
else
:
out1_pad
[:,:,
1
:
Hs
+
1
,:].
copy_
(
out1
)
elif
spatial_method
==
2
:
elif
spatial_method
==
2
:
# wait for halo transfer to finish before doing a full convolution of padded x
# wait for halo transfer to finish before doing a full convolution of padded x
torch
.
cuda
.
current_stream
().
wait_stream
(
stream1
)
torch
.
cuda
.
current_stream
().
wait_stream
(
stream1
)
torch
.
cuda
.
current_stream
().
wait_stream
(
stream3
)
if
explicit_nhwc
:
out1_pad
[:,
1
:
Hs
+
1
,:,:].
copy_
(
out1
)
else
:
out1_pad
[:,:,
1
:
Hs
+
1
,:].
copy_
(
out1
)
fast_bottleneck
.
forward_out2_pad
(
explicit_nhwc
,
stride_1x1
,
args
,
outputs
,
out1_pad
)
fast_bottleneck
.
forward_out2_pad
(
explicit_nhwc
,
stride_1x1
,
args
,
outputs
,
out1_pad
)
elif
spatial_method
==
3
:
elif
spatial_method
==
3
:
fast_bottleneck
.
forward_out2_mask
(
explicit_nhwc
,
stride_1x1
,
args
,
outputs
,
thresholdTop
,
thresholdBottom
)
fast_bottleneck
.
forward_out2_mask
(
explicit_nhwc
,
stride_1x1
,
args
,
outputs
,
thresholdTop
,
thresholdBottom
)
with
torch
.
cuda
.
stream
(
stream3
):
if
explicit_nhwc
:
out1_pad
[:,
1
:
Hs
+
1
,:,:].
copy_
(
out1
)
else
:
out1_pad
[:,:,
1
:
Hs
+
1
,:].
copy_
(
out1
)
# compute halo cells for outputs[1] (out2)
# compute halo cells for outputs[1] (out2)
if
spatial_group_size
>
1
:
if
spatial_group_size
>
1
:
...
@@ -405,6 +455,11 @@ class SpatialBottleneckFunction(torch.autograd.Function):
...
@@ -405,6 +455,11 @@ class SpatialBottleneckFunction(torch.autograd.Function):
grad_out2
=
fast_bottleneck
.
backward_grad_out2
(
ctx
.
explicit_nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
)
grad_out2
=
fast_bottleneck
.
backward_grad_out2
(
ctx
.
explicit_nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
)
wgrad2_stream
=
torch
.
cuda
.
Stream
()
wgrad2_stream
=
torch
.
cuda
.
Stream
()
wgrad2_stream
.
wait_stream
(
torch
.
cuda
.
current_stream
())
wgrad2_stream
.
wait_stream
(
torch
.
cuda
.
current_stream
())
with
torch
.
cuda
.
stream
(
wgrad2_stream
):
if
ctx
.
spatial_group_size
>
1
:
wgrad2
=
fast_bottleneck
.
backward_wgrad2_pad
(
ctx
.
explicit_nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
out1_pad
,
grad_out2
)
else
:
wgrad2
=
fast_bottleneck
.
backward_wgrad2
(
ctx
.
explicit_nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
grad_out2
)
# do halo exchange of grad_out2 here
# do halo exchange of grad_out2 here
# compute halo cells for grad_out1
# compute halo cells for grad_out1
if
ctx
.
spatial_group_size
>
1
:
if
ctx
.
spatial_group_size
>
1
:
...
@@ -463,16 +518,10 @@ class SpatialBottleneckFunction(torch.autograd.Function):
...
@@ -463,16 +518,10 @@ class SpatialBottleneckFunction(torch.autograd.Function):
top_grad_out1_halo
=
top_grad_out1_halo
[:,
1
:
2
,:,:]
top_grad_out1_halo
=
top_grad_out1_halo
[:,
1
:
2
,:,:]
else
:
else
:
top_grad_out1_halo
=
top_grad_out1_halo
[:,:,
1
:
2
,:]
top_grad_out1_halo
=
top_grad_out1_halo
[:,:,
1
:
2
,:]
inc
.
add_delay
(
10
)
#
inc.add_delay(10)
elif
ctx
.
spatial_method
!=
3
:
elif
ctx
.
spatial_method
!=
3
:
assert
(
False
),
"spatial_method must be 1, 2 or 3"
assert
(
False
),
"spatial_method must be 1, 2 or 3"
with
torch
.
cuda
.
stream
(
wgrad2_stream
):
if
ctx
.
spatial_group_size
>
1
:
wgrad2
=
fast_bottleneck
.
backward_wgrad2_pad
(
ctx
.
explicit_nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
out1_pad
,
grad_out2
)
else
:
wgrad2
=
fast_bottleneck
.
backward_wgrad2
(
ctx
.
explicit_nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
grad_out2
)
# compute grad_out1 for internal cells
# compute grad_out1 for internal cells
if
ctx
.
spatial_group_size
<=
1
or
ctx
.
spatial_method
==
1
or
ctx
.
spatial_method
==
2
:
if
ctx
.
spatial_group_size
<=
1
or
ctx
.
spatial_method
==
1
or
ctx
.
spatial_method
==
2
:
grad_out1
=
fast_bottleneck
.
backward_grad_out1
(
ctx
.
explicit_nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
grad_out2
)
grad_out1
=
fast_bottleneck
.
backward_grad_out1
(
ctx
.
explicit_nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
grad_out2
)
...
@@ -577,6 +626,7 @@ class SpatialBottleneck(torch.nn.Module):
...
@@ -577,6 +626,7 @@ class SpatialBottleneck(torch.nn.Module):
self
.
bn1
=
norm_func
(
bottleneck_channels
)
self
.
bn1
=
norm_func
(
bottleneck_channels
)
self
.
bn2
=
norm_func
(
bottleneck_channels
)
self
.
bn2
=
norm_func
(
bottleneck_channels
)
self
.
bn3
=
norm_func
(
out_channels
)
self
.
bn3
=
norm_func
(
out_channels
)
self
.
w_scale
=
None
self
.
use_cudnn
=
use_cudnn
self
.
use_cudnn
=
use_cudnn
...
@@ -610,6 +660,27 @@ class SpatialBottleneck(torch.nn.Module):
...
@@ -610,6 +660,27 @@ class SpatialBottleneck(torch.nn.Module):
self
.
spatial_parallel_args
=
spatial_parallel_args
self
.
spatial_parallel_args
=
spatial_parallel_args
return
return
# Returns single callable that recomputes scale and bias for all frozen batch-norms.
# This method must be called before cuda graphing.
# The callable it returns can be called anytime.
# Calling this method will prevent these from being computed every forward call.
def
get_scale_bias_callable
(
self
):
self
.
w_scale
,
self
.
w_bias
,
args
=
[],
[],
[]
batch_norms
=
[
self
.
bn1
,
self
.
bn2
,
self
.
bn3
]
if
self
.
downsample
is
not
None
:
batch_norms
.
append
(
self
.
downsample
[
1
])
for
bn
in
batch_norms
:
s
=
torch
.
empty_like
(
bn
.
weight
)
b
=
torch
.
empty_like
(
s
)
args
.
append
(
(
bn
.
weight
,
bn
.
bias
,
bn
.
running_mean
,
bn
.
running_var
,
s
,
b
)
)
if
self
.
explicit_nhwc
:
self
.
w_scale
.
append
(
s
.
reshape
(
1
,
1
,
1
,
-
1
)
)
self
.
w_bias
.
append
(
b
.
reshape
(
1
,
1
,
1
,
-
1
)
)
else
:
self
.
w_scale
.
append
(
s
.
reshape
(
1
,
-
1
,
1
,
1
)
)
self
.
w_bias
.
append
(
b
.
reshape
(
1
,
-
1
,
1
,
1
)
)
return
func
.
partial
(
compute_scale_bias_method
,
self
.
explicit_nhwc
,
args
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
if
self
.
use_cudnn
:
if
self
.
use_cudnn
:
if
self
.
thresholdTop
is
None
:
if
self
.
thresholdTop
is
None
:
...
@@ -620,19 +691,24 @@ class SpatialBottleneck(torch.nn.Module):
...
@@ -620,19 +691,24 @@ class SpatialBottleneck(torch.nn.Module):
N
,
C
,
H
,
W
=
list
(
x
.
shape
)
N
,
C
,
H
,
W
=
list
(
x
.
shape
)
self
.
thresholdTop
=
torch
.
tensor
([
1
if
spatial_group_rank
>
0
else
0
],
dtype
=
torch
.
int32
,
device
=
'cuda'
)
self
.
thresholdTop
=
torch
.
tensor
([
1
if
spatial_group_rank
>
0
else
0
],
dtype
=
torch
.
int32
,
device
=
'cuda'
)
self
.
thresholdBottom
=
torch
.
tensor
([
H
-
2
if
spatial_group_rank
<
spatial_group_size
-
1
else
H
-
1
],
dtype
=
torch
.
int32
,
device
=
'cuda'
)
self
.
thresholdBottom
=
torch
.
tensor
([
H
-
2
if
spatial_group_rank
<
spatial_group_size
-
1
else
H
-
1
],
dtype
=
torch
.
int32
,
device
=
'cuda'
)
# calculate scale/bias from registered buffers
# TODO: make this better
if
self
.
w_scale
is
None
:
s1
,
b1
=
self
.
bn1
.
get_scale_bias
(
self
.
explicit_nhwc
)
# calculate scale/bias from registered buffers
s2
,
b2
=
self
.
bn2
.
get_scale_bias
(
self
.
explicit_nhwc
)
# TODO: make this better
s3
,
b3
=
self
.
bn3
.
get_scale_bias
(
self
.
explicit_nhwc
)
s1
,
b1
=
self
.
bn1
.
get_scale_bias
(
self
.
explicit_nhwc
)
w_scale
=
[
s1
,
s2
,
s3
]
s2
,
b2
=
self
.
bn2
.
get_scale_bias
(
self
.
explicit_nhwc
)
w_bias
=
[
b1
,
b2
,
b3
]
s3
,
b3
=
self
.
bn3
.
get_scale_bias
(
self
.
explicit_nhwc
)
if
self
.
downsample
is
not
None
:
w_scale
=
[
s1
,
s2
,
s3
]
s4
,
b4
=
self
.
downsample
[
1
].
get_scale_bias
(
self
.
explicit_nhwc
)
w_bias
=
[
b1
,
b2
,
b3
]
w_scale
.
append
(
s4
)
if
self
.
downsample
is
not
None
:
w_bias
.
append
(
b4
)
s4
,
b4
=
self
.
downsample
[
1
].
get_scale_bias
(
self
.
explicit_nhwc
)
w_scale
.
append
(
s4
)
out
=
spatial_bottleneck_function
(
*
self
.
spatial_parallel_args
,
self
.
explicit_nhwc
,
self
.
stride
,
w_scale
,
w_bias
,
self
.
thresholdTop
,
self
.
thresholdBottom
,
x
,
*
self
.
w_conv
)
w_bias
.
append
(
b4
)
self
.
w_scale
=
w_scale
self
.
w_bias
=
w_bias
out
=
spatial_bottleneck_function
(
*
self
.
spatial_parallel_args
,
self
.
explicit_nhwc
,
self
.
stride
,
w_scale
,
w_bias
,
self
.
thresholdTop
,
self
.
thresholdBottom
,
x
,
*
self
.
w_conv
)
else
:
out
=
spatial_bottleneck_function
(
*
self
.
spatial_parallel_args
,
self
.
explicit_nhwc
,
self
.
stride
,
self
.
w_scale
,
self
.
w_bias
,
self
.
thresholdTop
,
self
.
thresholdBottom
,
x
,
*
self
.
w_conv
)
return
out
return
out
if
self
.
explicit_nhwc
:
if
self
.
explicit_nhwc
:
...
...
apex/contrib/csrc/peer_memory/peer_memory.cpp
View file @
c70f0e32
...
@@ -19,6 +19,7 @@
...
@@ -19,6 +19,7 @@
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"allocate_raw"
,
&
apex
::
contrib
::
peer_memory
::
allocate_raw
,
"allocate_raw"
);
m
.
def
(
"allocate_raw"
,
&
apex
::
contrib
::
peer_memory
::
allocate_raw
,
"allocate_raw"
);
m
.
def
(
"free_raw"
,
&
apex
::
contrib
::
peer_memory
::
free_raw
,
"free_raw"
);
m
.
def
(
"free_raw"
,
&
apex
::
contrib
::
peer_memory
::
free_raw
,
"free_raw"
);
m
.
def
(
"zero"
,
&
apex
::
contrib
::
peer_memory
::
zero
,
"zero"
);
m
.
def
(
"get_raw_ipc_address"
,
&
apex
::
contrib
::
peer_memory
::
get_raw_ipc_address
,
"get_raw_ipc_address"
);
m
.
def
(
"get_raw_ipc_address"
,
&
apex
::
contrib
::
peer_memory
::
get_raw_ipc_address
,
"get_raw_ipc_address"
);
m
.
def
(
"get_raw_peers"
,
&
apex
::
contrib
::
peer_memory
::
get_raw_peers
,
"get_raw_peers"
);
m
.
def
(
"get_raw_peers"
,
&
apex
::
contrib
::
peer_memory
::
get_raw_peers
,
"get_raw_peers"
);
m
.
def
(
"blob_view_half"
,
&
apex
::
contrib
::
peer_memory
::
blob_view_half
,
"blob_view_half"
);
m
.
def
(
"blob_view_half"
,
&
apex
::
contrib
::
peer_memory
::
blob_view_half
,
"blob_view_half"
);
...
...
apex/contrib/csrc/peer_memory/peer_memory_cuda.cu
View file @
c70f0e32
...
@@ -148,33 +148,58 @@ __device__ void strided_copy_kernel(
...
@@ -148,33 +148,58 @@ __device__ void strided_copy_kernel(
}
}
}
}
template
<
bool
wait
,
bool
clear
>
__device__
void
checked_signal
(
__device__
void
dual_signal_wait_clear
(
volatile
int
*
signal1_flag
,
volatile
int
*
signal2_flag
,
volatile
int
*
signal1_flag
,
volatile
int
*
wait1_flag
,
volatile
int
*
signal2_flag
,
volatile
int
*
wait2_flag
,
const
int
v1
,
const
int
v2
,
const
int
v3
,
const
int
v4
const
int
v1
,
const
int
v2
,
const
int
v3
,
const
int
v4
)
)
{
{
register
int
r1
,
r2
,
r3
,
r4
,
r5
,
r6
,
r7
,
r8
;
if
(
blockIdx
.
x
==
0
)
{
bool
is_main_thread
=
(
blockIdx
.
x
==
0
&&
threadIdx
.
x
==
0
)
?
true
:
false
;
register
int
r1
,
r2
,
r3
,
r4
;
// signal and wait
if
(
threadIdx
.
x
==
0
)
{
if
(
is_main_thread
)
{
// wait for top neighbor to clear bottom signal (indicating ready for new input)
asm
volatile
(
"st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};"
::
"l"
(
signal1_flag
),
"r"
(
v1
),
"r"
(
v2
),
"r"
(
v3
),
"r"
(
v4
)
:
"memory"
);
do
{
asm
volatile
(
"st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};"
::
"l"
(
signal2_flag
),
"r"
(
v1
),
"r"
(
v2
),
"r"
(
v3
),
"r"
(
v4
)
:
"memory"
);
asm
volatile
(
"ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];"
:
"=r"
(
r1
),
"=r"
(
r2
),
"=r"
(
r3
),
"=r"
(
r4
)
:
"l"
(
signal1_flag
)
:
"memory"
);
if
(
wait
)
{
}
while
(
r1
==
v1
&&
r2
==
v2
&&
r3
==
v3
&&
r4
==
v4
);
// signal to top neighbor my output is ready
asm
volatile
(
"st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};"
::
"l"
(
signal1_flag
),
"r"
(
v1
),
"r"
(
v2
),
"r"
(
v3
),
"r"
(
v4
)
:
"memory"
);
}
else
if
(
threadIdx
.
x
==
1
)
{
// wait for bottom neighbor to clear top signal (indicating ready for new input)
do
{
do
{
asm
volatile
(
"ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];"
:
"=r"
(
r1
),
"=r"
(
r2
),
"=r"
(
r3
),
"=r"
(
r4
)
:
"l"
(
wait1_flag
)
:
"memory"
);
asm
volatile
(
"ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];"
:
"=r"
(
r1
),
"=r"
(
r2
),
"=r"
(
r3
),
"=r"
(
r4
)
:
"l"
(
signal2_flag
)
:
"memory"
);
asm
volatile
(
"ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];"
:
"=r"
(
r5
),
"=r"
(
r6
),
"=r"
(
r7
),
"=r"
(
r8
)
:
"l"
(
wait2_flag
)
:
"memory"
);
}
while
(
r1
==
v1
&&
r2
==
v2
&&
r3
==
v3
&&
r4
==
v4
);
}
while
(
r1
!=
v1
||
r5
!=
v1
||
r2
!=
v2
||
r6
!=
v2
||
r3
!=
v3
||
r7
!=
v3
||
r4
!=
v4
||
r8
!=
v4
);
// signal to bottom neighbor my output is ready
asm
volatile
(
"st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};"
::
"l"
(
signal2_flag
),
"r"
(
v1
),
"r"
(
v2
),
"r"
(
v3
),
"r"
(
v4
)
:
"memory"
);
}
}
}
}
cg
::
this_grid
().
sync
();
}
if
(
clear
)
{
if
(
is_main_thread
)
{
__device__
void
wait_for
(
r1
=
0
;
r2
=
0
;
r3
=
0
;
r4
=
0
;
volatile
int
*
wait_flag
,
asm
volatile
(
"st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};"
::
"l"
(
wait1_flag
),
"r"
(
r1
),
"r"
(
r2
),
"r"
(
r3
),
"r"
(
r4
)
:
"memory"
);
const
int
v1
,
const
int
v2
,
const
int
v3
,
const
int
v4
asm
volatile
(
"st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};"
::
"l"
(
wait2_flag
),
"r"
(
r1
),
"r"
(
r2
),
"r"
(
r3
),
"r"
(
r4
)
:
"memory"
);
)
}
{
bool
is_main_thread
=
(
blockIdx
.
x
==
0
&&
threadIdx
.
x
==
0
)
?
true
:
false
;
if
(
is_main_thread
)
{
register
int
r1
,
r2
,
r3
,
r4
;
// wait for senders to signal their output is read
do
{
asm
volatile
(
"ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];"
:
"=r"
(
r1
),
"=r"
(
r2
),
"=r"
(
r3
),
"=r"
(
r4
)
:
"l"
(
wait_flag
)
:
"memory"
);
}
while
(
r1
!=
v1
||
r2
!=
v2
||
r3
!=
v3
||
r4
!=
v4
);
}
cg
::
this_grid
().
sync
();
// all threads wait for main
}
__device__
void
clear_flag
(
volatile
int
*
wait_flag
)
{
cg
::
this_grid
().
sync
();
// wait for all threads in kernel to finish
bool
is_main_thread
=
(
blockIdx
.
x
==
0
&&
threadIdx
.
x
==
0
)
?
true
:
false
;
if
(
is_main_thread
)
{
register
int
r1
,
r2
,
r3
,
r4
;
r1
=
0
;
r2
=
0
;
r3
=
0
;
r4
=
0
;
asm
volatile
(
"st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};"
::
"l"
(
wait_flag
),
"r"
(
r1
),
"r"
(
r2
),
"r"
(
r3
),
"r"
(
r4
)
:
"memory"
);
}
}
}
}
...
@@ -208,11 +233,15 @@ __global__ void push_pull_halos_1d_kernel(
...
@@ -208,11 +233,15 @@ __global__ void push_pull_halos_1d_kernel(
strided_copy_kernel
<
T
,
is_HWC
>
(
box
,
box_stride_C
,
box_stride_H
,
box_stride_W
,
boh
,
boh_stride_C
,
boh_stride_H
,
boh_stride_W
,
NC
,
NH
,
NW
);
strided_copy_kernel
<
T
,
is_HWC
>
(
box
,
box_stride_C
,
box_stride_H
,
box_stride_W
,
boh
,
boh_stride_C
,
boh_stride_H
,
boh_stride_W
,
NC
,
NH
,
NW
);
// signal to top and btm neigbhbors that output halos are ready to be read
// signal to top and btm neigbhbors that output halos are ready to be read
// the choice of values for v1-v4 is arbitrary and does not matter, as long as all ranks use the same values
// the choice of values for v1-v4 is arbitrary and does not matter, as long as all ranks use the same values
dual_signal_wait_clear
<
true
,
true
>
(
signal1_flag
,
wait1_flag
,
signal2_flag
,
wait
2_flag
,
-
987751720
,
840868300
,
-
225529332
,
281513358
);
checked_signal
(
signal1_flag
,
signal
2_flag
,
-
987751720
,
840868300
,
-
225529332
,
281513358
);
// pull top halo from transfer buffer in peer memory to input
// pull top halo from transfer buffer in peer memory to input
wait_for
(
wait1_flag
,
-
987751720
,
840868300
,
-
225529332
,
281513358
);
strided_copy_kernel
<
T
,
is_HWC
>
(
tih
,
tih_stride_C
,
tih_stride_H
,
tih_stride_W
,
tix
,
tix_stride_C
,
tix_stride_H
,
tix_stride_W
,
NC
,
NH
,
NW
);
strided_copy_kernel
<
T
,
is_HWC
>
(
tih
,
tih_stride_C
,
tih_stride_H
,
tih_stride_W
,
tix
,
tix_stride_C
,
tix_stride_H
,
tix_stride_W
,
NC
,
NH
,
NW
);
clear_flag
(
wait1_flag
);
// pull btm halo from transfer buffer in peer memory to input
// pull btm halo from transfer buffer in peer memory to input
wait_for
(
wait2_flag
,
-
987751720
,
840868300
,
-
225529332
,
281513358
);
strided_copy_kernel
<
T
,
is_HWC
>
(
bih
,
bih_stride_C
,
bih_stride_H
,
bih_stride_W
,
bix
,
bix_stride_C
,
bix_stride_H
,
bix_stride_W
,
NC
,
NH
,
NW
);
strided_copy_kernel
<
T
,
is_HWC
>
(
bih
,
bih_stride_C
,
bih_stride_H
,
bih_stride_W
,
bix
,
bix_stride_C
,
bix_stride_H
,
bix_stride_W
,
NC
,
NH
,
NW
);
clear_flag
(
wait2_flag
);
}
}
__global__
void
delay_kernel
(
int
delay_nanoseconds
,
int
*
counter
)
__global__
void
delay_kernel
(
int
delay_nanoseconds
,
int
*
counter
)
...
@@ -248,6 +277,11 @@ void free_raw(int64_t raw)
...
@@ -248,6 +277,11 @@ void free_raw(int64_t raw)
cudaFree
((
void
*
)
raw
);
cudaFree
((
void
*
)
raw
);
}
}
void
zero
(
int64_t
raw
,
int64_t
size
)
{
cudaMemset
((
void
*
)
raw
,
0
,
size
);
}
at
::
Tensor
get_raw_ipc_address
(
int64_t
raw
)
at
::
Tensor
get_raw_ipc_address
(
int64_t
raw
)
{
{
cudaIpcMemHandle_t
mem_handle
;
cudaIpcMemHandle_t
mem_handle
;
...
...
apex/contrib/csrc/peer_memory/peer_memory_cuda.cuh
View file @
c70f0e32
...
@@ -22,6 +22,7 @@
...
@@ -22,6 +22,7 @@
namespace
apex
{
namespace
contrib
{
namespace
peer_memory
{
namespace
apex
{
namespace
contrib
{
namespace
peer_memory
{
int64_t
allocate_raw
(
int64_t
size
);
int64_t
allocate_raw
(
int64_t
size
);
void
free_raw
(
int64_t
raw
);
void
free_raw
(
int64_t
raw
);
void
zero
(
int64_t
raw
,
int64_t
size
);
at
::
Tensor
get_raw_ipc_address
(
int64_t
raw
);
at
::
Tensor
get_raw_ipc_address
(
int64_t
raw
);
std
::
vector
<
int64_t
>
get_raw_peers
(
at
::
Tensor
ipc_addresses
,
int
peer_rank
,
int64_t
raw
);
std
::
vector
<
int64_t
>
get_raw_peers
(
at
::
Tensor
ipc_addresses
,
int
peer_rank
,
int64_t
raw
);
at
::
Tensor
blob_view_half
(
int64_t
raw
,
std
::
vector
<
int64_t
>
shape
,
bool
channels_last
);
at
::
Tensor
blob_view_half
(
int64_t
raw
,
std
::
vector
<
int64_t
>
shape
,
bool
channels_last
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment