Commit 0a991543 authored by jjsjann123's avatar jjsjann123 Committed by mcarilli
Browse files

[SyncBatchNorm] (#206)

supporting 2 dimensional input, resolving issue #194

Implementation:
  for 2d input, switching channel_last flag to true for better memory access
pattern in the kernel.
parent 570fde70
...@@ -67,7 +67,10 @@ class SyncBatchNorm(_BatchNorm): ...@@ -67,7 +67,10 @@ class SyncBatchNorm(_BatchNorm):
self.channel_last = channel_last self.channel_last = channel_last
def forward(self, input): def forward(self, input):
if not self.training and self.track_running_stats and not self.channel_last: # if input.dim() == 2, we switch to channel_last for efficient memory accessing
channel_last = self.channel_last if input.dim() != 2 else True
if not self.training and self.track_running_stats and not channel_last:
# fall back to pytorch implementation for inference # fall back to pytorch implementation for inference
return F.batch_norm(input, self.running_mean, self.running_var, self.weight, self.bias, False, 0.0, self.eps) return F.batch_norm(input, self.running_mean, self.running_var, self.weight, self.bias, False, 0.0, self.eps)
else: else:
...@@ -78,4 +81,4 @@ class SyncBatchNorm(_BatchNorm): ...@@ -78,4 +81,4 @@ class SyncBatchNorm(_BatchNorm):
exponential_average_factor = 1.0 / float(self.num_batches_tracked) exponential_average_factor = 1.0 / float(self.num_batches_tracked)
else: else:
exponential_average_factor = self.momentum exponential_average_factor = self.momentum
return SyncBatchnormFunction.apply(input, self.weight, self.bias, self.running_mean, self.running_var, self.eps, self.training or not self.track_running_stats, exponential_average_factor, self.process_group, self.channel_last) return SyncBatchnormFunction.apply(input, self.weight, self.bias, self.running_mean, self.running_var, self.eps, self.training or not self.track_running_stats, exponential_average_factor, self.process_group, channel_last)
...@@ -22,10 +22,13 @@ class SyncBatchnormFunction(Function): ...@@ -22,10 +22,13 @@ class SyncBatchnormFunction(Function):
if channel_last: if channel_last:
count = int(input.numel()/input.size(-1)) count = int(input.numel()/input.size(-1))
mean, var_biased = syncbn.welford_mean_var_c_last(input) mean, var_biased = syncbn.welford_mean_var_c_last(input)
else : else:
count = int(input.numel()/input.size(1)) count = int(input.numel()/input.size(1))
mean, var_biased = syncbn.welford_mean_var(input) mean, var_biased = syncbn.welford_mean_var(input)
if count == 1:
raise ValueError('Expected more than 1 value per channel when training, got input size{}'.format(input.size()))
if torch.distributed.is_initialized(): if torch.distributed.is_initialized():
if not process_group: if not process_group:
process_group = torch.distributed.group.WORLD process_group = torch.distributed.group.WORLD
......
...@@ -72,10 +72,9 @@ class SyncBatchNorm(_BatchNorm): ...@@ -72,10 +72,9 @@ class SyncBatchNorm(_BatchNorm):
return F.batch_norm(input, self.running_mean, self.running_var, self.weight, self.bias, False, 0.0, self.eps) return F.batch_norm(input, self.running_mean, self.running_var, self.weight, self.bias, False, 0.0, self.eps)
else: else:
process_group = self.process_group process_group = self.process_group
world_size = 0 world_size = 1
if not self.process_group: if not self.process_group:
process_group = torch.distributed.group.WORLD process_group = torch.distributed.group.WORLD
world_size = torch.distributed.get_world_size(process_group)
self.num_batches_tracked += 1 self.num_batches_tracked += 1
with torch.no_grad(): with torch.no_grad():
channel_first_input = input.transpose(0, 1).contiguous() channel_first_input = input.transpose(0, 1).contiguous()
...@@ -88,6 +87,7 @@ class SyncBatchNorm(_BatchNorm): ...@@ -88,6 +87,7 @@ class SyncBatchNorm(_BatchNorm):
local_sqr_mean = torch.pow( local_sqr_mean = torch.pow(
squashed_input_tensor_view, 2).mean(1) squashed_input_tensor_view, 2).mean(1)
if torch.distributed.is_initialized(): if torch.distributed.is_initialized():
world_size = torch.distributed.get_world_size(process_group)
torch.distributed.all_reduce( torch.distributed.all_reduce(
local_mean, ReduceOp.SUM, process_group) local_mean, ReduceOp.SUM, process_group)
mean = local_mean / world_size mean = local_mean / world_size
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment