Commit 8922371e authored by Shaoshuai Shi's avatar Shaoshuai Shi
Browse files

Support StackFarthestPointSampling

parent df299e7c
......@@ -184,6 +184,43 @@ class FarthestPointSampling(Function):
farthest_point_sample = furthest_point_sample = FarthestPointSampling.apply
class StackFarthestPointSampling(Function):
@staticmethod
def forward(ctx, xyz, xyz_batch_cnt, npoint):
"""
Args:
ctx:
xyz: (N1 + N2 + ..., 3) where N > npoint
xyz_batch_cnt: [N1, N2, ...]
npoint: int, number of features in the sampled set
Returns:
output: (npoint.sum()) tensor containing the set,
npoint: (M1, M2, ...)
"""
assert xyz.is_contiguous() and xyz.shape[1] == 3
batch_size = xyz_batch_cnt.__len__()
if not isinstance(npoint, torch.Tensor):
if not isinstance(npoint, list):
npoint = [npoint for i in range(batch_size)]
npoint = torch.tensor(npoint, device=xyz.device).int()
N, _ = xyz.size()
temp = torch.cuda.FloatTensor(N).fill_(1e10)
output = torch.cuda.IntTensor(npoint.sum().item())
pointnet2.stack_farthest_point_sampling_wrapper(xyz, temp, xyz_batch_cnt, output, npoint)
return output
@staticmethod
def backward(xyz, a=None):
return None, None
stack_farthest_point_sample = StackFarthestPointSampling.apply
class ThreeNN(Function):
@staticmethod
def forward(ctx, unknown, unknown_batch_cnt, known, known_batch_cnt):
......
......@@ -13,6 +13,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("voxel_query_wrapper", &voxel_query_wrapper_stack, "voxel_query_wrapper_stack");
m.def("farthest_point_sampling_wrapper", &farthest_point_sampling_wrapper, "farthest_point_sampling_wrapper");
m.def("stack_farthest_point_sampling_wrapper", &stack_farthest_point_sampling_wrapper, "stack_farthest_point_sampling_wrapper");
m.def("group_points_wrapper", &group_points_wrapper_stack, "group_points_wrapper_stack");
m.def("group_points_grad_wrapper", &group_points_grad_wrapper_stack, "group_points_grad_wrapper_stack");
......
......@@ -35,3 +35,26 @@ int farthest_point_sampling_wrapper(int b, int n, int m,
farthest_point_sampling_kernel_launcher(b, n, m, points, temp, idx);
return 1;
}
int stack_farthest_point_sampling_wrapper(at::Tensor points_tensor,
at::Tensor temp_tensor, at::Tensor xyz_batch_cnt_tensor, at::Tensor idx_tensor,
at::Tensor num_sampled_points_tensor) {
CHECK_INPUT(points_tensor);
CHECK_INPUT(temp_tensor);
CHECK_INPUT(idx_tensor);
CHECK_INPUT(xyz_batch_cnt_tensor);
CHECK_INPUT(num_sampled_points_tensor);
int batch_size = xyz_batch_cnt_tensor.size(0);
int N = points_tensor.size(0);
const float *points = points_tensor.data<float>();
float *temp = temp_tensor.data<float>();
int *xyz_batch_cnt = xyz_batch_cnt_tensor.data<int>();
int *idx = idx_tensor.data<int>();
int *num_sampled_points = num_sampled_points_tensor.data<int>();
stack_farthest_point_sampling_kernel_launcher(N, batch_size, points, temp, xyz_batch_cnt, idx, num_sampled_points);
return 1;
}
\ No newline at end of file
......@@ -182,3 +182,168 @@ void farthest_point_sampling_kernel_launcher(int b, int n, int m,
exit(-1);
}
}
template <unsigned int block_size>
__global__ void stack_farthest_point_sampling_kernel(int batch_size, int N,
const float *dataset, float *temp, int *xyz_batch_cnt, int *idxs, int *num_sampled_points) {
// """
// Args:
// ctx:
// dataset: (N1 + N2 + ..., 3) where N > npoint
// temp: (N1 + N2 + ...) where N > npoint
// xyz_batch_cnt: [N1, N2, ...]
// num_sampled_points: [M1, M2, ...] int, number of features in the sampled set
// Returns:
// idxs: (npoint.sum()) tensor containing the set,
// npoint: (M1, M2, ...)
// """
__shared__ float dists[block_size];
__shared__ int dists_i[block_size];
int bs_idx = blockIdx.x;
int xyz_batch_start_idx = 0, idxs_start_idx = 0;
for (int k = 0; k < bs_idx; k++){
xyz_batch_start_idx += xyz_batch_cnt[k];
idxs_start_idx += num_sampled_points[k];
}
dataset += xyz_batch_start_idx * 3;
temp += xyz_batch_start_idx;
idxs += idxs_start_idx;
int n = xyz_batch_cnt[bs_idx];
int m = num_sampled_points[bs_idx];
int tid = threadIdx.x;
const int stride = block_size;
int old = 0;
if (threadIdx.x == 0) idxs[0] = xyz_batch_start_idx;
__syncthreads();
for (int j = 1; j < m; j++) {
int besti = 0;
float best = -1;
float x1 = dataset[old * 3 + 0];
float y1 = dataset[old * 3 + 1];
float z1 = dataset[old * 3 + 2];
for (int k = tid; k < n; k += stride) {
float x2, y2, z2;
x2 = dataset[k * 3 + 0];
y2 = dataset[k * 3 + 1];
z2 = dataset[k * 3 + 2];
// float mag = (x2 * x2) + (y2 * y2) + (z2 * z2);
// if (mag <= 1e-3)
// continue;
float d = (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) * (z2 - z1);
float d2 = min(d, temp[k]);
temp[k] = d2;
besti = d2 > best ? k : besti;
best = d2 > best ? d2 : best;
}
dists[tid] = best;
dists_i[tid] = besti;
__syncthreads();
if (block_size >= 1024) {
if (tid < 512) {
__update(dists, dists_i, tid, tid + 512);
}
__syncthreads();
}
if (block_size >= 512) {
if (tid < 256) {
__update(dists, dists_i, tid, tid + 256);
}
__syncthreads();
}
if (block_size >= 256) {
if (tid < 128) {
__update(dists, dists_i, tid, tid + 128);
}
__syncthreads();
}
if (block_size >= 128) {
if (tid < 64) {
__update(dists, dists_i, tid, tid + 64);
}
__syncthreads();
}
if (block_size >= 64) {
if (tid < 32) {
__update(dists, dists_i, tid, tid + 32);
}
__syncthreads();
}
if (block_size >= 32) {
if (tid < 16) {
__update(dists, dists_i, tid, tid + 16);
}
__syncthreads();
}
if (block_size >= 16) {
if (tid < 8) {
__update(dists, dists_i, tid, tid + 8);
}
__syncthreads();
}
if (block_size >= 8) {
if (tid < 4) {
__update(dists, dists_i, tid, tid + 4);
}
__syncthreads();
}
if (block_size >= 4) {
if (tid < 2) {
__update(dists, dists_i, tid, tid + 2);
}
__syncthreads();
}
if (block_size >= 2) {
if (tid < 1) {
__update(dists, dists_i, tid, tid + 1);
}
__syncthreads();
}
old = dists_i[0];
if (tid == 0)
idxs[j] = old + xyz_batch_start_idx;
}
}
void stack_farthest_point_sampling_kernel_launcher(int N, int batch_size,
const float *dataset, float *temp, int *xyz_batch_cnt, int *idxs, int *num_sampled_points) {
// """
// Args:
// ctx:
// dataset: (N1 + N2 + ..., 3) where N > npoint
// temp: (N1 + N2 + ...) where N > npoint
// xyz_batch_cnt: [N1, N2, ...]
// npoint: int, number of features in the sampled set
// Returns:
// idxs: (npoint.sum()) tensor containing the set,
// npoint: (M1, M2, ...)
// """
cudaError_t err;
unsigned int n_threads = opt_n_threads(N);
stack_farthest_point_sampling_kernel<1024><<<batch_size, 1024>>>(
batch_size, N, dataset, temp, xyz_batch_cnt, idxs, num_sampled_points
);
err = cudaGetLastError();
if (cudaSuccess != err) {
fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
exit(-1);
}
}
\ No newline at end of file
......@@ -12,4 +12,12 @@ int farthest_point_sampling_wrapper(int b, int n, int m,
void farthest_point_sampling_kernel_launcher(int b, int n, int m,
const float *dataset, float *temp, int *idxs);
int stack_farthest_point_sampling_wrapper(
at::Tensor points_tensor, at::Tensor temp_tensor, at::Tensor xyz_batch_cnt_tensor,
at::Tensor idx_tensor, at::Tensor num_sampled_points_tensor);
void stack_farthest_point_sampling_kernel_launcher(int N, int batch_size,
const float *dataset, float *temp, int *xyz_batch_cnt, int *idxs, int *num_sampled_points);
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment