Commit 37cdaf4a authored by jjsjann123's avatar jjsjann123 Committed by mcarilli
Browse files

fixing batchnorm 1d input (#590)

parent 606c3dcc
......@@ -71,7 +71,7 @@ class SyncBatchNorm(_BatchNorm):
# if input.dim() == 2, we switch to channel_last for efficient memory accessing
channel_last = self.channel_last if input.dim() != 2 else True
if not self.training and self.track_running_stats and not self.channel_last and not self.fuse_relu and z == None:
if not self.training and self.track_running_stats and not channel_last and not self.fuse_relu and z == None:
# fall back to pytorch implementation for inference
return F.batch_norm(input, self.running_mean, self.running_var, self.weight, self.bias, False, 0.0, self.eps)
else:
......@@ -82,4 +82,4 @@ class SyncBatchNorm(_BatchNorm):
exponential_average_factor = 1.0 / float(self.num_batches_tracked)
else:
exponential_average_factor = self.momentum
return SyncBatchnormFunction.apply(input, z, self.weight, self.bias, self.running_mean, self.running_var, self.eps, self.training or not self.track_running_stats, exponential_average_factor, self.process_group, self.channel_last, self.fuse_relu)
return SyncBatchnormFunction.apply(input, z, self.weight, self.bias, self.running_mean, self.running_var, self.eps, self.training or not self.track_running_stats, exponential_average_factor, self.process_group, channel_last, self.fuse_relu)
import torch
import apex
model = apex.parallel.SyncBatchNorm(4).cuda()
model.weight.data.uniform_()
model.bias.data.uniform_()
data = torch.rand((8,4)).cuda()
model_ref = torch.nn.BatchNorm1d(4).cuda()
model_ref.load_state_dict(model.state_dict())
data_ref = data.clone()
output = model(data)
output_ref = model_ref(data_ref)
assert(output.allclose(output_ref))
assert(model.running_mean.allclose(model_ref.running_mean))
assert(model.running_var.allclose(model_ref.running_var))
python python_single_gpu_unit_test.py
python single_gpu_unit_test.py
python test_batchnorm1d.py
python -m torch.distributed.launch --nproc_per_node=2 two_gpu_unit_test.py
python -m torch.distributed.launch --nproc_per_node=2 two_gpu_unit_test.py --fp16
#beware, you need a system with at least 4 gpus to test group_size<world_size
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment