"...text-generation-inference.git" did not exist on "042180d88f91d4bc9acd42ae4de3c0236d272de4"
Commit 78c38db4 authored by Du Xingjian's avatar Du Xingjian Committed by mcarilli
Browse files

skip instancenorm in convert_syncbn_model (#438)

parent 880ab925
...@@ -37,6 +37,8 @@ def convert_syncbn_model(module, process_group=None, channel_last=False): ...@@ -37,6 +37,8 @@ def convert_syncbn_model(module, process_group=None, channel_last=False):
>>> sync_bn_model = apex.parallel.convert_syncbn_model(model) >>> sync_bn_model = apex.parallel.convert_syncbn_model(model)
''' '''
mod = module mod = module
if isinstance(module, torch.nn.modules.instancenorm._InstanceNorm):
return module
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
mod = SyncBatchNorm(module.num_features, module.eps, module.momentum, module.affine, module.track_running_stats, process_group, channel_last=channel_last) mod = SyncBatchNorm(module.num_features, module.eps, module.momentum, module.affine, module.track_running_stats, process_group, channel_last=channel_last)
mod.running_mean = module.running_mean mod.running_mean = module.running_mean
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment