Commit 78c38db4 authored by Du Xingjian's avatar Du Xingjian Committed by mcarilli
Browse files

skip instancenorm in convert_syncbn_model (#438)

parent 880ab925
......@@ -37,6 +37,8 @@ def convert_syncbn_model(module, process_group=None, channel_last=False):
>>> sync_bn_model = apex.parallel.convert_syncbn_model(model)
'''
mod = module
if isinstance(module, torch.nn.modules.instancenorm._InstanceNorm):
return module
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
mod = SyncBatchNorm(module.num_features, module.eps, module.momentum, module.affine, module.track_running_stats, process_group, channel_last=channel_last)
mod.running_mean = module.running_mean
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment