Unverified Commit 1feb8917 authored by Yezhen Cong's avatar Yezhen Cong Committed by GitHub
Browse files

mark tensor not differentiable (#377)

parent 22c47ac7
......@@ -32,6 +32,9 @@ class ThreeNN(Function):
idx = torch.cuda.IntTensor(B, N, 3)
interpolate_ext.three_nn_wrapper(B, N, m, target, source, dist2, idx)
ctx.mark_non_differentiable(idx)
return torch.sqrt(dist2), idx
@staticmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment