Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
878ba512
Unverified
Commit
878ba512
authored
Jan 28, 2019
by
mcarilli
Committed by
GitHub
Jan 28, 2019
Browse files
Merge pull request #138 from NVIDIA/sbn_test_cases
[syncBN]
parents
95fe7f6a
d0624f4f
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
5 additions
and
5 deletions
+5
-5
tests/synced_batchnorm/two_gpu_unit_test.py
tests/synced_batchnorm/two_gpu_unit_test.py
+5
-5
No files found.
tests/synced_batchnorm/two_gpu_unit_test.py
View file @
878ba512
...
@@ -92,6 +92,10 @@ inp_bn = inp_t.clone().requires_grad_()
...
@@ -92,6 +92,10 @@ inp_bn = inp_t.clone().requires_grad_()
grad_bn
=
grad_output_t
.
clone
().
detach
()
grad_bn
=
grad_output_t
.
clone
().
detach
()
out_bn
=
bn
(
inp_bn
)
out_bn
=
bn
(
inp_bn
)
out_bn
.
backward
(
grad_bn
)
out_bn
.
backward
(
grad_bn
)
# compensating the averaging over processes done by DDP
# in order to produce mathematically equivalent result
for
param
in
bn
.
parameters
():
param
.
grad
=
param
.
grad
/
args
.
world_size
bn_opt
=
optim
.
SGD
(
bn
.
parameters
(),
lr
=
1.0
)
bn_opt
=
optim
.
SGD
(
bn
.
parameters
(),
lr
=
1.0
)
sbn
=
apex
.
parallel
.
SyncBatchNorm
(
feature_size
).
cuda
()
sbn
=
apex
.
parallel
.
SyncBatchNorm
(
feature_size
).
cuda
()
...
@@ -103,7 +107,7 @@ if args.fp16:
...
@@ -103,7 +107,7 @@ if args.fp16:
if
args
.
fp64
:
if
args
.
fp64
:
sbn
.
double
()
sbn
.
double
()
sbn
=
DDP
(
sbn
)
sbn
=
DDP
(
sbn
)
sbn_opt
=
optim
.
SGD
(
sbn
.
parameters
(),
lr
=
1.0
*
args
.
world_size
)
sbn_opt
=
optim
.
SGD
(
sbn
.
parameters
(),
lr
=
1.0
)
inp_sbn
=
inp_t
.
clone
().
requires_grad_
()
inp_sbn
=
inp_t
.
clone
().
requires_grad_
()
grad_sbn
=
grad_output_t
.
clone
().
detach
()
grad_sbn
=
grad_output_t
.
clone
().
detach
()
out_sbn
=
sbn
(
inp_sbn
[
start
:
finish
])
out_sbn
=
sbn
(
inp_sbn
[
start
:
finish
])
...
@@ -159,11 +163,7 @@ sbn_opt.step()
...
@@ -159,11 +163,7 @@ sbn_opt.step()
if
args
.
local_rank
==
0
:
if
args
.
local_rank
==
0
:
compare
(
"comparing bn vs sbn bias: "
,
bn
.
bias
,
sbn
.
module
.
bias
,
error
)
compare
(
"comparing bn vs sbn bias: "
,
bn
.
bias
,
sbn
.
module
.
bias
,
error
)
compare
(
"comparing bn vs ref bias: "
,
bn
.
bias
,
bias_r
.
view
(
-
1
)
-
grad_bias_r
,
error
)
sbn_result
=
compare
(
"comparing sbn vs ref bias: "
,
sbn
.
module
.
bias
,
bias_r
.
view
(
-
1
)
-
grad_bias_r
,
error
)
and
sbn_result
compare
(
"comparing bn vs sbn weight: "
,
bn
.
weight
,
sbn
.
module
.
weight
,
error
)
compare
(
"comparing bn vs sbn weight: "
,
bn
.
weight
,
sbn
.
module
.
weight
,
error
)
compare
(
"comparing bn vs ref weight: "
,
bn
.
weight
,
(
weight_r
.
view
(
-
1
)
-
grad_weight_r
),
error
)
sbn_result
=
compare
(
"comparing sbn vs ref weight: "
,
sbn
.
module
.
weight
,
(
weight_r
.
view
(
-
1
)
-
grad_weight_r
),
error
)
and
sbn_result
if
sbn_result
:
if
sbn_result
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment