Unverified Commit d1eaa556 authored by Ningxin Zheng's avatar Ningxin Zheng Committed by GitHub
Browse files

update (#4241)

parent 45048734
...@@ -102,6 +102,7 @@ class AutoMaskInference: ...@@ -102,6 +102,7 @@ class AutoMaskInference:
self.weight_mask = {} self.weight_mask = {}
if weight_mask: if weight_mask:
self.weight_mask.update(weight_mask) self.weight_mask.update(weight_mask)
self.name = name
if isinstance(self.module, nn.Module): if isinstance(self.module, nn.Module):
# the function should not has parameters # the function should not has parameters
# get all the parameter tensors of the target module # get all the parameter tensors of the target module
...@@ -109,7 +110,6 @@ class AutoMaskInference: ...@@ -109,7 +110,6 @@ class AutoMaskInference:
self.weights[name] = para self.weights[name] = para
if name not in self.weight_mask: if name not in self.weight_mask:
self.weight_mask[name] = torch.ones_like(para.data) self.weight_mask[name] = torch.ones_like(para.data)
self.name = name
self.state_dict = state_dict self.state_dict = state_dict
# TODO support the other batch dimension in the future # TODO support the other batch dimension in the future
self.batch_dim = batch_dim self.batch_dim = batch_dim
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment