Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
3ade5b26
Commit
3ade5b26
authored
Mar 24, 2022
by
Thor Johnsen
Browse files
Add bottleneck block
parent
b48898fb
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
187 additions
and
107 deletions
+187
-107
apex/contrib/bottleneck/bottleneck.py
apex/contrib/bottleneck/bottleneck.py
+144
-95
apex/contrib/csrc/peer_memory/peer_memory.cpp
apex/contrib/csrc/peer_memory/peer_memory.cpp
+8
-8
apex/contrib/csrc/peer_memory/peer_memory_cuda.cu
apex/contrib/csrc/peer_memory/peer_memory_cuda.cu
+19
-2
apex/contrib/csrc/peer_memory/peer_memory_cuda.cuh
apex/contrib/csrc/peer_memory/peer_memory_cuda.cuh
+2
-2
setup.py
setup.py
+14
-0
No files found.
apex/contrib/bottleneck/bottleneck.py
View file @
3ade5b26
import
torch
import
torch.distributed
as
dist
from
torch
import
nn
import
fast_bottleneck
from
maskrcnn_benchmark.utils.registry
import
Registry
import
maskrcnn_benchmark.SpatialBottleneck
as
fast_bottleneck
import
nccl_p2p
as
inc
def
kaiming_uniform_
(
tensor
,
a
=
0
,
mode
=
'fan_in'
,
nonlinearity
=
'leaky_relu'
):
weight_tensor_nchw
=
tensor
nn
.
init
.
kaiming_uniform_
(
weight_tensor_nchw
,
a
=
a
,
mode
=
mode
,
nonlinearity
=
nonlinearity
)
class
FrozenBatchNorm2d
(
torch
.
nn
.
Module
):
class
FrozenBatchNorm2d
(
torch
.
jit
.
Script
Module
):
"""
BatchNorm2d where the batch statistics and the affine parameters are fixed
"""
...
...
@@ -18,7 +20,9 @@ class FrozenBatchNorm2d(torch.nn.Module):
self
.
register_buffer
(
"running_mean"
,
torch
.
zeros
(
n
))
self
.
register_buffer
(
"running_var"
,
torch
.
ones
(
n
))
def
get_scale_bias
(
self
,
nhwc
=
False
):
@
torch
.
jit
.
script_method
def
get_scale_bias
(
self
,
nhwc
):
# type: (bool) -> List[torch.Tensor]
scale
=
self
.
weight
*
self
.
running_var
.
rsqrt
()
bias
=
self
.
bias
-
self
.
running_mean
*
scale
if
nhwc
:
...
...
@@ -29,11 +33,11 @@ class FrozenBatchNorm2d(torch.nn.Module):
bias
=
bias
.
reshape
(
1
,
-
1
,
1
,
1
)
return
scale
,
bias
@
torch
.
jit
.
script_method
def
forward
(
self
,
x
):
scale
,
bias
=
self
.
get_scale_bias
()
scale
,
bias
=
self
.
get_scale_bias
(
False
)
return
x
*
scale
+
bias
@
torch
.
jit
.
script
def
drelu_dscale1
(
grad_o
,
output
,
scale1
):
relu_mask
=
(
output
>
0
)
...
...
@@ -217,7 +221,11 @@ class Bottleneck(torch.nn.Module):
class
SpatialBottleneckFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
spatial_group_size
,
local_rank
,
comm
,
stream1
,
nhwc
,
stride_1x1
,
scale
,
bias
,
x
,
*
conv
):
def
forward
(
ctx
,
spatial_group_size
,
spatial_group_rank
,
spatial_communicator
,
spatial_halo_exchanger
,
spatial_stream
,
nhwc
,
stride_1x1
,
scale
,
bias
,
x
,
*
conv
):
if
spatial_group_size
>
1
:
stream1
=
spatial_halo_exchanger
.
stream1
stream2
=
spatial_halo_exchanger
.
stream2
# TODO: clean up order of tensors
args
=
[
x
,
*
conv
[
0
:
3
],
*
scale
[
0
:
3
],
*
bias
[
0
:
3
]]
ctx
.
downsample
=
len
(
conv
)
>
3
...
...
@@ -232,38 +240,38 @@ class SpatialBottleneckFunction(torch.autograd.Function):
outputs
=
fast_bottleneck
.
forward_init
(
nhwc
,
stride_1x1
,
args
)
fast_bottleneck
.
forward_out1
(
nhwc
,
stride_1x1
,
args
,
outputs
)
fast_bottleneck
.
forward_out2
(
nhwc
,
stride_1x1
,
args
,
outputs
)
# do halo exchange for outputs[0] (out1)
# compute halo cells for outputs[1]
if
spatial_group_size
>
1
:
out1
=
outputs
[
0
]
N
,
Hs
,
W
,
C
=
list
(
out1
.
shape
)
stream1
.
wait_stream
(
torch
.
cuda
.
current_stream
())
with
torch
.
cuda
.
stream
(
stream1
):
# copy halos to send buffer
send_halos
=
torch
.
empty
((
N
,
2
,
W
,
C
),
dtype
=
out1
.
dtype
,
device
=
out1
.
device
)
send_halos
[:,:
1
,:,:].
copy_
(
out1
[:,:
1
,:,:])
send_halos
[:,
1
:,:,:].
copy_
(
out1
[:,
Hs
-
1
:,:,:])
all_halos
=
torch
.
empty
((
N
,
2
*
spatial_group_size
,
W
,
C
),
dtype
=
out1
.
dtype
,
device
=
out1
.
device
)
all_halos
=
[
all_halos
[:,
i
*
2
:(
i
+
1
)
*
2
,:,:]
for
i
in
range
(
spatial_group_size
)]
dist
.
all_gather
(
all_halos
,
send_halos
,
group
=
comm
)
fat_halo
=
torch
.
empty
((
N
,
3
,
W
,
C
),
dtype
=
out1
.
dtype
,
device
=
out1
.
device
)
top_out1_halo
=
all_halos
[(
spatial_group_size
+
local_rank
-
1
)
%
spatial_group_size
][:,
1
:,:,:]
if
local_rank
>
0
:
fat_halo
[:,:
1
,:,:].
copy_
(
top_out1_halo
)
fat_halo
[:,
1
:
3
,:,:].
copy_
(
out1
[:,:
2
,:,:])
top_out2
=
fast_bottleneck
.
forward_out2_halo
(
nhwc
,
fat_halo
,
args
)
btm_out1_halo
=
all_halos
[(
local_rank
+
1
)
%
spatial_group_size
][:,:
1
,:,:]
if
local_rank
<
spatial_group_size
-
1
:
fat_halo
[:,
0
:
2
,:,:].
copy_
(
out1
[:,
Hs
-
2
:,:,:])
fat_halo
[:,
2
:,:,:].
copy_
(
btm_out1_halo
)
btm_out2
=
fast_bottleneck
.
forward_out2_halo
(
nhwc
,
fat_halo
,
args
)
torch
.
cuda
.
current_stream
().
wait_stream
(
stream1
)
top_out1_halo
,
btm_out1_halo
=
spatial_halo_exchanger
.
left_right_halo_exchange
(
out1
[:,:
1
,:,:],
out1
[:,
Hs
-
1
:,:,:])
if
spatial_group_rank
<
spatial_group_size
-
1
:
stream2
.
wait_stream
(
stream1
)
with
torch
.
cuda
.
stream
(
stream2
):
btm_fat_halo
=
torch
.
empty
((
N
,
3
,
W
,
C
),
dtype
=
out1
.
dtype
,
device
=
out1
.
device
)
btm_fat_halo
[:,
0
:
2
,:,:].
copy_
(
out1
[:,
Hs
-
2
:,:,:])
btm_fat_halo
[:,
2
:,:,:].
copy_
(
btm_out1_halo
)
btm_out2
=
fast_bottleneck
.
forward_out2_halo
(
nhwc
,
btm_fat_halo
,
args
)
if
spatial_group_rank
>
0
:
with
torch
.
cuda
.
stream
(
stream1
):
top_fat_halo
=
torch
.
empty
((
N
,
3
,
W
,
C
),
dtype
=
out1
.
dtype
,
device
=
out1
.
device
)
top_fat_halo
[:,:
1
,:,:].
copy_
(
top_out1_halo
)
top_fat_halo
[:,
1
:
3
,:,:].
copy_
(
out1
[:,:
2
,:,:])
top_out2
=
fast_bottleneck
.
forward_out2_halo
(
nhwc
,
top_fat_halo
,
args
)
inc
.
add_delay
(
10
)
fast_bottleneck
.
forward_out2
(
nhwc
,
stride_1x1
,
args
,
outputs
)
# compute halo cells for outputs[1] (out2)
if
spatial_group_size
>
1
:
out2
=
outputs
[
1
]
if
local_rank
>
0
:
if
spatial_group_rank
>
0
:
torch
.
cuda
.
current_stream
().
wait_stream
(
stream1
)
out2
[:,:
1
,:,:].
copy_
(
top_out2
)
if
local_rank
<
spatial_group_size
-
1
:
if
spatial_group_rank
<
spatial_group_size
-
1
:
torch
.
cuda
.
current_stream
().
wait_stream
(
stream2
)
out2
[:,
Hs
-
1
:,:,:].
copy_
(
btm_out2
)
fast_bottleneck
.
forward_rest
(
nhwc
,
stride_1x1
,
args
,
outputs
)
...
...
@@ -276,9 +284,11 @@ class SpatialBottleneckFunction(torch.autograd.Function):
ctx
.
nhwc
=
nhwc
ctx
.
stride_1x1
=
stride_1x1
ctx
.
spatial_group_size
=
spatial_group_size
ctx
.
local_rank
=
local_rank
ctx
.
comm
=
comm
if
spatial_group_size
>
1
:
ctx
.
spatial_group_rank
=
spatial_group_rank
ctx
.
spatial_halo_exchanger
=
spatial_halo_exchanger
ctx
.
stream1
=
stream1
ctx
.
stream2
=
stream2
return
outputs
[
2
]
# backward relu is not exposed, MUL with mask used now
...
...
@@ -312,54 +322,52 @@ class SpatialBottleneckFunction(torch.autograd.Function):
grads
=
fast_bottleneck
.
backward_init
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
)
grad_out2
=
fast_bottleneck
.
backward_grad_out2
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
)
# do halo exchange of grad_out2 here
# compute halo cells for grad_out1
if
ctx
.
spatial_group_size
>
1
:
N
,
Hs
,
W
,
C
=
list
(
grad_out2
.
shape
)
relu1
=
t_list
[
12
]
ctx
.
stream1
.
wait_stream
(
torch
.
cuda
.
current_stream
())
with
torch
.
cuda
.
stream
(
ctx
.
stream1
):
top_halo
,
btm_halo
=
ctx
.
spatial_halo_exchanger
.
left_right_halo_exchange
(
grad_out2
[:,:
1
,:,:],
grad_out2
[:,
Hs
-
1
:,:,:])
# copy halos to send buffer
if
ctx
.
spatial_group_rank
<
ctx
.
spatial_group_size
-
1
:
ctx
.
stream2
.
wait_stream
(
ctx
.
stream1
)
with
torch
.
cuda
.
stream
(
ctx
.
stream2
):
btm_fat_halo
=
torch
.
empty
((
N
,
3
,
W
,
C
),
dtype
=
grad_out2
.
dtype
,
device
=
grad_out2
.
device
)
btm_relu_halo
=
torch
.
empty
((
N
,
3
,
W
,
C
),
dtype
=
grad_out2
.
dtype
,
device
=
grad_out2
.
device
)
btm_fat_halo
[:,:
2
,:,:].
copy_
(
grad_out2
[:,
Hs
-
2
:,:,:])
btm_fat_halo
[:,
2
:,:,:].
copy_
(
btm_halo
)
btm_relu_halo
[:,:
2
,:,:].
copy_
(
relu1
[:,
Hs
-
2
:,:,:])
btm_relu_halo
[:,
2
:,:,:].
zero_
()
btm_grad_out1_halo
=
fast_bottleneck
.
backward_grad_out1_halo
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
btm_fat_halo
,
btm_relu_halo
)
btm_grad_out1_halo
=
btm_grad_out1_halo
[:,
1
:
2
,:,:]
if
ctx
.
spatial_group_rank
>
0
:
with
torch
.
cuda
.
stream
(
ctx
.
stream1
):
top_fat_halo
=
torch
.
empty
((
N
,
3
,
W
,
C
),
dtype
=
grad_out2
.
dtype
,
device
=
grad_out2
.
device
)
top_relu_halo
=
torch
.
empty
((
N
,
3
,
W
,
C
),
dtype
=
grad_out2
.
dtype
,
device
=
grad_out2
.
device
)
top_fat_halo
[:,:
1
,:,:].
copy_
(
top_halo
)
top_fat_halo
[:,
1
:,:,:].
copy_
(
grad_out2
[:,:
2
,:,:])
top_relu_halo
[:,:
1
,:,:].
zero_
()
top_relu_halo
[:,
1
:,:,:].
copy_
(
relu1
[:,:
2
,:,:])
top_grad_out1_halo
=
fast_bottleneck
.
backward_grad_out1_halo
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
top_fat_halo
,
top_relu_halo
)
top_grad_out1_halo
=
top_grad_out1_halo
[:,
1
:
2
,:,:]
inc
.
add_delay
(
10
)
# compute wgrad2 for internal cells
wgrad2
=
fast_bottleneck
.
backward_wgrad2
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
grad_out2
)
# apply wgrad2 halos
if
ctx
.
spatial_group_size
>
1
:
if
ctx
.
local
_rank
>
0
:
if
ctx
.
spatial_group
_rank
>
0
:
top_grad2_halo
=
grad_out2
[:,:
1
,:,:]
top_wgrad2_halo
=
fast_bottleneck
.
backward_wgrad2_halo
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
top_out1_halo
,
top_grad2_halo
)
wgrad2
[:,:
1
,:,:].
add_
(
top_wgrad2_halo
)
if
ctx
.
local
_rank
<
ctx
.
spatial_group_size
-
1
:
if
ctx
.
spatial_group
_rank
<
ctx
.
spatial_group_size
-
1
:
btm_grad2_halo
=
grad_out2
[:,
-
1
:,:,:]
btm_wgrad2_halo
=
fast_bottleneck
.
backward_wgrad2_halo
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
btm_out1_halo
,
btm_grad2_halo
)
wgrad2
[:,
-
1
:,:,:].
add_
(
btm_wgrad2_halo
)
# do halo exchange of grad_out2 here
# compute halo cells for grad_out1
if
ctx
.
spatial_group_size
>
1
:
N
,
Hs
,
W
,
C
=
list
(
grad_out2
.
shape
)
ctx
.
stream1
.
wait_stream
(
torch
.
cuda
.
current_stream
())
with
torch
.
cuda
.
stream
(
ctx
.
stream1
):
# copy halos to send buffer
send_halos
=
torch
.
empty
((
N
,
2
,
W
,
C
),
dtype
=
grad_out2
.
dtype
,
device
=
grad_out2
.
device
)
send_halos
[:,:
1
,:,:].
copy_
(
grad_out2
[:,:
1
,:,:])
send_halos
[:,
1
:,:,:].
copy_
(
grad_out2
[:,
Hs
-
1
:,:,:])
all_halos
=
torch
.
empty
((
N
,
2
*
ctx
.
spatial_group_size
,
W
,
C
),
dtype
=
grad_out2
.
dtype
,
device
=
grad_out2
.
device
)
all_halos
=
[
all_halos
[:,
i
*
2
:(
i
+
1
)
*
2
,:,:]
for
i
in
range
(
ctx
.
spatial_group_size
)]
dist
.
all_gather
(
all_halos
,
send_halos
,
group
=
ctx
.
comm
)
relu1
=
t_list
[
12
]
fat_halo
=
torch
.
empty
((
N
,
3
,
W
,
C
),
dtype
=
grad_out2
.
dtype
,
device
=
grad_out2
.
device
)
relu_halo
=
torch
.
empty
((
N
,
3
,
W
,
C
),
dtype
=
grad_out2
.
dtype
,
device
=
grad_out2
.
device
)
if
ctx
.
local_rank
>
0
:
top_halo
=
all_halos
[
ctx
.
local_rank
-
1
][:,
1
:,:,:]
fat_halo
[:,:
1
,:,:].
copy_
(
top_halo
)
fat_halo
[:,
1
:,:,:].
copy_
(
grad_out2
[:,:
2
,:,:])
relu_halo
[:,:
1
,:,:].
zero_
()
relu_halo
[:,
1
:,:,:].
copy_
(
relu1
[:,:
2
,:,:])
top_grad_out1_halo
=
fast_bottleneck
.
backward_grad_out1_halo
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
fat_halo
,
relu_halo
)
top_grad_out1_halo
=
top_grad_out1_halo
[:,
1
:
2
,:,:]
if
ctx
.
local_rank
<
ctx
.
spatial_group_size
-
1
:
btm_halo
=
all_halos
[
ctx
.
local_rank
+
1
][:,:
1
,:,:]
fat_halo
[:,:
2
,:,:].
copy_
(
grad_out2
[:,
Hs
-
2
:,:,:])
fat_halo
[:,
2
:,:,:].
copy_
(
btm_halo
)
relu_halo
[:,:
2
,:,:].
copy_
(
relu1
[:,
Hs
-
2
:,:,:])
relu_halo
[:,
2
:,:,:].
zero_
()
btm_grad_out1_halo
=
fast_bottleneck
.
backward_grad_out1_halo
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
fat_halo
,
relu_halo
)
btm_grad_out1_halo
=
btm_grad_out1_halo
[:,
1
:
2
,:,:]
# compute grad_out1 for internal cells
grad_out1
=
fast_bottleneck
.
backward_grad_out1
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
grad_out2
)
...
...
@@ -369,20 +377,70 @@ class SpatialBottleneckFunction(torch.autograd.Function):
z
=
t_list
[
4
]
relu1
=
t_list
[
12
]
#print("w.shape = %s, z.shape = %s, relu1.shape = %s" % (str(list(w.shape)), str(list(z.shape)), str(list(relu1.shape))))
if
ctx
.
spatial_group_rank
>
0
:
torch
.
cuda
.
current_stream
().
wait_stream
(
ctx
.
stream1
)
if
ctx
.
local_rank
>
0
:
grad_out1
[:,:
1
,:,:].
copy_
(
top_grad_out1_halo
)
#print("ctx.local_rank = %d, apply grad_out1 top halo (grad_out1.shape = %s)" % (ctx.local_rank, str(list(grad_out1.shape))))
if
ctx
.
local_rank
<
ctx
.
spatial_group_size
-
1
:
#print("ctx.spatial_group_rank = %d, apply grad_out1 top halo (grad_out1.shape = %s)" % (ctx.spatial_group_rank, str(list(grad_out1.shape))))
if
ctx
.
spatial_group_rank
<
ctx
.
spatial_group_size
-
1
:
torch
.
cuda
.
current_stream
().
wait_stream
(
ctx
.
stream2
)
grad_out1
[:,
Hs
-
1
:,:,:].
copy_
(
btm_grad_out1_halo
)
#print("ctx.
local
_rank = %d, apply grad_out1 btm halo (grad_out1.shape = %s)" % (ctx.
local
_rank, str(list(grad_out1.shape))))
#print("ctx.
spatial_group
_rank = %d, apply grad_out1 btm halo (grad_out1.shape = %s)" % (ctx.
spatial_group
_rank, str(list(grad_out1.shape))))
fast_bottleneck
.
backward_rest
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
grad_out2
,
grad_out1
,
wgrad2
)
return
(
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
*
grads
)
return
(
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
*
grads
)
spatial_bottleneck_function
=
SpatialBottleneckFunction
.
apply
# Communication free halo exchanger.
# NB! This halo exchanger does not exchange halos with neighbors as it should, it merely swaps the inputs
# NB! This is only useful for performance testing.
# NB! Do not use for actual production runs
class
HaloExchanger
(
object
):
def
__init__
(
self
):
self
.
stream1
=
torch
.
cuda
.
Stream
()
self
.
stream2
=
torch
.
cuda
.
Stream
()
class
HaloExchangerNoComm
(
HaloExchanger
):
def
__init__
(
self
,
world_size
,
spatial_group_size
,
rank
,
comm
):
super
(
HaloExchangerNoComm
,
self
).
__init__
()
def
left_right_halo_exchange
(
self
,
left_output_halo
,
right_output_halo
):
return
right_output_halo
,
left_output_halo
class
HaloExchangerAllGather
(
HaloExchanger
):
def
__init__
(
self
,
world_size
,
spatial_group_size
,
rank
,
comm
):
super
(
HaloExchangerAllGather
,
self
).
__init__
()
self
.
spatial_group_size
=
spatial_group_size
self
.
local_rank
=
rank
%
spatial_group_size
self
.
comm
=
comm
def
left_right_halo_exchange
(
self
,
left_output_halo
,
right_output_halo
):
N
,
Hh
,
W
,
C
=
list
(
left_output_halo
.
shape
)
send_halos
=
torch
.
empty
((
N
,
2
*
Hh
,
W
,
C
),
dtype
=
left_output_halo
.
dtype
,
device
=
left_output_halo
.
device
)
send_halos
[:,:
Hh
,:,:].
copy_
(
left_output_halo
)
send_halos
[:,
Hh
:,:,:].
copy_
(
right_output_halo
)
all_halos
=
torch
.
empty
((
N
,
2
*
Hh
*
self
.
spatial_group_size
,
W
,
C
),
dtype
=
left_output_halo
.
dtype
,
device
=
left_output_halo
.
device
)
all_halos
=
[
all_halos
[:,
i
*
2
*
Hh
:(
i
+
1
)
*
2
*
Hh
,:,:]
for
i
in
range
(
self
.
spatial_group_size
)]
torch
.
distributed
.
all_gather
(
all_halos
,
send_halos
,
group
=
self
.
comm
,
no_copy
=
True
)
left_input_halo
=
all_halos
[(
self
.
spatial_group_size
+
self
.
local_rank
-
1
)
%
self
.
spatial_group_size
][:,
Hh
:,:,:]
right_input_halo
=
all_halos
[(
self
.
local_rank
+
1
)
%
self
.
spatial_group_size
][:,:
Hh
,:,:]
return
left_input_halo
,
right_input_halo
class
HaloExchangerSendRecv
(
HaloExchanger
):
def
__init__
(
self
,
world_size
,
spatial_group_size
,
rank
,
comm
):
super
(
HaloExchangerSendRecv
,
self
).
__init__
()
self
.
world_size
=
world_size
self
.
spatial_group_size
=
spatial_group_size
nccl_id
=
inc
.
get_unique_nccl_id
(
1
).
cuda
()
torch
.
distributed
.
broadcast
(
nccl_id
,
0
)
nccl_id
=
nccl_id
.
cpu
()
self
.
handle
=
inc
.
init_nccl_comm
(
nccl_id
,
rank
,
world_size
)
def
left_right_halo_exchange
(
self
,
left_output_halo
,
right_output_halo
):
left_input_halo
,
right_input_halo
=
inc
.
left_right_halo_exchange
(
self
.
handle
,
left_output_halo
,
right_output_halo
,
self
.
spatial_group_size
)
return
left_input_halo
,
right_input_halo
class
SpatialBottleneck
(
torch
.
nn
.
Module
):
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
# while original implementation places the stride at the first 1x1 convolution(self.conv1)
...
...
@@ -393,7 +451,7 @@ class SpatialBottleneck(torch.nn.Module):
def
__init__
(
self
,
in_channels
,
bottleneck_channels
,
out_channels
,
stride
=
1
,
groups
=
1
,
dilation
=
1
,
norm_func
=
None
,
use_cudnn
=
False
,
explicit_nhwc
=
False
,
spatial_
group_size
=
1
,
communicator
=
None
):
spatial_
parallel_args
=
None
):
super
(
SpatialBottleneck
,
self
).
__init__
()
if
groups
!=
1
:
raise
RuntimeError
(
'Only support groups == 1'
)
...
...
@@ -447,26 +505,10 @@ class SpatialBottleneck(torch.nn.Module):
p
.
data
=
p
.
data
.
permute
(
0
,
2
,
3
,
1
).
contiguous
()
# spatial communicator
self
.
spatial_group_size
=
spatial_group_size
if
spatial_group_size
>
1
:
world_size
=
dist
.
get_world_size
()
num_groups
=
world_size
//
spatial_group_size
assert
(
num_groups
*
spatial_group_size
==
world_size
),
"torch.distributed.get_world_size() must be multiple of group_size"
rank
=
dist
.
get_rank
()
self
.
local_rank
=
rank
%
spatial_group_size
if
communicator
is
None
:
for
group
in
range
(
num_groups
):
ranks
=
list
(
range
(
group
*
spatial_group_size
,(
group
+
1
)
*
spatial_group_size
))
comm
=
torch
.
distributed
.
new_group
(
ranks
=
ranks
)
if
rank
in
ranks
:
self
.
communicator
=
comm
else
:
self
.
communicator
=
communicator
self
.
stream1
=
torch
.
cuda
.
Stream
()
self
.
spatial_args
=
self
.
spatial_group_size
,
self
.
local_rank
,
self
.
communicator
,
self
.
stream1
if
spatial_parallel_args
is
None
:
self
.
spatial_parallel_args
=
(
1
,
0
,
None
,
None
,
None
)
else
:
self
.
spatial_args
=
1
,
0
,
None
,
None
self
.
spatial_parallel_args
=
spatial_parallel_args
return
def
forward
(
self
,
x
):
...
...
@@ -483,7 +525,7 @@ class SpatialBottleneck(torch.nn.Module):
w_scale
.
append
(
s4
)
w_bias
.
append
(
b4
)
out
=
spatial_bottleneck_function
(
*
self
.
spatial_args
,
self
.
explicit_nhwc
,
self
.
stride
,
w_scale
,
w_bias
,
x
,
*
self
.
w_conv
)
out
=
spatial_bottleneck_function
(
*
self
.
spatial_
parallel_
args
,
self
.
explicit_nhwc
,
self
.
stride
,
w_scale
,
w_bias
,
x
,
*
self
.
w_conv
)
return
out
if
self
.
explicit_nhwc
:
...
...
@@ -510,3 +552,10 @@ class SpatialBottleneck(torch.nn.Module):
out
=
self
.
relu
(
out
)
return
out
_HALO_EXCHANGERS
=
Registry
({
"HaloExchangerNoComm"
:
HaloExchangerNoComm
,
"HaloExchangerAllGather"
:
HaloExchangerAllGather
,
"HaloExchangerSendRecv"
:
HaloExchangerSendRecv
,
})
apex/contrib/csrc/peer_memory/peer_memory.cpp
View file @
3ade5b26
...
...
@@ -17,12 +17,12 @@
#include "peer_memory_cuda.cuh"
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"allocate_raw"
,
&
apex
::
peer_memory
::
allocate_raw
,
"allocate_raw"
);
m
.
def
(
"free_raw"
,
&
apex
::
peer_memory
::
free_raw
,
"free_raw"
);
m
.
def
(
"get_raw_ipc_address"
,
&
apex
::
peer_memory
::
get_raw_ipc_address
,
"get_raw_ipc_address"
);
m
.
def
(
"get_raw_peers"
,
&
apex
::
peer_memory
::
get_raw_peers
,
"get_raw_peers"
);
m
.
def
(
"blob_view_half"
,
&
apex
::
peer_memory
::
blob_view_half
,
"blob_view_half"
);
m
.
def
(
"blob_view_float"
,
&
apex
::
peer_memory
::
blob_view_float
,
"blob_view_float"
);
m
.
def
(
"blob_view_int"
,
&
apex
::
peer_memory
::
blob_view_int
,
"blob_view_int"
);
m
.
def
(
"push_pull_halos_1d"
,
&
apex
::
peer_memory
::
push_pull_halos_1d
,
"push_pull_halos_1d"
);
m
.
def
(
"allocate_raw"
,
&
apex
::
contrib
::
peer_memory
::
allocate_raw
,
"allocate_raw"
);
m
.
def
(
"free_raw"
,
&
apex
::
contrib
::
peer_memory
::
free_raw
,
"free_raw"
);
m
.
def
(
"get_raw_ipc_address"
,
&
apex
::
contrib
::
peer_memory
::
get_raw_ipc_address
,
"get_raw_ipc_address"
);
m
.
def
(
"get_raw_peers"
,
&
apex
::
contrib
::
peer_memory
::
get_raw_peers
,
"get_raw_peers"
);
m
.
def
(
"blob_view_half"
,
&
apex
::
contrib
::
peer_memory
::
blob_view_half
,
"blob_view_half"
);
m
.
def
(
"blob_view_float"
,
&
apex
::
contrib
::
peer_memory
::
blob_view_float
,
"blob_view_float"
);
m
.
def
(
"blob_view_int"
,
&
apex
::
contrib
::
peer_memory
::
blob_view_int
,
"blob_view_int"
);
m
.
def
(
"push_pull_halos_1d"
,
&
apex
::
contrib
::
peer_memory
::
push_pull_halos_1d
,
"push_pull_halos_1d"
);
}
apex/contrib/csrc/peer_memory/peer_memory_cuda.cu
View file @
3ade5b26
...
...
@@ -6,6 +6,7 @@
#include <cassert>
#include <cuda_runtime_api.h>
#include <cooperative_groups.h>
#include "nccl.h"
namespace
cg
=
cooperative_groups
;
#define CUDACHECK(cmd) do { \
...
...
@@ -214,9 +215,25 @@ __global__ void push_pull_halos_1d_kernel(
strided_copy_kernel
<
T
,
is_HWC
>
(
bih
,
bih_stride_C
,
bih_stride_H
,
bih_stride_W
,
bix
,
bix_stride_C
,
bix_stride_H
,
bix_stride_W
,
NC
,
NH
,
NW
);
}
__global__
void
delay_kernel
(
int
delay_nanoseconds
,
int
*
counter
)
{
if
(
blockIdx
.
x
==
0
&&
threadIdx
.
x
==
0
)
{
// waste time while doing something compiler can't predict, thus preventing it from optimizing away this code.
int
new_counter
=
0
;
double
elapsed
=
0
;
clock_t
start
=
clock
();
do
{
clock_t
now
=
clock
();
elapsed
=
(
double
)(
now
-
start
)
*
1e9
/
CLOCKS_PER_SEC
;
++
new_counter
;
}
while
(
elapsed
<
(
double
)
delay_nanoseconds
);
*
counter
=
new_counter
;
}
}
}
namespace
apex
{
namespace
peer_memory
{
namespace
apex
{
namespace
contrib
{
namespace
peer_memory
{
int64_t
allocate_raw
(
int64_t
size
)
{
...
...
@@ -460,5 +477,5 @@ void push_pull_halos_1d(
}
);
}
}
}
}
}
}
apex/contrib/csrc/peer_memory/peer_memory_cuda.cuh
View file @
3ade5b26
...
...
@@ -19,7 +19,7 @@
#ifndef _peer_memory_h_
#define _peer_memory_h_
namespace
apex
{
namespace
peer_memory
{
namespace
apex
{
namespace
contrib
{
namespace
peer_memory
{
int64_t
allocate_raw
(
int64_t
size
);
void
free_raw
(
int64_t
raw
);
at
::
Tensor
get_raw_ipc_address
(
int64_t
raw
);
...
...
@@ -43,5 +43,5 @@ namespace apex { namespace peer_memory {
at
::
Tensor
btm_signal
,
// btm input signal in receiver device memory
at
::
Tensor
waits
// top and btm signals for this rank
);
}
}
}
}
}
#endif
setup.py
View file @
3ade5b26
...
...
@@ -641,6 +641,20 @@ if "--peer_memory" in sys.argv:
)
)
if
"--nccl_p2p"
in
sys
.
argv
:
sys
.
argv
.
remove
(
"--nccl_p2p"
)
raise_if_cuda_home_none
(
"--nccl_p2p"
)
ext_modules
.
append
(
CUDAExtension
(
name
=
"nccl_p2p"
,
sources
=
[
"apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cu"
,
"apex/contrib/csrc/nccl_p2p/nccl_p2p.cpp"
,
],
extra_compile_args
=
{
"cxx"
:
[
"-O3"
]
+
version_dependent_macros
+
generator_flag
},
)
)
setup
(
name
=
"apex"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment