Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
6d3c75e5
Commit
6d3c75e5
authored
Dec 10, 2018
by
Jie
Browse files
Adding process group in convert_syncbn_model
parent
920da6da
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
30 additions
and
15 deletions
+30
-15
apex/parallel/__init__.py
apex/parallel/__init__.py
+2
-2
apex/parallel/sync_batchnorm.py
apex/parallel/sync_batchnorm.py
+18
-7
apex/parallel/sync_batchnorm_kernel.py
apex/parallel/sync_batchnorm_kernel.py
+10
-6
No files found.
apex/parallel/__init__.py
View file @
6d3c75e5
...
...
@@ -8,7 +8,7 @@ except ImportError:
print
(
"Warning: apex was installed without --cuda_ext. Fused syncbn kernels will be unavailable. Python fallbacks will be used instead."
)
from
.sync_batchnorm
import
SyncBatchNorm
def
convert_syncbn_model
(
module
):
def
convert_syncbn_model
(
module
,
process_group
=
None
):
'''
Recursively traverse module and its children to replace all
`torch.nn.modules.batchnorm._BatchNorm` with `apex.parallel.SyncBatchNorm`
...
...
@@ -27,7 +27,7 @@ def convert_syncbn_model(module):
'''
mod
=
module
if
isinstance
(
module
,
torch
.
nn
.
modules
.
batchnorm
.
_BatchNorm
):
mod
=
SyncBatchNorm
(
module
.
num_features
,
module
.
eps
,
module
.
momentum
,
module
.
affine
,
module
.
track_running_stats
)
mod
=
SyncBatchNorm
(
module
.
num_features
,
module
.
eps
,
module
.
momentum
,
module
.
affine
,
module
.
track_running_stats
,
process_group
)
mod
.
running_mean
=
module
.
running_mean
mod
.
running_var
=
module
.
running_var
if
module
.
affine
:
...
...
apex/parallel/sync_batchnorm.py
View file @
6d3c75e5
...
...
@@ -44,8 +44,12 @@ class SyncBatchNorm(_BatchNorm):
>>> out = sbn(inp)
"""
def
__init__
(
self
,
num_features
,
eps
=
1e-5
,
momentum
=
0.1
,
affine
=
True
,
track_running_stats
=
True
):
def
__init__
(
self
,
num_features
,
eps
=
1e-5
,
momentum
=
0.1
,
affine
=
True
,
track_running_stats
=
True
,
process_group
=
None
):
super
(
SyncBatchNorm
,
self
).
__init__
(
num_features
,
eps
=
eps
,
momentum
=
momentum
,
affine
=
affine
,
track_running_stats
=
track_running_stats
)
self
.
process_group
=
process_group
def
_specify_process_group
(
self
,
process_group
):
self
.
process_group
=
process_group
def
forward
(
self
,
input
):
torch
.
cuda
.
nvtx
.
range_push
(
"sync_bn_fw_with_mean_var"
)
...
...
@@ -56,6 +60,13 @@ class SyncBatchNorm(_BatchNorm):
torch
.
cuda
.
nvtx
.
range_pop
()
return
F
.
batch_norm
(
input
,
self
.
running_mean
,
self
.
running_var
,
self
.
weight
,
self
.
bias
,
False
,
0.0
,
self
.
eps
)
else
:
process_group
=
self
.
process_group
world_size
=
0
if
self
.
process_group
:
world_size
=
torch
.
distributed
.
get_world_size
(
process_group
)
else
:
process_group
=
torch
.
distributed
.
get_default_group
()
world_size
=
torch
.
distributed
.
get_world_size
()
self
.
num_batches_tracked
+=
1
with
torch
.
no_grad
():
channel_first_input
=
input
.
transpose
(
0
,
1
).
contiguous
()
...
...
@@ -69,12 +80,12 @@ class SyncBatchNorm(_BatchNorm):
squashed_input_tensor_view
,
2
).
mean
(
1
)
if
torch
.
distributed
.
is_initialized
():
torch
.
distributed
.
all_reduce
(
local_mean
,
op
=
torch
.
distributed
.
reduce_op
.
SUM
)
mean
=
local_mean
/
torch
.
distributed
.
get_
world_size
()
local_mean
,
torch
.
distributed
.
reduce_op
.
SUM
,
process_group
)
mean
=
local_mean
/
world_size
torch
.
distributed
.
all_reduce
(
local_sqr_mean
,
op
=
torch
.
distributed
.
reduce_op
.
SUM
)
sqr_mean
=
local_sqr_mean
/
torch
.
distributed
.
get_
world_size
()
m
=
local_m
*
torch
.
distributed
.
get_
world_size
()
local_sqr_mean
,
torch
.
distributed
.
reduce_op
.
SUM
,
process_group
)
sqr_mean
=
local_sqr_mean
/
world_size
m
=
local_m
*
world_size
else
:
m
=
local_m
mean
=
local_mean
...
...
@@ -94,4 +105,4 @@ class SyncBatchNorm(_BatchNorm):
(
m
-
1
)
*
self
.
momentum
*
var
+
\
(
1
-
self
.
momentum
)
*
self
.
running_var
torch
.
cuda
.
nvtx
.
range_pop
()
return
SyncBatchnormFunction
.
apply
(
input
,
self
.
weight
,
self
.
bias
,
mean
,
var
,
self
.
eps
)
return
SyncBatchnormFunction
.
apply
(
input
,
self
.
weight
,
self
.
bias
,
mean
,
var
,
self
.
eps
,
process_group
,
world_size
)
apex/parallel/sync_batchnorm_kernel.py
View file @
6d3c75e5
...
...
@@ -5,7 +5,7 @@ from torch.autograd.function import Function
class
SyncBatchnormFunction
(
Function
):
@
staticmethod
def
forward
(
ctx
,
input
,
weight
,
bias
,
running_mean
,
running_variance
,
eps
):
def
forward
(
ctx
,
input
,
weight
,
bias
,
running_mean
,
running_variance
,
eps
,
process_group
,
world_size
):
torch
.
cuda
.
nvtx
.
range_push
(
"sync_BN_fw"
)
# transpose it to channel last to support broadcasting for input with different rank
c_last_input
=
input
.
transpose
(
1
,
-
1
).
contiguous
().
clone
()
...
...
@@ -13,6 +13,8 @@ class SyncBatchnormFunction(Function):
ctx
.
save_for_backward
(
c_last_input
,
weight
,
bias
,
running_mean
,
running_variance
)
ctx
.
eps
=
eps
ctx
.
process_group
=
process_group
ctx
.
world_size
=
world_size
c_last_input
=
(
c_last_input
-
running_mean
)
/
\
torch
.
sqrt
(
running_variance
+
eps
)
...
...
@@ -34,6 +36,8 @@ class SyncBatchnormFunction(Function):
c_last_input
,
weight
,
bias
,
running_mean
,
running_variance
=
ctx
.
saved_tensors
eps
=
ctx
.
eps
process_group
=
ctx
.
process_group
world_size
=
ctx
.
world_size
grad_input
=
grad_weight
=
grad_bias
=
None
num_features
=
running_mean
.
size
()[
0
]
...
...
@@ -53,11 +57,11 @@ class SyncBatchnormFunction(Function):
running_mean
)).
view
(
-
1
,
num_features
).
mean
(
0
)
if
torch
.
distributed
.
is_initialized
():
torch
.
distributed
.
all_reduce
(
mean_dy
,
op
=
torch
.
distributed
.
reduce_op
.
SUM
)
mean_dy
=
mean_dy
/
torch
.
distributed
.
get_
world_size
()
mean_dy
,
torch
.
distributed
.
reduce_op
.
SUM
,
process_group
)
mean_dy
=
mean_dy
/
world_size
torch
.
distributed
.
all_reduce
(
mean_dy_xmu
,
op
=
torch
.
distributed
.
reduce_op
.
SUM
)
mean_dy_xmu
=
mean_dy_xmu
/
torch
.
distributed
.
get_
world_size
()
mean_dy_xmu
,
torch
.
distributed
.
reduce_op
.
SUM
,
process_group
)
mean_dy_xmu
=
mean_dy_xmu
/
world_size
c_last_grad_input
=
(
c_last_grad
-
mean_dy
-
(
c_last_input
-
running_mean
)
/
(
running_variance
+
eps
)
*
mean_dy_xmu
)
/
torch
.
sqrt
(
running_variance
+
eps
)
if
weight
is
not
None
:
...
...
@@ -78,4 +82,4 @@ class SyncBatchnormFunction(Function):
grad_bias
=
c_grad
.
sum
(
0
)
torch
.
cuda
.
nvtx
.
range_pop
()
return
grad_input
,
grad_weight
,
grad_bias
,
None
,
None
,
None
return
grad_input
,
grad_weight
,
grad_bias
,
None
,
None
,
None
,
None
,
None
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment