"git@developer.sourcefind.cn:OpenDAS/openpcdet.git" did not exist on "e685140975ca203242e826e13b088654509d6620"
Unverified Commit 878ad2e0 authored by cheng052's avatar cheng052 Committed by GitHub
Browse files

mark idx in BallQuery as non_differentiable (#919)

parent fd20aabc
...@@ -35,11 +35,15 @@ class BallQuery(Function): ...@@ -35,11 +35,15 @@ class BallQuery(Function):
pointnet2.ball_query_wrapper(B, M, radius, nsample, new_xyz, new_xyz_batch_cnt, xyz, xyz_batch_cnt, idx) pointnet2.ball_query_wrapper(B, M, radius, nsample, new_xyz, new_xyz_batch_cnt, xyz, xyz_batch_cnt, idx)
empty_ball_mask = (idx[:, 0] == -1) empty_ball_mask = (idx[:, 0] == -1)
idx[empty_ball_mask] = 0 idx[empty_ball_mask] = 0
ctx.mark_non_differentiable(idx)
ctx.mark_non_differentiable(empty_ball_mask)
return idx, empty_ball_mask return idx, empty_ball_mask
@staticmethod @staticmethod
def backward(ctx, a=None): def backward(ctx, a=None, b=None):
return None, None, None, None return None, None, None, None, None, None
ball_query = BallQuery.apply ball_query = BallQuery.apply
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment