Commit ea7c2098 authored by Michael Carilli's avatar Michael Carilli
Browse files

Merge branch 'master' of https://github.com/NVIDIA/apex

parents 427e82cd 78c38db4
...@@ -37,6 +37,8 @@ def convert_syncbn_model(module, process_group=None, channel_last=False): ...@@ -37,6 +37,8 @@ def convert_syncbn_model(module, process_group=None, channel_last=False):
>>> sync_bn_model = apex.parallel.convert_syncbn_model(model) >>> sync_bn_model = apex.parallel.convert_syncbn_model(model)
''' '''
mod = module mod = module
if isinstance(module, torch.nn.modules.instancenorm._InstanceNorm):
return module
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
mod = SyncBatchNorm(module.num_features, module.eps, module.momentum, module.affine, module.track_running_stats, process_group, channel_last=channel_last) mod = SyncBatchNorm(module.num_features, module.eps, module.momentum, module.affine, module.track_running_stats, process_group, channel_last=channel_last)
mod.running_mean = module.running_mean mod.running_mean = module.running_mean
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment