Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
3ade5b26
You need to sign in or sign up before continuing.
Commit
3ade5b26
authored
Mar 24, 2022
by
Thor Johnsen
Browse files
Add bottleneck block
parent
b48898fb
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
187 additions
and
107 deletions
+187
-107
apex/contrib/bottleneck/bottleneck.py
apex/contrib/bottleneck/bottleneck.py
+144
-95
apex/contrib/csrc/peer_memory/peer_memory.cpp
apex/contrib/csrc/peer_memory/peer_memory.cpp
+8
-8
apex/contrib/csrc/peer_memory/peer_memory_cuda.cu
apex/contrib/csrc/peer_memory/peer_memory_cuda.cu
+19
-2
apex/contrib/csrc/peer_memory/peer_memory_cuda.cuh
apex/contrib/csrc/peer_memory/peer_memory_cuda.cuh
+2
-2
setup.py
setup.py
+14
-0
No files found.
apex/contrib/bottleneck/bottleneck.py
View file @
3ade5b26
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
torch
import
nn
from
torch
import
nn
import
fast_bottleneck
from
maskrcnn_benchmark.utils.registry
import
Registry
import
maskrcnn_benchmark.SpatialBottleneck
as
fast_bottleneck
import
nccl_p2p
as
inc
def
kaiming_uniform_
(
tensor
,
a
=
0
,
mode
=
'fan_in'
,
nonlinearity
=
'leaky_relu'
):
def
kaiming_uniform_
(
tensor
,
a
=
0
,
mode
=
'fan_in'
,
nonlinearity
=
'leaky_relu'
):
weight_tensor_nchw
=
tensor
weight_tensor_nchw
=
tensor
nn
.
init
.
kaiming_uniform_
(
weight_tensor_nchw
,
a
=
a
,
mode
=
mode
,
nonlinearity
=
nonlinearity
)
nn
.
init
.
kaiming_uniform_
(
weight_tensor_nchw
,
a
=
a
,
mode
=
mode
,
nonlinearity
=
nonlinearity
)
class
FrozenBatchNorm2d
(
torch
.
nn
.
Module
):
class
FrozenBatchNorm2d
(
torch
.
jit
.
Script
Module
):
"""
"""
BatchNorm2d where the batch statistics and the affine parameters are fixed
BatchNorm2d where the batch statistics and the affine parameters are fixed
"""
"""
...
@@ -18,7 +20,9 @@ class FrozenBatchNorm2d(torch.nn.Module):
...
@@ -18,7 +20,9 @@ class FrozenBatchNorm2d(torch.nn.Module):
self
.
register_buffer
(
"running_mean"
,
torch
.
zeros
(
n
))
self
.
register_buffer
(
"running_mean"
,
torch
.
zeros
(
n
))
self
.
register_buffer
(
"running_var"
,
torch
.
ones
(
n
))
self
.
register_buffer
(
"running_var"
,
torch
.
ones
(
n
))
def
get_scale_bias
(
self
,
nhwc
=
False
):
@
torch
.
jit
.
script_method
def
get_scale_bias
(
self
,
nhwc
):
# type: (bool) -> List[torch.Tensor]
scale
=
self
.
weight
*
self
.
running_var
.
rsqrt
()
scale
=
self
.
weight
*
self
.
running_var
.
rsqrt
()
bias
=
self
.
bias
-
self
.
running_mean
*
scale
bias
=
self
.
bias
-
self
.
running_mean
*
scale
if
nhwc
:
if
nhwc
:
...
@@ -29,11 +33,11 @@ class FrozenBatchNorm2d(torch.nn.Module):
...
@@ -29,11 +33,11 @@ class FrozenBatchNorm2d(torch.nn.Module):
bias
=
bias
.
reshape
(
1
,
-
1
,
1
,
1
)
bias
=
bias
.
reshape
(
1
,
-
1
,
1
,
1
)
return
scale
,
bias
return
scale
,
bias
@
torch
.
jit
.
script_method
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
scale
,
bias
=
self
.
get_scale_bias
()
scale
,
bias
=
self
.
get_scale_bias
(
False
)
return
x
*
scale
+
bias
return
x
*
scale
+
bias
@
torch
.
jit
.
script
@
torch
.
jit
.
script
def
drelu_dscale1
(
grad_o
,
output
,
scale1
):
def
drelu_dscale1
(
grad_o
,
output
,
scale1
):
relu_mask
=
(
output
>
0
)
relu_mask
=
(
output
>
0
)
...
@@ -217,7 +221,11 @@ class Bottleneck(torch.nn.Module):
...
@@ -217,7 +221,11 @@ class Bottleneck(torch.nn.Module):
class
SpatialBottleneckFunction
(
torch
.
autograd
.
Function
):
class
SpatialBottleneckFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
spatial_group_size
,
local_rank
,
comm
,
stream1
,
nhwc
,
stride_1x1
,
scale
,
bias
,
x
,
*
conv
):
def
forward
(
ctx
,
spatial_group_size
,
spatial_group_rank
,
spatial_communicator
,
spatial_halo_exchanger
,
spatial_stream
,
nhwc
,
stride_1x1
,
scale
,
bias
,
x
,
*
conv
):
if
spatial_group_size
>
1
:
stream1
=
spatial_halo_exchanger
.
stream1
stream2
=
spatial_halo_exchanger
.
stream2
# TODO: clean up order of tensors
# TODO: clean up order of tensors
args
=
[
x
,
*
conv
[
0
:
3
],
*
scale
[
0
:
3
],
*
bias
[
0
:
3
]]
args
=
[
x
,
*
conv
[
0
:
3
],
*
scale
[
0
:
3
],
*
bias
[
0
:
3
]]
ctx
.
downsample
=
len
(
conv
)
>
3
ctx
.
downsample
=
len
(
conv
)
>
3
...
@@ -232,38 +240,38 @@ class SpatialBottleneckFunction(torch.autograd.Function):
...
@@ -232,38 +240,38 @@ class SpatialBottleneckFunction(torch.autograd.Function):
outputs
=
fast_bottleneck
.
forward_init
(
nhwc
,
stride_1x1
,
args
)
outputs
=
fast_bottleneck
.
forward_init
(
nhwc
,
stride_1x1
,
args
)
fast_bottleneck
.
forward_out1
(
nhwc
,
stride_1x1
,
args
,
outputs
)
fast_bottleneck
.
forward_out1
(
nhwc
,
stride_1x1
,
args
,
outputs
)
fast_bottleneck
.
forward_out2
(
nhwc
,
stride_1x1
,
args
,
outputs
)
# do halo exchange for outputs[0] (out1)
# do halo exchange for outputs[0] (out1)
# compute halo cells for outputs[1]
if
spatial_group_size
>
1
:
if
spatial_group_size
>
1
:
out1
=
outputs
[
0
]
out1
=
outputs
[
0
]
N
,
Hs
,
W
,
C
=
list
(
out1
.
shape
)
N
,
Hs
,
W
,
C
=
list
(
out1
.
shape
)
stream1
.
wait_stream
(
torch
.
cuda
.
current_stream
())
stream1
.
wait_stream
(
torch
.
cuda
.
current_stream
())
with
torch
.
cuda
.
stream
(
stream1
):
with
torch
.
cuda
.
stream
(
stream1
):
# copy halos to send buffer
top_out1_halo
,
btm_out1_halo
=
spatial_halo_exchanger
.
left_right_halo_exchange
(
out1
[:,:
1
,:,:],
out1
[:,
Hs
-
1
:,:,:])
send_halos
=
torch
.
empty
((
N
,
2
,
W
,
C
),
dtype
=
out1
.
dtype
,
device
=
out1
.
device
)
if
spatial_group_rank
<
spatial_group_size
-
1
:
send_halos
[:,:
1
,:,:].
copy_
(
out1
[:,:
1
,:,:])
stream2
.
wait_stream
(
stream1
)
send_halos
[:,
1
:,:,:].
copy_
(
out1
[:,
Hs
-
1
:,:,:])
with
torch
.
cuda
.
stream
(
stream2
):
all_halos
=
torch
.
empty
((
N
,
2
*
spatial_group_size
,
W
,
C
),
dtype
=
out1
.
dtype
,
device
=
out1
.
device
)
btm_fat_halo
=
torch
.
empty
((
N
,
3
,
W
,
C
),
dtype
=
out1
.
dtype
,
device
=
out1
.
device
)
all_halos
=
[
all_halos
[:,
i
*
2
:(
i
+
1
)
*
2
,:,:]
for
i
in
range
(
spatial_group_size
)]
btm_fat_halo
[:,
0
:
2
,:,:].
copy_
(
out1
[:,
Hs
-
2
:,:,:])
dist
.
all_gather
(
all_halos
,
send_halos
,
group
=
comm
)
btm_fat_halo
[:,
2
:,:,:].
copy_
(
btm_out1_halo
)
fat_halo
=
torch
.
empty
((
N
,
3
,
W
,
C
),
dtype
=
out1
.
dtype
,
device
=
out1
.
device
)
btm_out2
=
fast_bottleneck
.
forward_out2_halo
(
nhwc
,
btm_fat_halo
,
args
)
top_out1_halo
=
all_halos
[(
spatial_group_size
+
local_rank
-
1
)
%
spatial_group_size
][:,
1
:,:,:]
if
spatial_group_rank
>
0
:
if
local_rank
>
0
:
with
torch
.
cuda
.
stream
(
stream1
):
fat_halo
[:,:
1
,:,:].
copy_
(
top_out1_halo
)
top_fat_halo
=
torch
.
empty
((
N
,
3
,
W
,
C
),
dtype
=
out1
.
dtype
,
device
=
out1
.
device
)
fat_halo
[:,
1
:
3
,:,:].
copy_
(
out1
[:,:
2
,:,:])
top_fat_halo
[:,:
1
,:,:].
copy_
(
top_out1_halo
)
top_out2
=
fast_bottleneck
.
forward_out2_halo
(
nhwc
,
fat_halo
,
args
)
top_fat_halo
[:,
1
:
3
,:,:].
copy_
(
out1
[:,:
2
,:,:])
btm_out1_halo
=
all_halos
[(
local_rank
+
1
)
%
spatial_group_size
][:,:
1
,:,:]
top_out2
=
fast_bottleneck
.
forward_out2_halo
(
nhwc
,
top_fat_halo
,
args
)
if
local_rank
<
spatial_group_size
-
1
:
inc
.
add_delay
(
10
)
fat_halo
[:,
0
:
2
,:,:].
copy_
(
out1
[:,
Hs
-
2
:,:,:])
fat_halo
[:,
2
:,:,:].
copy_
(
btm_out1_halo
)
fast_bottleneck
.
forward_out2
(
nhwc
,
stride_1x1
,
args
,
outputs
)
btm_out2
=
fast_bottleneck
.
forward_out2_halo
(
nhwc
,
fat_halo
,
args
)
torch
.
cuda
.
current_stream
().
wait_stream
(
stream1
)
# compute halo cells for outputs[1] (out2)
if
spatial_group_size
>
1
:
out2
=
outputs
[
1
]
out2
=
outputs
[
1
]
if
local_rank
>
0
:
if
spatial_group_rank
>
0
:
torch
.
cuda
.
current_stream
().
wait_stream
(
stream1
)
out2
[:,:
1
,:,:].
copy_
(
top_out2
)
out2
[:,:
1
,:,:].
copy_
(
top_out2
)
if
local_rank
<
spatial_group_size
-
1
:
if
spatial_group_rank
<
spatial_group_size
-
1
:
torch
.
cuda
.
current_stream
().
wait_stream
(
stream2
)
out2
[:,
Hs
-
1
:,:,:].
copy_
(
btm_out2
)
out2
[:,
Hs
-
1
:,:,:].
copy_
(
btm_out2
)
fast_bottleneck
.
forward_rest
(
nhwc
,
stride_1x1
,
args
,
outputs
)
fast_bottleneck
.
forward_rest
(
nhwc
,
stride_1x1
,
args
,
outputs
)
...
@@ -276,9 +284,11 @@ class SpatialBottleneckFunction(torch.autograd.Function):
...
@@ -276,9 +284,11 @@ class SpatialBottleneckFunction(torch.autograd.Function):
ctx
.
nhwc
=
nhwc
ctx
.
nhwc
=
nhwc
ctx
.
stride_1x1
=
stride_1x1
ctx
.
stride_1x1
=
stride_1x1
ctx
.
spatial_group_size
=
spatial_group_size
ctx
.
spatial_group_size
=
spatial_group_size
ctx
.
local_rank
=
local_rank
if
spatial_group_size
>
1
:
ctx
.
comm
=
comm
ctx
.
spatial_group_rank
=
spatial_group_rank
ctx
.
stream1
=
stream1
ctx
.
spatial_halo_exchanger
=
spatial_halo_exchanger
ctx
.
stream1
=
stream1
ctx
.
stream2
=
stream2
return
outputs
[
2
]
return
outputs
[
2
]
# backward relu is not exposed, MUL with mask used now
# backward relu is not exposed, MUL with mask used now
...
@@ -312,54 +322,52 @@ class SpatialBottleneckFunction(torch.autograd.Function):
...
@@ -312,54 +322,52 @@ class SpatialBottleneckFunction(torch.autograd.Function):
grads
=
fast_bottleneck
.
backward_init
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
)
grads
=
fast_bottleneck
.
backward_init
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
)
grad_out2
=
fast_bottleneck
.
backward_grad_out2
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
)
grad_out2
=
fast_bottleneck
.
backward_grad_out2
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
)
# do halo exchange of grad_out2 here
# compute halo cells for grad_out1
if
ctx
.
spatial_group_size
>
1
:
N
,
Hs
,
W
,
C
=
list
(
grad_out2
.
shape
)
relu1
=
t_list
[
12
]
ctx
.
stream1
.
wait_stream
(
torch
.
cuda
.
current_stream
())
with
torch
.
cuda
.
stream
(
ctx
.
stream1
):
top_halo
,
btm_halo
=
ctx
.
spatial_halo_exchanger
.
left_right_halo_exchange
(
grad_out2
[:,:
1
,:,:],
grad_out2
[:,
Hs
-
1
:,:,:])
# copy halos to send buffer
if
ctx
.
spatial_group_rank
<
ctx
.
spatial_group_size
-
1
:
ctx
.
stream2
.
wait_stream
(
ctx
.
stream1
)
with
torch
.
cuda
.
stream
(
ctx
.
stream2
):
btm_fat_halo
=
torch
.
empty
((
N
,
3
,
W
,
C
),
dtype
=
grad_out2
.
dtype
,
device
=
grad_out2
.
device
)
btm_relu_halo
=
torch
.
empty
((
N
,
3
,
W
,
C
),
dtype
=
grad_out2
.
dtype
,
device
=
grad_out2
.
device
)
btm_fat_halo
[:,:
2
,:,:].
copy_
(
grad_out2
[:,
Hs
-
2
:,:,:])
btm_fat_halo
[:,
2
:,:,:].
copy_
(
btm_halo
)
btm_relu_halo
[:,:
2
,:,:].
copy_
(
relu1
[:,
Hs
-
2
:,:,:])
btm_relu_halo
[:,
2
:,:,:].
zero_
()
btm_grad_out1_halo
=
fast_bottleneck
.
backward_grad_out1_halo
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
btm_fat_halo
,
btm_relu_halo
)
btm_grad_out1_halo
=
btm_grad_out1_halo
[:,
1
:
2
,:,:]
if
ctx
.
spatial_group_rank
>
0
:
with
torch
.
cuda
.
stream
(
ctx
.
stream1
):
top_fat_halo
=
torch
.
empty
((
N
,
3
,
W
,
C
),
dtype
=
grad_out2
.
dtype
,
device
=
grad_out2
.
device
)
top_relu_halo
=
torch
.
empty
((
N
,
3
,
W
,
C
),
dtype
=
grad_out2
.
dtype
,
device
=
grad_out2
.
device
)
top_fat_halo
[:,:
1
,:,:].
copy_
(
top_halo
)
top_fat_halo
[:,
1
:,:,:].
copy_
(
grad_out2
[:,:
2
,:,:])
top_relu_halo
[:,:
1
,:,:].
zero_
()
top_relu_halo
[:,
1
:,:,:].
copy_
(
relu1
[:,:
2
,:,:])
top_grad_out1_halo
=
fast_bottleneck
.
backward_grad_out1_halo
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
top_fat_halo
,
top_relu_halo
)
top_grad_out1_halo
=
top_grad_out1_halo
[:,
1
:
2
,:,:]
inc
.
add_delay
(
10
)
# compute wgrad2 for internal cells
# compute wgrad2 for internal cells
wgrad2
=
fast_bottleneck
.
backward_wgrad2
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
grad_out2
)
wgrad2
=
fast_bottleneck
.
backward_wgrad2
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
grad_out2
)
# apply wgrad2 halos
# apply wgrad2 halos
if
ctx
.
spatial_group_size
>
1
:
if
ctx
.
spatial_group_size
>
1
:
if
ctx
.
local
_rank
>
0
:
if
ctx
.
spatial_group
_rank
>
0
:
top_grad2_halo
=
grad_out2
[:,:
1
,:,:]
top_grad2_halo
=
grad_out2
[:,:
1
,:,:]
top_wgrad2_halo
=
fast_bottleneck
.
backward_wgrad2_halo
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
top_out1_halo
,
top_grad2_halo
)
top_wgrad2_halo
=
fast_bottleneck
.
backward_wgrad2_halo
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
top_out1_halo
,
top_grad2_halo
)
wgrad2
[:,:
1
,:,:].
add_
(
top_wgrad2_halo
)
wgrad2
[:,:
1
,:,:].
add_
(
top_wgrad2_halo
)
if
ctx
.
local
_rank
<
ctx
.
spatial_group_size
-
1
:
if
ctx
.
spatial_group
_rank
<
ctx
.
spatial_group_size
-
1
:
btm_grad2_halo
=
grad_out2
[:,
-
1
:,:,:]
btm_grad2_halo
=
grad_out2
[:,
-
1
:,:,:]
btm_wgrad2_halo
=
fast_bottleneck
.
backward_wgrad2_halo
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
btm_out1_halo
,
btm_grad2_halo
)
btm_wgrad2_halo
=
fast_bottleneck
.
backward_wgrad2_halo
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
btm_out1_halo
,
btm_grad2_halo
)
wgrad2
[:,
-
1
:,:,:].
add_
(
btm_wgrad2_halo
)
wgrad2
[:,
-
1
:,:,:].
add_
(
btm_wgrad2_halo
)
# do halo exchange of grad_out2 here
# compute halo cells for grad_out1
if
ctx
.
spatial_group_size
>
1
:
N
,
Hs
,
W
,
C
=
list
(
grad_out2
.
shape
)
ctx
.
stream1
.
wait_stream
(
torch
.
cuda
.
current_stream
())
with
torch
.
cuda
.
stream
(
ctx
.
stream1
):
# copy halos to send buffer
send_halos
=
torch
.
empty
((
N
,
2
,
W
,
C
),
dtype
=
grad_out2
.
dtype
,
device
=
grad_out2
.
device
)
send_halos
[:,:
1
,:,:].
copy_
(
grad_out2
[:,:
1
,:,:])
send_halos
[:,
1
:,:,:].
copy_
(
grad_out2
[:,
Hs
-
1
:,:,:])
all_halos
=
torch
.
empty
((
N
,
2
*
ctx
.
spatial_group_size
,
W
,
C
),
dtype
=
grad_out2
.
dtype
,
device
=
grad_out2
.
device
)
all_halos
=
[
all_halos
[:,
i
*
2
:(
i
+
1
)
*
2
,:,:]
for
i
in
range
(
ctx
.
spatial_group_size
)]
dist
.
all_gather
(
all_halos
,
send_halos
,
group
=
ctx
.
comm
)
relu1
=
t_list
[
12
]
fat_halo
=
torch
.
empty
((
N
,
3
,
W
,
C
),
dtype
=
grad_out2
.
dtype
,
device
=
grad_out2
.
device
)
relu_halo
=
torch
.
empty
((
N
,
3
,
W
,
C
),
dtype
=
grad_out2
.
dtype
,
device
=
grad_out2
.
device
)
if
ctx
.
local_rank
>
0
:
top_halo
=
all_halos
[
ctx
.
local_rank
-
1
][:,
1
:,:,:]
fat_halo
[:,:
1
,:,:].
copy_
(
top_halo
)
fat_halo
[:,
1
:,:,:].
copy_
(
grad_out2
[:,:
2
,:,:])
relu_halo
[:,:
1
,:,:].
zero_
()
relu_halo
[:,
1
:,:,:].
copy_
(
relu1
[:,:
2
,:,:])
top_grad_out1_halo
=
fast_bottleneck
.
backward_grad_out1_halo
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
fat_halo
,
relu_halo
)
top_grad_out1_halo
=
top_grad_out1_halo
[:,
1
:
2
,:,:]
if
ctx
.
local_rank
<
ctx
.
spatial_group_size
-
1
:
btm_halo
=
all_halos
[
ctx
.
local_rank
+
1
][:,:
1
,:,:]
fat_halo
[:,:
2
,:,:].
copy_
(
grad_out2
[:,
Hs
-
2
:,:,:])
fat_halo
[:,
2
:,:,:].
copy_
(
btm_halo
)
relu_halo
[:,:
2
,:,:].
copy_
(
relu1
[:,
Hs
-
2
:,:,:])
relu_halo
[:,
2
:,:,:].
zero_
()
btm_grad_out1_halo
=
fast_bottleneck
.
backward_grad_out1_halo
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
fat_halo
,
relu_halo
)
btm_grad_out1_halo
=
btm_grad_out1_halo
[:,
1
:
2
,:,:]
# compute grad_out1 for internal cells
# compute grad_out1 for internal cells
grad_out1
=
fast_bottleneck
.
backward_grad_out1
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
grad_out2
)
grad_out1
=
fast_bottleneck
.
backward_grad_out1
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
grad_out2
)
...
@@ -369,20 +377,70 @@ class SpatialBottleneckFunction(torch.autograd.Function):
...
@@ -369,20 +377,70 @@ class SpatialBottleneckFunction(torch.autograd.Function):
z
=
t_list
[
4
]
z
=
t_list
[
4
]
relu1
=
t_list
[
12
]
relu1
=
t_list
[
12
]
#print("w.shape = %s, z.shape = %s, relu1.shape = %s" % (str(list(w.shape)), str(list(z.shape)), str(list(relu1.shape))))
#print("w.shape = %s, z.shape = %s, relu1.shape = %s" % (str(list(w.shape)), str(list(z.shape)), str(list(relu1.shape))))
torch
.
cuda
.
current_stream
().
wait_stream
(
ctx
.
stream1
)
if
ctx
.
spatial_group_rank
>
0
:
if
ctx
.
local_rank
>
0
:
torch
.
cuda
.
current_stream
().
wait_stream
(
ctx
.
stream1
)
grad_out1
[:,:
1
,:,:].
copy_
(
top_grad_out1_halo
)
grad_out1
[:,:
1
,:,:].
copy_
(
top_grad_out1_halo
)
#print("ctx.local_rank = %d, apply grad_out1 top halo (grad_out1.shape = %s)" % (ctx.local_rank, str(list(grad_out1.shape))))
#print("ctx.spatial_group_rank = %d, apply grad_out1 top halo (grad_out1.shape = %s)" % (ctx.spatial_group_rank, str(list(grad_out1.shape))))
if
ctx
.
local_rank
<
ctx
.
spatial_group_size
-
1
:
if
ctx
.
spatial_group_rank
<
ctx
.
spatial_group_size
-
1
:
torch
.
cuda
.
current_stream
().
wait_stream
(
ctx
.
stream2
)
grad_out1
[:,
Hs
-
1
:,:,:].
copy_
(
btm_grad_out1_halo
)
grad_out1
[:,
Hs
-
1
:,:,:].
copy_
(
btm_grad_out1_halo
)
#print("ctx.
local
_rank = %d, apply grad_out1 btm halo (grad_out1.shape = %s)" % (ctx.
local
_rank, str(list(grad_out1.shape))))
#print("ctx.
spatial_group
_rank = %d, apply grad_out1 btm halo (grad_out1.shape = %s)" % (ctx.
spatial_group
_rank, str(list(grad_out1.shape))))
fast_bottleneck
.
backward_rest
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
grad_out2
,
grad_out1
,
wgrad2
)
fast_bottleneck
.
backward_rest
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
grad_out2
,
grad_out1
,
wgrad2
)
return
(
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
*
grads
)
return
(
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
*
grads
)
spatial_bottleneck_function
=
SpatialBottleneckFunction
.
apply
spatial_bottleneck_function
=
SpatialBottleneckFunction
.
apply
# Communication free halo exchanger.
# NB! This halo exchanger does not exchange halos with neighbors as it should, it merely swaps the inputs
# NB! This is only useful for performance testing.
# NB! Do not use for actual production runs
class
HaloExchanger
(
object
):
def
__init__
(
self
):
self
.
stream1
=
torch
.
cuda
.
Stream
()
self
.
stream2
=
torch
.
cuda
.
Stream
()
class
HaloExchangerNoComm
(
HaloExchanger
):
def
__init__
(
self
,
world_size
,
spatial_group_size
,
rank
,
comm
):
super
(
HaloExchangerNoComm
,
self
).
__init__
()
def
left_right_halo_exchange
(
self
,
left_output_halo
,
right_output_halo
):
return
right_output_halo
,
left_output_halo
class
HaloExchangerAllGather
(
HaloExchanger
):
def
__init__
(
self
,
world_size
,
spatial_group_size
,
rank
,
comm
):
super
(
HaloExchangerAllGather
,
self
).
__init__
()
self
.
spatial_group_size
=
spatial_group_size
self
.
local_rank
=
rank
%
spatial_group_size
self
.
comm
=
comm
def
left_right_halo_exchange
(
self
,
left_output_halo
,
right_output_halo
):
N
,
Hh
,
W
,
C
=
list
(
left_output_halo
.
shape
)
send_halos
=
torch
.
empty
((
N
,
2
*
Hh
,
W
,
C
),
dtype
=
left_output_halo
.
dtype
,
device
=
left_output_halo
.
device
)
send_halos
[:,:
Hh
,:,:].
copy_
(
left_output_halo
)
send_halos
[:,
Hh
:,:,:].
copy_
(
right_output_halo
)
all_halos
=
torch
.
empty
((
N
,
2
*
Hh
*
self
.
spatial_group_size
,
W
,
C
),
dtype
=
left_output_halo
.
dtype
,
device
=
left_output_halo
.
device
)
all_halos
=
[
all_halos
[:,
i
*
2
*
Hh
:(
i
+
1
)
*
2
*
Hh
,:,:]
for
i
in
range
(
self
.
spatial_group_size
)]
torch
.
distributed
.
all_gather
(
all_halos
,
send_halos
,
group
=
self
.
comm
,
no_copy
=
True
)
left_input_halo
=
all_halos
[(
self
.
spatial_group_size
+
self
.
local_rank
-
1
)
%
self
.
spatial_group_size
][:,
Hh
:,:,:]
right_input_halo
=
all_halos
[(
self
.
local_rank
+
1
)
%
self
.
spatial_group_size
][:,:
Hh
,:,:]
return
left_input_halo
,
right_input_halo
class
HaloExchangerSendRecv
(
HaloExchanger
):
def
__init__
(
self
,
world_size
,
spatial_group_size
,
rank
,
comm
):
super
(
HaloExchangerSendRecv
,
self
).
__init__
()
self
.
world_size
=
world_size
self
.
spatial_group_size
=
spatial_group_size
nccl_id
=
inc
.
get_unique_nccl_id
(
1
).
cuda
()
torch
.
distributed
.
broadcast
(
nccl_id
,
0
)
nccl_id
=
nccl_id
.
cpu
()
self
.
handle
=
inc
.
init_nccl_comm
(
nccl_id
,
rank
,
world_size
)
def
left_right_halo_exchange
(
self
,
left_output_halo
,
right_output_halo
):
left_input_halo
,
right_input_halo
=
inc
.
left_right_halo_exchange
(
self
.
handle
,
left_output_halo
,
right_output_halo
,
self
.
spatial_group_size
)
return
left_input_halo
,
right_input_halo
class
SpatialBottleneck
(
torch
.
nn
.
Module
):
class
SpatialBottleneck
(
torch
.
nn
.
Module
):
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
# while original implementation places the stride at the first 1x1 convolution(self.conv1)
# while original implementation places the stride at the first 1x1 convolution(self.conv1)
...
@@ -393,7 +451,7 @@ class SpatialBottleneck(torch.nn.Module):
...
@@ -393,7 +451,7 @@ class SpatialBottleneck(torch.nn.Module):
def
__init__
(
self
,
in_channels
,
bottleneck_channels
,
out_channels
,
stride
=
1
,
groups
=
1
,
def
__init__
(
self
,
in_channels
,
bottleneck_channels
,
out_channels
,
stride
=
1
,
groups
=
1
,
dilation
=
1
,
norm_func
=
None
,
use_cudnn
=
False
,
explicit_nhwc
=
False
,
dilation
=
1
,
norm_func
=
None
,
use_cudnn
=
False
,
explicit_nhwc
=
False
,
spatial_
group_size
=
1
,
communicator
=
None
):
spatial_
parallel_args
=
None
):
super
(
SpatialBottleneck
,
self
).
__init__
()
super
(
SpatialBottleneck
,
self
).
__init__
()
if
groups
!=
1
:
if
groups
!=
1
:
raise
RuntimeError
(
'Only support groups == 1'
)
raise
RuntimeError
(
'Only support groups == 1'
)
...
@@ -447,26 +505,10 @@ class SpatialBottleneck(torch.nn.Module):
...
@@ -447,26 +505,10 @@ class SpatialBottleneck(torch.nn.Module):
p
.
data
=
p
.
data
.
permute
(
0
,
2
,
3
,
1
).
contiguous
()
p
.
data
=
p
.
data
.
permute
(
0
,
2
,
3
,
1
).
contiguous
()
# spatial communicator
# spatial communicator
self
.
spatial_group_size
=
spatial_group_size
if
spatial_parallel_args
is
None
:
if
spatial_group_size
>
1
:
self
.
spatial_parallel_args
=
(
1
,
0
,
None
,
None
,
None
)
world_size
=
dist
.
get_world_size
()
num_groups
=
world_size
//
spatial_group_size
assert
(
num_groups
*
spatial_group_size
==
world_size
),
"torch.distributed.get_world_size() must be multiple of group_size"
rank
=
dist
.
get_rank
()
self
.
local_rank
=
rank
%
spatial_group_size
if
communicator
is
None
:
for
group
in
range
(
num_groups
):
ranks
=
list
(
range
(
group
*
spatial_group_size
,(
group
+
1
)
*
spatial_group_size
))
comm
=
torch
.
distributed
.
new_group
(
ranks
=
ranks
)
if
rank
in
ranks
:
self
.
communicator
=
comm
else
:
self
.
communicator
=
communicator
self
.
stream1
=
torch
.
cuda
.
Stream
()
self
.
spatial_args
=
self
.
spatial_group_size
,
self
.
local_rank
,
self
.
communicator
,
self
.
stream1
else
:
else
:
self
.
spatial_args
=
1
,
0
,
None
,
None
self
.
spatial_parallel_args
=
spatial_parallel_args
return
return
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
...
@@ -483,7 +525,7 @@ class SpatialBottleneck(torch.nn.Module):
...
@@ -483,7 +525,7 @@ class SpatialBottleneck(torch.nn.Module):
w_scale
.
append
(
s4
)
w_scale
.
append
(
s4
)
w_bias
.
append
(
b4
)
w_bias
.
append
(
b4
)
out
=
spatial_bottleneck_function
(
*
self
.
spatial_args
,
self
.
explicit_nhwc
,
self
.
stride
,
w_scale
,
w_bias
,
x
,
*
self
.
w_conv
)
out
=
spatial_bottleneck_function
(
*
self
.
spatial_
parallel_
args
,
self
.
explicit_nhwc
,
self
.
stride
,
w_scale
,
w_bias
,
x
,
*
self
.
w_conv
)
return
out
return
out
if
self
.
explicit_nhwc
:
if
self
.
explicit_nhwc
:
...
@@ -510,3 +552,10 @@ class SpatialBottleneck(torch.nn.Module):
...
@@ -510,3 +552,10 @@ class SpatialBottleneck(torch.nn.Module):
out
=
self
.
relu
(
out
)
out
=
self
.
relu
(
out
)
return
out
return
out
_HALO_EXCHANGERS
=
Registry
({
"HaloExchangerNoComm"
:
HaloExchangerNoComm
,
"HaloExchangerAllGather"
:
HaloExchangerAllGather
,
"HaloExchangerSendRecv"
:
HaloExchangerSendRecv
,
})
apex/contrib/csrc/peer_memory/peer_memory.cpp
View file @
3ade5b26
...
@@ -17,12 +17,12 @@
...
@@ -17,12 +17,12 @@
#include "peer_memory_cuda.cuh"
#include "peer_memory_cuda.cuh"
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"allocate_raw"
,
&
apex
::
peer_memory
::
allocate_raw
,
"allocate_raw"
);
m
.
def
(
"allocate_raw"
,
&
apex
::
contrib
::
peer_memory
::
allocate_raw
,
"allocate_raw"
);
m
.
def
(
"free_raw"
,
&
apex
::
peer_memory
::
free_raw
,
"free_raw"
);
m
.
def
(
"free_raw"
,
&
apex
::
contrib
::
peer_memory
::
free_raw
,
"free_raw"
);
m
.
def
(
"get_raw_ipc_address"
,
&
apex
::
peer_memory
::
get_raw_ipc_address
,
"get_raw_ipc_address"
);
m
.
def
(
"get_raw_ipc_address"
,
&
apex
::
contrib
::
peer_memory
::
get_raw_ipc_address
,
"get_raw_ipc_address"
);
m
.
def
(
"get_raw_peers"
,
&
apex
::
peer_memory
::
get_raw_peers
,
"get_raw_peers"
);
m
.
def
(
"get_raw_peers"
,
&
apex
::
contrib
::
peer_memory
::
get_raw_peers
,
"get_raw_peers"
);
m
.
def
(
"blob_view_half"
,
&
apex
::
peer_memory
::
blob_view_half
,
"blob_view_half"
);
m
.
def
(
"blob_view_half"
,
&
apex
::
contrib
::
peer_memory
::
blob_view_half
,
"blob_view_half"
);
m
.
def
(
"blob_view_float"
,
&
apex
::
peer_memory
::
blob_view_float
,
"blob_view_float"
);
m
.
def
(
"blob_view_float"
,
&
apex
::
contrib
::
peer_memory
::
blob_view_float
,
"blob_view_float"
);
m
.
def
(
"blob_view_int"
,
&
apex
::
peer_memory
::
blob_view_int
,
"blob_view_int"
);
m
.
def
(
"blob_view_int"
,
&
apex
::
contrib
::
peer_memory
::
blob_view_int
,
"blob_view_int"
);
m
.
def
(
"push_pull_halos_1d"
,
&
apex
::
peer_memory
::
push_pull_halos_1d
,
"push_pull_halos_1d"
);
m
.
def
(
"push_pull_halos_1d"
,
&
apex
::
contrib
::
peer_memory
::
push_pull_halos_1d
,
"push_pull_halos_1d"
);
}
}
apex/contrib/csrc/peer_memory/peer_memory_cuda.cu
View file @
3ade5b26
...
@@ -6,6 +6,7 @@
...
@@ -6,6 +6,7 @@
#include <cassert>
#include <cassert>
#include <cuda_runtime_api.h>
#include <cuda_runtime_api.h>
#include <cooperative_groups.h>
#include <cooperative_groups.h>
#include "nccl.h"
namespace
cg
=
cooperative_groups
;
namespace
cg
=
cooperative_groups
;
#define CUDACHECK(cmd) do { \
#define CUDACHECK(cmd) do { \
...
@@ -214,9 +215,25 @@ __global__ void push_pull_halos_1d_kernel(
...
@@ -214,9 +215,25 @@ __global__ void push_pull_halos_1d_kernel(
strided_copy_kernel
<
T
,
is_HWC
>
(
bih
,
bih_stride_C
,
bih_stride_H
,
bih_stride_W
,
bix
,
bix_stride_C
,
bix_stride_H
,
bix_stride_W
,
NC
,
NH
,
NW
);
strided_copy_kernel
<
T
,
is_HWC
>
(
bih
,
bih_stride_C
,
bih_stride_H
,
bih_stride_W
,
bix
,
bix_stride_C
,
bix_stride_H
,
bix_stride_W
,
NC
,
NH
,
NW
);
}
}
__global__
void
delay_kernel
(
int
delay_nanoseconds
,
int
*
counter
)
{
if
(
blockIdx
.
x
==
0
&&
threadIdx
.
x
==
0
)
{
// waste time while doing something compiler can't predict, thus preventing it from optimizing away this code.
int
new_counter
=
0
;
double
elapsed
=
0
;
clock_t
start
=
clock
();
do
{
clock_t
now
=
clock
();
elapsed
=
(
double
)(
now
-
start
)
*
1e9
/
CLOCKS_PER_SEC
;
++
new_counter
;
}
while
(
elapsed
<
(
double
)
delay_nanoseconds
);
*
counter
=
new_counter
;
}
}
}
}
namespace
apex
{
namespace
peer_memory
{
namespace
apex
{
namespace
contrib
{
namespace
peer_memory
{
int64_t
allocate_raw
(
int64_t
size
)
int64_t
allocate_raw
(
int64_t
size
)
{
{
...
@@ -460,5 +477,5 @@ void push_pull_halos_1d(
...
@@ -460,5 +477,5 @@ void push_pull_halos_1d(
}
);
}
);
}
}
}
}
}
}
}
apex/contrib/csrc/peer_memory/peer_memory_cuda.cuh
View file @
3ade5b26
...
@@ -19,7 +19,7 @@
...
@@ -19,7 +19,7 @@
#ifndef _peer_memory_h_
#ifndef _peer_memory_h_
#define _peer_memory_h_
#define _peer_memory_h_
namespace
apex
{
namespace
peer_memory
{
namespace
apex
{
namespace
contrib
{
namespace
peer_memory
{
int64_t
allocate_raw
(
int64_t
size
);
int64_t
allocate_raw
(
int64_t
size
);
void
free_raw
(
int64_t
raw
);
void
free_raw
(
int64_t
raw
);
at
::
Tensor
get_raw_ipc_address
(
int64_t
raw
);
at
::
Tensor
get_raw_ipc_address
(
int64_t
raw
);
...
@@ -43,5 +43,5 @@ namespace apex { namespace peer_memory {
...
@@ -43,5 +43,5 @@ namespace apex { namespace peer_memory {
at
::
Tensor
btm_signal
,
// btm input signal in receiver device memory
at
::
Tensor
btm_signal
,
// btm input signal in receiver device memory
at
::
Tensor
waits
// top and btm signals for this rank
at
::
Tensor
waits
// top and btm signals for this rank
);
);
}
}
}
}
}
#endif
#endif
setup.py
View file @
3ade5b26
...
@@ -641,6 +641,20 @@ if "--peer_memory" in sys.argv:
...
@@ -641,6 +641,20 @@ if "--peer_memory" in sys.argv:
)
)
)
)
if
"--nccl_p2p"
in
sys
.
argv
:
sys
.
argv
.
remove
(
"--nccl_p2p"
)
raise_if_cuda_home_none
(
"--nccl_p2p"
)
ext_modules
.
append
(
CUDAExtension
(
name
=
"nccl_p2p"
,
sources
=
[
"apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cu"
,
"apex/contrib/csrc/nccl_p2p/nccl_p2p.cpp"
,
],
extra_compile_args
=
{
"cxx"
:
[
"-O3"
]
+
version_dependent_macros
+
generator_flag
},
)
)
setup
(
setup
(
name
=
"apex"
,
name
=
"apex"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment