Unverified Commit b11c5660 authored by Kai Chen's avatar Kai Chen Committed by GitHub
Browse files

fix syncbn when dist is unavailabel (#388)

parent b2b42cbd
......@@ -109,7 +109,9 @@ class SyncBatchNormFunction(Function):
None, None, None, None
class SyncBatchNorm(Module):
if dist.is_available():
class SyncBatchNorm(Module):
def __init__(self,
num_features,
......@@ -175,8 +177,8 @@ class SyncBatchNorm(Module):
if self.training or not self.track_running_stats:
return SyncBatchNormFunction.apply(input, self.running_mean,
self.running_var, self.weight,
self.bias,
self.running_var,
self.weight, self.bias,
exponential_average_factor,
self.eps, self.group,
self.group_size)
......@@ -194,3 +196,12 @@ class SyncBatchNorm(Module):
s += f'track_running_stats={self.track_running_stats}, '
s += f'group_size={self.group_size})'
return s
else:
class SyncBatchNorm(Module):
def __init__(self, *args, **kwargs):
raise NotImplementedError(
'SyncBatchNorm is not supported in this OS since the '
'distributed package is not available')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment