Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
a99e1875
Commit
a99e1875
authored
Dec 17, 2018
by
Michael Carilli
Browse files
Fix deprecation warnings for ReduceOp
parent
35891b28
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
15 additions
and
6 deletions
+15
-6
apex/parallel/__init__.py
apex/parallel/__init__.py
+5
-0
apex/parallel/optimized_sync_batchnorm_kernel.py
apex/parallel/optimized_sync_batchnorm_kernel.py
+3
-2
apex/parallel/sync_batchnorm.py
apex/parallel/sync_batchnorm.py
+3
-2
apex/parallel/sync_batchnorm_kernel.py
apex/parallel/sync_batchnorm_kernel.py
+4
-2
No files found.
apex/parallel/__init__.py
View file @
a99e1875
...
@@ -7,6 +7,11 @@ if hasattr(torch.distributed, 'get_default_group'):
...
@@ -7,6 +7,11 @@ if hasattr(torch.distributed, 'get_default_group'):
else
:
else
:
group_creator
=
torch
.
distributed
.
new_group
group_creator
=
torch
.
distributed
.
new_group
if
hasattr
(
torch
.
distributed
,
'ReduceOp'
):
ReduceOp
=
torch
.
distributed
.
ReduceOp
else
:
ReduceOp
=
torch
.
distributed
.
reduce_op
from
.distributed
import
DistributedDataParallel
,
Reducer
from
.distributed
import
DistributedDataParallel
,
Reducer
try
:
try
:
import
syncbn
import
syncbn
...
...
apex/parallel/optimized_sync_batchnorm_kernel.py
View file @
a99e1875
...
@@ -2,6 +2,7 @@ import torch
...
@@ -2,6 +2,7 @@ import torch
from
torch.autograd.function
import
Function
from
torch.autograd.function
import
Function
import
syncbn
import
syncbn
from
apex.parallel
import
group_creator
,
ReduceOp
class
SyncBatchnormFunction
(
Function
):
class
SyncBatchnormFunction
(
Function
):
...
@@ -68,10 +69,10 @@ class SyncBatchnormFunction(Function):
...
@@ -68,10 +69,10 @@ class SyncBatchnormFunction(Function):
if
torch
.
distributed
.
is_initialized
():
if
torch
.
distributed
.
is_initialized
():
torch
.
distributed
.
all_reduce
(
torch
.
distributed
.
all_reduce
(
mean_dy
,
torch
.
distributed
.
r
educe
_o
p
.
SUM
,
process_group
)
mean_dy
,
R
educe
O
p
.
SUM
,
process_group
)
mean_dy
=
mean_dy
/
world_size
mean_dy
=
mean_dy
/
world_size
torch
.
distributed
.
all_reduce
(
torch
.
distributed
.
all_reduce
(
mean_dy_xmu
,
torch
.
distributed
.
r
educe
_o
p
.
SUM
,
process_group
)
mean_dy_xmu
,
R
educe
O
p
.
SUM
,
process_group
)
mean_dy_xmu
=
mean_dy_xmu
/
world_size
mean_dy_xmu
=
mean_dy_xmu
/
world_size
grad_input
=
syncbn
.
batchnorm_backward
(
grad_output
,
saved_input
,
running_mean
,
running_variance
,
weight
,
mean_dy
,
mean_dy_xmu
,
eps
)
grad_input
=
syncbn
.
batchnorm_backward
(
grad_output
,
saved_input
,
running_mean
,
running_variance
,
weight
,
mean_dy
,
mean_dy_xmu
,
eps
)
...
...
apex/parallel/sync_batchnorm.py
View file @
a99e1875
...
@@ -3,6 +3,7 @@ from torch.nn.modules.batchnorm import _BatchNorm
...
@@ -3,6 +3,7 @@ from torch.nn.modules.batchnorm import _BatchNorm
from
torch.nn
import
functional
as
F
from
torch.nn
import
functional
as
F
from
.sync_batchnorm_kernel
import
SyncBatchnormFunction
from
.sync_batchnorm_kernel
import
SyncBatchnormFunction
from
apex.parallel
import
group_creator
,
ReduceOp
class
SyncBatchNorm
(
_BatchNorm
):
class
SyncBatchNorm
(
_BatchNorm
):
...
@@ -80,10 +81,10 @@ class SyncBatchNorm(_BatchNorm):
...
@@ -80,10 +81,10 @@ class SyncBatchNorm(_BatchNorm):
squashed_input_tensor_view
,
2
).
mean
(
1
)
squashed_input_tensor_view
,
2
).
mean
(
1
)
if
torch
.
distributed
.
is_initialized
():
if
torch
.
distributed
.
is_initialized
():
torch
.
distributed
.
all_reduce
(
torch
.
distributed
.
all_reduce
(
local_mean
,
torch
.
distributed
.
r
educe
_o
p
.
SUM
,
process_group
)
local_mean
,
R
educe
O
p
.
SUM
,
process_group
)
mean
=
local_mean
/
world_size
mean
=
local_mean
/
world_size
torch
.
distributed
.
all_reduce
(
torch
.
distributed
.
all_reduce
(
local_sqr_mean
,
torch
.
distributed
.
r
educe
_o
p
.
SUM
,
process_group
)
local_sqr_mean
,
R
educe
O
p
.
SUM
,
process_group
)
sqr_mean
=
local_sqr_mean
/
world_size
sqr_mean
=
local_sqr_mean
/
world_size
m
=
local_m
*
world_size
m
=
local_m
*
world_size
else
:
else
:
...
...
apex/parallel/sync_batchnorm_kernel.py
View file @
a99e1875
import
torch
import
torch
from
torch.autograd.function
import
Function
from
torch.autograd.function
import
Function
from
apex.parallel
import
group_creator
,
ReduceOp
class
SyncBatchnormFunction
(
Function
):
class
SyncBatchnormFunction
(
Function
):
...
@@ -57,10 +59,10 @@ class SyncBatchnormFunction(Function):
...
@@ -57,10 +59,10 @@ class SyncBatchnormFunction(Function):
running_mean
)).
view
(
-
1
,
num_features
).
mean
(
0
)
running_mean
)).
view
(
-
1
,
num_features
).
mean
(
0
)
if
torch
.
distributed
.
is_initialized
():
if
torch
.
distributed
.
is_initialized
():
torch
.
distributed
.
all_reduce
(
torch
.
distributed
.
all_reduce
(
mean_dy
,
torch
.
distributed
.
r
educe
_o
p
.
SUM
,
process_group
)
mean_dy
,
R
educe
O
p
.
SUM
,
process_group
)
mean_dy
=
mean_dy
/
world_size
mean_dy
=
mean_dy
/
world_size
torch
.
distributed
.
all_reduce
(
torch
.
distributed
.
all_reduce
(
mean_dy_xmu
,
torch
.
distributed
.
r
educe
_o
p
.
SUM
,
process_group
)
mean_dy_xmu
,
R
educe
O
p
.
SUM
,
process_group
)
mean_dy_xmu
=
mean_dy_xmu
/
world_size
mean_dy_xmu
=
mean_dy_xmu
/
world_size
c_last_grad_input
=
(
c_last_grad
-
mean_dy
-
(
c_last_input
-
running_mean
)
/
(
c_last_grad_input
=
(
c_last_grad
-
mean_dy
-
(
c_last_input
-
running_mean
)
/
(
running_variance
+
eps
)
*
mean_dy_xmu
)
/
torch
.
sqrt
(
running_variance
+
eps
)
running_variance
+
eps
)
*
mean_dy_xmu
)
/
torch
.
sqrt
(
running_variance
+
eps
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment