Unverified Commit 089149d3 authored by Asit's avatar Asit Committed by GitHub
Browse files

Merge pull request #2 from a-maci/a-maci-patch-1

Fixing mask multiplication with grad tensors
parents 3dd36070 774de913
...@@ -131,7 +131,8 @@ class ASP: ...@@ -131,7 +131,8 @@ class ASP:
# prune gradients before step method # prune gradients before step method
with torch.no_grad(): with torch.no_grad():
for module_name, module, p_name, p, mask, pruned in cls.__sparse_parameters: for module_name, module, p_name, p, mask, pruned in cls.__sparse_parameters:
p.grad.mul_(mask) if p.grad is not None: #thx pjudd
p.grad.mul_(mask)
# call original optimizer step method # call original optimizer step method
rval = opt_self.__step(*args, **kwargs) rval = opt_self.__step(*args, **kwargs)
# prune parameters after step method # prune parameters after step method
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment