"git@developer.sourcefind.cn:OpenDAS/torch-cluster.git" did not exist on "c7aeb8d2598ecd56bb2ffbbbb003975b94c7dce9"
Commit a62b87ea authored by Jie's avatar Jie
Browse files

fixing utility function convert_syncbn_model to accept channel_last flag and...

fixing utility function convert_syncbn_model to accept channel_last flag and properly set attribute for nested layers
parent 443fa76e
...@@ -19,7 +19,7 @@ except ImportError: ...@@ -19,7 +19,7 @@ except ImportError:
warned_syncbn = True warned_syncbn = True
from .sync_batchnorm import SyncBatchNorm from .sync_batchnorm import SyncBatchNorm
def convert_syncbn_model(module, process_group=None): def convert_syncbn_model(module, process_group=None, channel_last=False):
''' '''
Recursively traverse module and its children to replace all Recursively traverse module and its children to replace all
`torch.nn.modules.batchnorm._BatchNorm` with `apex.parallel.SyncBatchNorm` `torch.nn.modules.batchnorm._BatchNorm` with `apex.parallel.SyncBatchNorm`
...@@ -38,14 +38,16 @@ def convert_syncbn_model(module, process_group=None): ...@@ -38,14 +38,16 @@ def convert_syncbn_model(module, process_group=None):
''' '''
mod = module mod = module
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
mod = SyncBatchNorm(module.num_features, module.eps, module.momentum, module.affine, module.track_running_stats, process_group) mod = SyncBatchNorm(module.num_features, module.eps, module.momentum, module.affine, module.track_running_stats, process_group, channel_last=channel_last)
mod.running_mean = module.running_mean mod.running_mean = module.running_mean
mod.running_var = module.running_var mod.running_var = module.running_var
if module.affine: if module.affine:
mod.weight.data = module.weight.data.clone().detach() mod.weight.data = module.weight.data.clone().detach()
mod.bias.data = module.bias.data.clone().detach() mod.bias.data = module.bias.data.clone().detach()
for name, child in module.named_children(): for name, child in module.named_children():
mod.add_module(name, convert_syncbn_model(child)) mod.add_module(name, convert_syncbn_model(child,
process_group=process_group,
channel_last=channel_last))
# TODO(jie) should I delete model explicitly? # TODO(jie) should I delete model explicitly?
del module del module
return mod return mod
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment