Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
37cdaf4a
Commit
37cdaf4a
authored
Nov 06, 2019
by
jjsjann123
Committed by
mcarilli
Nov 06, 2019
Browse files
fixing batchnorm 1d input (#590)
parent
606c3dcc
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
21 additions
and
2 deletions
+21
-2
apex/parallel/optimized_sync_batchnorm.py
apex/parallel/optimized_sync_batchnorm.py
+2
-2
tests/distributed/synced_batchnorm/test_batchnorm1d.py
tests/distributed/synced_batchnorm/test_batchnorm1d.py
+18
-0
tests/distributed/synced_batchnorm/unit_test.sh
tests/distributed/synced_batchnorm/unit_test.sh
+1
-0
No files found.
apex/parallel/optimized_sync_batchnorm.py
View file @
37cdaf4a
...
...
@@ -71,7 +71,7 @@ class SyncBatchNorm(_BatchNorm):
# if input.dim() == 2, we switch to channel_last for efficient memory accessing
channel_last
=
self
.
channel_last
if
input
.
dim
()
!=
2
else
True
if
not
self
.
training
and
self
.
track_running_stats
and
not
self
.
channel_last
and
not
self
.
fuse_relu
and
z
==
None
:
if
not
self
.
training
and
self
.
track_running_stats
and
not
channel_last
and
not
self
.
fuse_relu
and
z
==
None
:
# fall back to pytorch implementation for inference
return
F
.
batch_norm
(
input
,
self
.
running_mean
,
self
.
running_var
,
self
.
weight
,
self
.
bias
,
False
,
0.0
,
self
.
eps
)
else
:
...
...
@@ -82,4 +82,4 @@ class SyncBatchNorm(_BatchNorm):
exponential_average_factor
=
1.0
/
float
(
self
.
num_batches_tracked
)
else
:
exponential_average_factor
=
self
.
momentum
return
SyncBatchnormFunction
.
apply
(
input
,
z
,
self
.
weight
,
self
.
bias
,
self
.
running_mean
,
self
.
running_var
,
self
.
eps
,
self
.
training
or
not
self
.
track_running_stats
,
exponential_average_factor
,
self
.
process_group
,
self
.
channel_last
,
self
.
fuse_relu
)
return
SyncBatchnormFunction
.
apply
(
input
,
z
,
self
.
weight
,
self
.
bias
,
self
.
running_mean
,
self
.
running_var
,
self
.
eps
,
self
.
training
or
not
self
.
track_running_stats
,
exponential_average_factor
,
self
.
process_group
,
channel_last
,
self
.
fuse_relu
)
tests/distributed/synced_batchnorm/test_batchnorm1d.py
0 → 100644
View file @
37cdaf4a
import
torch
import
apex
model
=
apex
.
parallel
.
SyncBatchNorm
(
4
).
cuda
()
model
.
weight
.
data
.
uniform_
()
model
.
bias
.
data
.
uniform_
()
data
=
torch
.
rand
((
8
,
4
)).
cuda
()
model_ref
=
torch
.
nn
.
BatchNorm1d
(
4
).
cuda
()
model_ref
.
load_state_dict
(
model
.
state_dict
())
data_ref
=
data
.
clone
()
output
=
model
(
data
)
output_ref
=
model_ref
(
data_ref
)
assert
(
output
.
allclose
(
output_ref
))
assert
(
model
.
running_mean
.
allclose
(
model_ref
.
running_mean
))
assert
(
model
.
running_var
.
allclose
(
model_ref
.
running_var
))
tests/distributed/synced_batchnorm/unit_test.sh
View file @
37cdaf4a
python python_single_gpu_unit_test.py
python single_gpu_unit_test.py
python test_batchnorm1d.py
python
-m
torch.distributed.launch
--nproc_per_node
=
2 two_gpu_unit_test.py
python
-m
torch.distributed.launch
--nproc_per_node
=
2 two_gpu_unit_test.py
--fp16
#beware, you need a system with at least 4 gpus to test group_size<world_size
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment