Unverified Commit 2d73eafe authored by pc's avatar pc Committed by GitHub
Browse files

add mmdet3d op (#1425)


Co-authored-by: default avatarzhouzaida <zhouzaida@163.com>
parent 75cae78c
......@@ -61,10 +61,12 @@ class KNN(Function):
idx = center_xyz.new_zeros((B, npoint, k)).int()
dist2 = center_xyz.new_zeros((B, npoint, k)).float()
ext_module.knn_forward(B, N, npoint, k, xyz, center_xyz, idx, dist2)
ext_module.knn_forward(
xyz, center_xyz, idx, dist2, b=B, n=N, m=npoint, nsample=k)
# idx shape to [B, k, npoint]
idx = idx.transpose(2, 1).contiguous()
ctx.mark_non_differentiable(idx)
if torch.__version__ != 'parrots':
ctx.mark_non_differentiable(idx)
return idx
@staticmethod
......
......@@ -39,8 +39,8 @@ class ThreeInterpolate(Function):
ctx.three_interpolate_for_backward = (indices, weight, m)
output = torch.cuda.FloatTensor(B, c, n)
ext_module.three_interpolate_forward(B, c, m, n, features, indices,
weight, output)
ext_module.three_interpolate_forward(
features, indices, weight, output, b=B, c=c, m=m, n=n)
return output
@staticmethod
......@@ -60,8 +60,8 @@ class ThreeInterpolate(Function):
grad_features = torch.cuda.FloatTensor(B, c, m).zero_()
grad_out_data = grad_out.data.contiguous()
ext_module.three_interpolate_backward(B, c, n, m, grad_out_data, idx,
weight, grad_features.data)
ext_module.three_interpolate_backward(
grad_out_data, idx, weight, grad_features.data, b=B, c=c, n=n, m=m)
return grad_features, None, None
......
......@@ -37,9 +37,9 @@ class ThreeNN(Function):
dist2 = torch.cuda.FloatTensor(B, N, 3)
idx = torch.cuda.IntTensor(B, N, 3)
ext_module.three_nn_forward(B, N, m, target, source, dist2, idx)
ctx.mark_non_differentiable(idx)
ext_module.three_nn_forward(target, source, dist2, idx, b=B, n=N, m=m)
if torch.__version__ != 'parrots':
ctx.mark_non_differentiable(idx)
return torch.sqrt(dist2), idx
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment