Commit a62b87ea authored by Jie's avatar Jie
Browse files

fixing utility function convert_syncbn_model to accept channel_last flag and...

fixing utility function convert_syncbn_model to accept channel_last flag and properly set attribute for nested layers
parent 443fa76e
......@@ -19,7 +19,7 @@ except ImportError:
warned_syncbn = True
from .sync_batchnorm import SyncBatchNorm
def convert_syncbn_model(module, process_group=None):
def convert_syncbn_model(module, process_group=None, channel_last=False):
'''
Recursively traverse module and its children to replace all
`torch.nn.modules.batchnorm._BatchNorm` with `apex.parallel.SyncBatchNorm`
......@@ -38,14 +38,16 @@ def convert_syncbn_model(module, process_group=None):
'''
mod = module
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
mod = SyncBatchNorm(module.num_features, module.eps, module.momentum, module.affine, module.track_running_stats, process_group)
mod = SyncBatchNorm(module.num_features, module.eps, module.momentum, module.affine, module.track_running_stats, process_group, channel_last=channel_last)
mod.running_mean = module.running_mean
mod.running_var = module.running_var
if module.affine:
mod.weight.data = module.weight.data.clone().detach()
mod.bias.data = module.bias.data.clone().detach()
for name, child in module.named_children():
mod.add_module(name, convert_syncbn_model(child))
mod.add_module(name, convert_syncbn_model(child,
process_group=process_group,
channel_last=channel_last))
# TODO(jie) should I delete model explicitly?
del module
return mod
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment