Commit f00ca88b authored by rusty1s's avatar rusty1s
Browse files

mark non differentiable

parent 729b4fb2
...@@ -30,6 +30,7 @@ class SPMM(torch.autograd.Function): ...@@ -30,6 +30,7 @@ class SPMM(torch.autograd.Function):
mat, arg_out) mat, arg_out)
if reduce == 'min' or reduce == 'max': if reduce == 'min' or reduce == 'max':
ctx.mark_non_differentiable(arg_out)
return out, arg_out return out, arg_out
else: else:
return out return out
...@@ -123,6 +124,8 @@ class SPSPMM(torch.autograd.Function): ...@@ -123,6 +124,8 @@ class SPSPMM(torch.autograd.Function):
colC = torch.from_numpy(C.col).to(torch.int64) colC = torch.from_numpy(C.col).to(torch.int64)
indexC = torch.stack([rowC, colC], dim=0) indexC = torch.stack([rowC, colC], dim=0)
ctx.mark_non_differentiable(indexC, rowptrC)
# We cannot return `NoneType` in torch.autograd :( # We cannot return `NoneType` in torch.autograd :(
if valueC is None: if valueC is None:
return indexC, rowptrC return indexC, rowptrC
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment