"gallery/transforms/plot_transforms_illustrations.py" did not exist on "a18b4af17504edddfdd8a4832b1a6a64946b73fa"
Unverified Commit 2d73eafe authored by pc's avatar pc Committed by GitHub
Browse files

add mmdet3d op (#1425)


Co-authored-by: default avatarzhouzaida <zhouzaida@163.com>
parent 75cae78c
...@@ -61,10 +61,12 @@ class KNN(Function): ...@@ -61,10 +61,12 @@ class KNN(Function):
idx = center_xyz.new_zeros((B, npoint, k)).int() idx = center_xyz.new_zeros((B, npoint, k)).int()
dist2 = center_xyz.new_zeros((B, npoint, k)).float() dist2 = center_xyz.new_zeros((B, npoint, k)).float()
ext_module.knn_forward(B, N, npoint, k, xyz, center_xyz, idx, dist2) ext_module.knn_forward(
xyz, center_xyz, idx, dist2, b=B, n=N, m=npoint, nsample=k)
# idx shape to [B, k, npoint] # idx shape to [B, k, npoint]
idx = idx.transpose(2, 1).contiguous() idx = idx.transpose(2, 1).contiguous()
ctx.mark_non_differentiable(idx) if torch.__version__ != 'parrots':
ctx.mark_non_differentiable(idx)
return idx return idx
@staticmethod @staticmethod
......
...@@ -39,8 +39,8 @@ class ThreeInterpolate(Function): ...@@ -39,8 +39,8 @@ class ThreeInterpolate(Function):
ctx.three_interpolate_for_backward = (indices, weight, m) ctx.three_interpolate_for_backward = (indices, weight, m)
output = torch.cuda.FloatTensor(B, c, n) output = torch.cuda.FloatTensor(B, c, n)
ext_module.three_interpolate_forward(B, c, m, n, features, indices, ext_module.three_interpolate_forward(
weight, output) features, indices, weight, output, b=B, c=c, m=m, n=n)
return output return output
@staticmethod @staticmethod
...@@ -60,8 +60,8 @@ class ThreeInterpolate(Function): ...@@ -60,8 +60,8 @@ class ThreeInterpolate(Function):
grad_features = torch.cuda.FloatTensor(B, c, m).zero_() grad_features = torch.cuda.FloatTensor(B, c, m).zero_()
grad_out_data = grad_out.data.contiguous() grad_out_data = grad_out.data.contiguous()
ext_module.three_interpolate_backward(B, c, n, m, grad_out_data, idx, ext_module.three_interpolate_backward(
weight, grad_features.data) grad_out_data, idx, weight, grad_features.data, b=B, c=c, n=n, m=m)
return grad_features, None, None return grad_features, None, None
......
...@@ -37,9 +37,9 @@ class ThreeNN(Function): ...@@ -37,9 +37,9 @@ class ThreeNN(Function):
dist2 = torch.cuda.FloatTensor(B, N, 3) dist2 = torch.cuda.FloatTensor(B, N, 3)
idx = torch.cuda.IntTensor(B, N, 3) idx = torch.cuda.IntTensor(B, N, 3)
ext_module.three_nn_forward(B, N, m, target, source, dist2, idx) ext_module.three_nn_forward(target, source, dist2, idx, b=B, n=N, m=m)
if torch.__version__ != 'parrots':
ctx.mark_non_differentiable(idx) ctx.mark_non_differentiable(idx)
return torch.sqrt(dist2), idx return torch.sqrt(dist2), idx
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment