Unverified Commit b11c5660 authored by Kai Chen's avatar Kai Chen Committed by GitHub
Browse files

fix syncbn when dist is unavailabel (#388)

parent b2b42cbd
...@@ -109,7 +109,9 @@ class SyncBatchNormFunction(Function): ...@@ -109,7 +109,9 @@ class SyncBatchNormFunction(Function):
None, None, None, None None, None, None, None
class SyncBatchNorm(Module): if dist.is_available():
class SyncBatchNorm(Module):
def __init__(self, def __init__(self,
num_features, num_features,
...@@ -175,8 +177,8 @@ class SyncBatchNorm(Module): ...@@ -175,8 +177,8 @@ class SyncBatchNorm(Module):
if self.training or not self.track_running_stats: if self.training or not self.track_running_stats:
return SyncBatchNormFunction.apply(input, self.running_mean, return SyncBatchNormFunction.apply(input, self.running_mean,
self.running_var, self.weight, self.running_var,
self.bias, self.weight, self.bias,
exponential_average_factor, exponential_average_factor,
self.eps, self.group, self.eps, self.group,
self.group_size) self.group_size)
...@@ -194,3 +196,12 @@ class SyncBatchNorm(Module): ...@@ -194,3 +196,12 @@ class SyncBatchNorm(Module):
s += f'track_running_stats={self.track_running_stats}, ' s += f'track_running_stats={self.track_running_stats}, '
s += f'group_size={self.group_size})' s += f'group_size={self.group_size})'
return s return s
else:
class SyncBatchNorm(Module):
def __init__(self, *args, **kwargs):
raise NotImplementedError(
'SyncBatchNorm is not supported in this OS since the '
'distributed package is not available')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment