Unverified Commit d1eaa556 authored by Ningxin Zheng's avatar Ningxin Zheng Committed by GitHub
Browse files

update (#4241)

parent 45048734
......@@ -102,6 +102,7 @@ class AutoMaskInference:
self.weight_mask = {}
if weight_mask:
self.weight_mask.update(weight_mask)
self.name = name
if isinstance(self.module, nn.Module):
# the function should not has parameters
# get all the parameter tensors of the target module
......@@ -109,7 +110,6 @@ class AutoMaskInference:
self.weights[name] = para
if name not in self.weight_mask:
self.weight_mask[name] = torch.ones_like(para.data)
self.name = name
self.state_dict = state_dict
# TODO support the other batch dimension in the future
self.batch_dim = batch_dim
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment