Unverified Commit 774de913 authored by Asit's avatar Asit Committed by GitHub
Browse files

Fixing mask multiplication with grad tensors

Grads can be None type. Adding this fix to skip multiplication with masks if this is the case.
parent 3dd36070
......@@ -131,7 +131,8 @@ class ASP:
# prune gradients before step method
with torch.no_grad():
for module_name, module, p_name, p, mask, pruned in cls.__sparse_parameters:
p.grad.mul_(mask)
if p.grad is not None: #thx pjudd
p.grad.mul_(mask)
# call original optimizer step method
rval = opt_self.__step(*args, **kwargs)
# prune parameters after step method
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment