"git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "d8fd9eda9dd95f19048043394afbd5fefecfdcdd"
Commit 8092c8bd authored by chicm-ms's avatar chicm-ms Committed by GitHub
Browse files

Add weight_mask to calc_mask kwargs parameter (#2024)

parent b210695f
...@@ -225,15 +225,17 @@ class PrunerModuleWrapper(torch.nn.Module): ...@@ -225,15 +225,17 @@ class PrunerModuleWrapper(torch.nn.Module):
# config and pruner # config and pruner
self.config = config self.config = config
self.pruner = pruner self.pruner = pruner
self.registered_buffers = {}
# register buffer for mask # register buffer for mask
self.register_buffer("weight_mask", torch.ones(self.module.weight.shape)) self.register_buffer("weight_mask", torch.ones(self.module.weight.shape))
self.registered_buffers['weight_mask'] = self.weight_mask
if hasattr(self.module, 'bias') and self.module.bias is not None: if hasattr(self.module, 'bias') and self.module.bias is not None:
self.register_buffer("bias_mask", torch.ones(self.module.bias.shape)) self.register_buffer("bias_mask", torch.ones(self.module.bias.shape))
else: else:
self.register_buffer("bias_mask", None) self.register_buffer("bias_mask", None)
self.registered_buffers['bias_mask'] = self.bias_mask
# register user specified buffer # register user specified buffer
self.registered_buffers = {}
for name in self.pruner.buffers: for name in self.pruner.buffers:
self.register_buffer(name, self.pruner.buffers[name].clone()) self.register_buffer(name, self.pruner.buffers[name].clone())
self.registered_buffers[name] = getattr(self, name) self.registered_buffers[name] = getattr(self, name)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment