Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
8421cfb4
Unverified
Commit
8421cfb4
authored
Jan 02, 2019
by
mcarilli
Committed by
GitHub
Jan 02, 2019
Browse files
Merge pull request #113 from NVIDIA/sbn_issue
[syncBN]
parents
241dd6c4
fa719e8b
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
9 additions
and
22 deletions
+9
-22
apex/parallel/__init__.py
apex/parallel/__init__.py
+0
-9
apex/parallel/optimized_sync_batchnorm_kernel.py
apex/parallel/optimized_sync_batchnorm_kernel.py
+4
-6
apex/parallel/sync_batchnorm.py
apex/parallel/sync_batchnorm.py
+4
-6
apex/parallel/sync_batchnorm_kernel.py
apex/parallel/sync_batchnorm_kernel.py
+1
-1
No files found.
apex/parallel/__init__.py
View file @
8421cfb4
import
torch
import
torch
# Backward compatibility hack around
# https://github.com/pytorch/pytorch/pull/14767
if
hasattr
(
torch
.
distributed
,
'get_default_group'
):
group_creator
=
torch
.
distributed
.
get_default_group
elif
hasattr
(
torch
.
distributed
,
'new_group'
):
group_creator
=
torch
.
distributed
.
new_group
else
:
group_creator
=
torch
.
distributed
.
deprecated
.
new_group
if
hasattr
(
torch
.
distributed
,
'ReduceOp'
):
if
hasattr
(
torch
.
distributed
,
'ReduceOp'
):
ReduceOp
=
torch
.
distributed
.
ReduceOp
ReduceOp
=
torch
.
distributed
.
ReduceOp
elif
hasattr
(
torch
.
distributed
,
'reduce_op'
):
elif
hasattr
(
torch
.
distributed
,
'reduce_op'
):
...
...
apex/parallel/optimized_sync_batchnorm_kernel.py
View file @
8421cfb4
...
@@ -2,7 +2,7 @@ import torch
...
@@ -2,7 +2,7 @@ import torch
from
torch.autograd.function
import
Function
from
torch.autograd.function
import
Function
import
syncbn
import
syncbn
from
apex.parallel
import
group_creator
,
ReduceOp
from
apex.parallel
import
ReduceOp
class
SyncBatchnormFunction
(
Function
):
class
SyncBatchnormFunction
(
Function
):
...
@@ -16,11 +16,9 @@ class SyncBatchnormFunction(Function):
...
@@ -16,11 +16,9 @@ class SyncBatchnormFunction(Function):
mean
,
var
,
var_biased
=
syncbn
.
welford_mean_var
(
input
)
mean
,
var
,
var_biased
=
syncbn
.
welford_mean_var
(
input
)
if
torch
.
distributed
.
is_initialized
():
if
torch
.
distributed
.
is_initialized
():
if
process_group
:
if
not
process_group
:
world_size
=
torch
.
distributed
.
get_world_size
(
process_group
)
process_group
=
torch
.
distributed
.
group
.
WORLD
else
:
world_size
=
torch
.
distributed
.
get_world_size
(
process_group
)
process_group
=
group_creator
()
world_size
=
torch
.
distributed
.
get_world_size
()
mean_all
=
torch
.
empty
(
world_size
,
mean
.
size
(
0
),
dtype
=
mean
.
dtype
,
device
=
mean
.
device
)
mean_all
=
torch
.
empty
(
world_size
,
mean
.
size
(
0
),
dtype
=
mean
.
dtype
,
device
=
mean
.
device
)
var_all
=
torch
.
empty
(
world_size
,
var
.
size
(
0
),
dtype
=
var
.
dtype
,
device
=
var
.
device
)
var_all
=
torch
.
empty
(
world_size
,
var
.
size
(
0
),
dtype
=
var
.
dtype
,
device
=
var
.
device
)
mean_l
=
[
mean_all
.
narrow
(
0
,
i
,
1
)
for
i
in
range
(
world_size
)]
mean_l
=
[
mean_all
.
narrow
(
0
,
i
,
1
)
for
i
in
range
(
world_size
)]
...
...
apex/parallel/sync_batchnorm.py
View file @
8421cfb4
...
@@ -3,7 +3,7 @@ from torch.nn.modules.batchnorm import _BatchNorm
...
@@ -3,7 +3,7 @@ from torch.nn.modules.batchnorm import _BatchNorm
from
torch.nn
import
functional
as
F
from
torch.nn
import
functional
as
F
from
.sync_batchnorm_kernel
import
SyncBatchnormFunction
from
.sync_batchnorm_kernel
import
SyncBatchnormFunction
from
apex.parallel
import
group_creator
,
ReduceOp
from
apex.parallel
import
ReduceOp
class
SyncBatchNorm
(
_BatchNorm
):
class
SyncBatchNorm
(
_BatchNorm
):
...
@@ -63,11 +63,9 @@ class SyncBatchNorm(_BatchNorm):
...
@@ -63,11 +63,9 @@ class SyncBatchNorm(_BatchNorm):
else
:
else
:
process_group
=
self
.
process_group
process_group
=
self
.
process_group
world_size
=
0
world_size
=
0
if
self
.
process_group
:
if
not
self
.
process_group
:
world_size
=
torch
.
distributed
.
get_world_size
(
process_group
)
process_group
=
torch
.
distributed
.
group
.
WORLD
else
:
world_size
=
torch
.
distributed
.
get_world_size
(
process_group
)
process_group
=
group_creator
()
world_size
=
torch
.
distributed
.
get_world_size
()
self
.
num_batches_tracked
+=
1
self
.
num_batches_tracked
+=
1
with
torch
.
no_grad
():
with
torch
.
no_grad
():
channel_first_input
=
input
.
transpose
(
0
,
1
).
contiguous
()
channel_first_input
=
input
.
transpose
(
0
,
1
).
contiguous
()
...
...
apex/parallel/sync_batchnorm_kernel.py
View file @
8421cfb4
import
torch
import
torch
from
torch.autograd.function
import
Function
from
torch.autograd.function
import
Function
from
apex.parallel
import
group_creator
,
ReduceOp
from
apex.parallel
import
ReduceOp
class
SyncBatchnormFunction
(
Function
):
class
SyncBatchnormFunction
(
Function
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment