Unverified Commit 878ad2e0 authored by cheng052's avatar cheng052 Committed by GitHub
Browse files

mark idx in BallQuery as non_differentiable (#919)

parent fd20aabc
......@@ -35,11 +35,15 @@ class BallQuery(Function):
pointnet2.ball_query_wrapper(B, M, radius, nsample, new_xyz, new_xyz_batch_cnt, xyz, xyz_batch_cnt, idx)
empty_ball_mask = (idx[:, 0] == -1)
idx[empty_ball_mask] = 0
ctx.mark_non_differentiable(idx)
ctx.mark_non_differentiable(empty_ball_mask)
return idx, empty_ball_mask
@staticmethod
def backward(ctx, a=None):
return None, None, None, None
def backward(ctx, a=None, b=None):
return None, None, None, None, None, None
ball_query = BallQuery.apply
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment