Commit a5289067 authored by jjsjann123's avatar jjsjann123 Committed by mcarilli
Browse files

[syncbn update] (#287)

update input size check to fix github issue #262

update SyncBatchNorm count check so that size 1 input with cross GPU
synchronization runs fine.
parent ffbb52ba
...@@ -26,9 +26,6 @@ class SyncBatchnormFunction(Function): ...@@ -26,9 +26,6 @@ class SyncBatchnormFunction(Function):
count = int(input.numel()/input.size(1)) count = int(input.numel()/input.size(1))
mean, var_biased = syncbn.welford_mean_var(input) mean, var_biased = syncbn.welford_mean_var(input)
if count == 1:
raise ValueError('Expected more than 1 value per channel when training, got input size{}'.format(input.size()))
if torch.distributed.is_initialized(): if torch.distributed.is_initialized():
if not process_group: if not process_group:
process_group = torch.distributed.group.WORLD process_group = torch.distributed.group.WORLD
...@@ -45,6 +42,9 @@ class SyncBatchnormFunction(Function): ...@@ -45,6 +42,9 @@ class SyncBatchnormFunction(Function):
inv_std = 1.0 / torch.sqrt(var_biased + eps) inv_std = 1.0 / torch.sqrt(var_biased + eps)
var = var_biased * (count) / (count-1) var = var_biased * (count) / (count-1)
if count == 1 and world_size < 2:
raise ValueError('Expected more than 1 value per channel when training, got input size{}'.format(input.size()))
r_m_inc = mean if running_mean.dtype != torch.float16 else mean.half() r_m_inc = mean if running_mean.dtype != torch.float16 else mean.half()
r_v_inc = var if running_variance.dtype != torch.float16 else var.half() r_v_inc = var if running_variance.dtype != torch.float16 else var.half()
running_mean.data = running_mean.data * (1-momentum) + momentum*r_m_inc running_mean.data = running_mean.data * (1-momentum) + momentum*r_m_inc
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment