Commit f561b8ae authored by Shaoshuai Shi's avatar Shaoshuai Shi
Browse files

bugfixed: backward bug in ThreeInterpolate of stacked version pointnet2

parent 3b2bcdaf
......@@ -77,5 +77,6 @@ void three_interpolate_grad_wrapper_stack(at::Tensor grad_out_tensor, at::Tensor
const int *idx = idx_tensor.data<int>();
float *grad_features = grad_features_tensor.data<float>();
// printf("N=%d, channels=%d\n", N, channels);
three_interpolate_grad_kernel_launcher_stack(N, channels, grad_out, idx, weight, grad_features);
}
\ No newline at end of file
......@@ -164,9 +164,11 @@ __global__ void three_interpolate_grad_kernel_stack(int N, int channels, const f
weight += pt_idx * 3;
idx += pt_idx * 3;
atomicAdd(grad_features + idx[0], grad_out[0] * weight[0]);
atomicAdd(grad_features + idx[1], grad_out[0] * weight[1]);
atomicAdd(grad_features + idx[2], grad_out[0] * weight[2]);
// printf("pt_idx=%d, c_idx=%d, idx=(%d, %d, %d), grad_out=%f\n", pt_idx, c_idx, idx[0], idx[1], idx[2], grad_out[0]);
atomicAdd(grad_features + idx[0] * channels + c_idx, grad_out[0] * weight[0]);
atomicAdd(grad_features + idx[1] * channels + c_idx, grad_out[0] * weight[1]);
atomicAdd(grad_features + idx[2] * channels + c_idx, grad_out[0] * weight[2]);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment