"docs/source/vscode:/vscode.git/clone" did not exist on "ac5d6ee6c2fab42229fcb7dc031240f09d55d951"
Commit ffbb52ba authored by jjsjann123's avatar jjsjann123 Committed by mcarilli
Browse files

[SyncBatchNorm update] (#285)

resolves issue #254

Added input casting for pure python implementation, this supports mismatched
input and layer dtype.
parent 4d325d2f
...@@ -66,10 +66,23 @@ class SyncBatchNorm(_BatchNorm): ...@@ -66,10 +66,23 @@ class SyncBatchNorm(_BatchNorm):
torch.cuda.nvtx.range_push("sync_bn_fw_with_mean_var") torch.cuda.nvtx.range_push("sync_bn_fw_with_mean_var")
mean = None mean = None
var = None var = None
cast = None
out = None
# casting to handle mismatch input type to layer type
if self.running_mean is not None:
if self.running_mean.dtype != input.dtype:
input = input.to(self.running_mean.dtype)
cast = input.dtype
elif self.weight is not None:
if self.weight.dtype != input.dtype:
input = input.to(self.weight.dtype)
cast = input.dtype
if not self.training and self.track_running_stats: if not self.training and self.track_running_stats:
# fall back to pytorch implementation for inference # fall back to pytorch implementation for inference
torch.cuda.nvtx.range_pop() torch.cuda.nvtx.range_pop()
return F.batch_norm(input, self.running_mean, self.running_var, self.weight, self.bias, False, 0.0, self.eps) out = F.batch_norm(input, self.running_mean, self.running_var, self.weight, self.bias, False, 0.0, self.eps)
else: else:
process_group = self.process_group process_group = self.process_group
world_size = 1 world_size = 1
...@@ -114,4 +127,5 @@ class SyncBatchNorm(_BatchNorm): ...@@ -114,4 +127,5 @@ class SyncBatchNorm(_BatchNorm):
(m-1) * self.momentum * var + \ (m-1) * self.momentum * var + \
(1 - self.momentum) * self.running_var (1 - self.momentum) * self.running_var
torch.cuda.nvtx.range_pop() torch.cuda.nvtx.range_pop()
return SyncBatchnormFunction.apply(input, self.weight, self.bias, mean, var, self.eps, process_group, world_size) out = SyncBatchnormFunction.apply(input, self.weight, self.bias, mean, var, self.eps, process_group, world_size)
out = out.to(cast)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment