Commit 685b3770 authored by rusty1s's avatar rusty1s
Browse files

template for dim

parent 5ab10c54
...@@ -5,29 +5,74 @@ ...@@ -5,29 +5,74 @@
#define THREADS 1024 #define THREADS 1024
template <typename scalar_t> template <typename scalar_t, int64_t Dim> struct Dist {};
__global__ void
fps_kernel(scalar_t *__restrict__ x, int64_t *__restrict__ cum_deg,
int64_t *__restrict__ cum_k, int64_t *__restrict__ start,
scalar_t *__restrict__ dist, scalar_t *__restrict__ tmp_dist,
int64_t *__restrict__ out, size_t dim) {
const size_t batch_idx = blockIdx.x; template <typename scalar_t> struct Dist<scalar_t, 1> {
const size_t idx = threadIdx.x; static __device__ void
compute(ptrdiff_t idx, ptrdiff_t start_idx, ptrdiff_t end_idx, ptrdiff_t old,
scalar_t *__restrict__ best, ptrdiff_t *__restrict__ best_idx,
scalar_t *__restrict__ x, scalar_t *__restrict__ dist,
scalar_t *__restrict__ tmp_dist, size_t dim) {
const size_t start_idx = cum_deg[batch_idx]; for (ptrdiff_t n = start_idx + idx; n < end_idx; n += THREADS) {
const size_t end_idx = cum_deg[batch_idx + 1]; scalar_t a = x[old * 3 + 0] - x[n * 3 + 0];
scalar_t d = a * a;
dist[n] = min(dist[n], d);
if (dist[n] > *best) {
*best = dist[n];
*best_idx = n;
}
}
}
};
__shared__ scalar_t best_dist[THREADS]; template <typename scalar_t> struct Dist<scalar_t, 2> {
__shared__ int64_t best_dist_idx[THREADS]; static __device__ void
compute(ptrdiff_t idx, ptrdiff_t start_idx, ptrdiff_t end_idx, ptrdiff_t old,
scalar_t *__restrict__ best, ptrdiff_t *__restrict__ best_idx,
scalar_t *__restrict__ x, scalar_t *__restrict__ dist,
scalar_t *__restrict__ tmp_dist, size_t dim) {
if (idx == 0) { for (ptrdiff_t n = start_idx + idx; n < end_idx; n += THREADS) {
out[cum_k[batch_idx]] = start_idx + start[batch_idx]; scalar_t a = x[old * 3 + 0] - x[n * 3 + 0];
scalar_t b = x[old * 3 + 1] - x[n * 3 + 1];
scalar_t d = a * a + b * b;
dist[n] = min(dist[n], d);
if (dist[n] > *best) {
*best = dist[n];
*best_idx = n;
}
}
} }
};
for (ptrdiff_t m = cum_k[batch_idx] + 1; m < cum_k[batch_idx + 1]; m++) { template <typename scalar_t> struct Dist<scalar_t, 3> {
ptrdiff_t best_idx = 0; static __device__ void
scalar_t best = -1; compute(ptrdiff_t idx, ptrdiff_t start_idx, ptrdiff_t end_idx, ptrdiff_t old,
scalar_t *__restrict__ best, ptrdiff_t *__restrict__ best_idx,
scalar_t *__restrict__ x, scalar_t *__restrict__ dist,
scalar_t *__restrict__ tmp_dist, size_t dim) {
for (ptrdiff_t n = start_idx + idx; n < end_idx; n += THREADS) {
scalar_t a = x[old * 3 + 0] - x[n * 3 + 0];
scalar_t b = x[old * 3 + 1] - x[n * 3 + 1];
scalar_t c = x[old * 3 + 2] - x[n * 3 + 2];
scalar_t d = a * a + b * b + c * c;
dist[n] = min(dist[n], d);
if (dist[n] > *best) {
*best = dist[n];
*best_idx = n;
}
}
}
};
template <typename scalar_t> struct Dist<scalar_t, -1> {
static __device__ void
compute(ptrdiff_t idx, ptrdiff_t start_idx, ptrdiff_t end_idx, ptrdiff_t old,
scalar_t *__restrict__ best, ptrdiff_t *__restrict__ best_idx,
scalar_t *__restrict__ x, scalar_t *__restrict__ dist,
scalar_t *__restrict__ tmp_dist, size_t dim) {
for (ptrdiff_t n = start_idx + idx; n < end_idx; n += THREADS) { for (ptrdiff_t n = start_idx + idx; n < end_idx; n += THREADS) {
tmp_dist[n] = 0; tmp_dist[n] = 0;
...@@ -35,18 +80,47 @@ fps_kernel(scalar_t *__restrict__ x, int64_t *__restrict__ cum_deg, ...@@ -35,18 +80,47 @@ fps_kernel(scalar_t *__restrict__ x, int64_t *__restrict__ cum_deg,
__syncthreads(); __syncthreads();
for (ptrdiff_t i = start_idx * dim + idx; i < end_idx * dim; i += THREADS) { for (ptrdiff_t i = start_idx * dim + idx; i < end_idx * dim; i += THREADS) {
scalar_t d = x[(out[m - 1] * dim) + (i % dim)] - x[i]; scalar_t d = x[(old * dim) + (i % dim)] - x[i];
atomicAdd(&tmp_dist[i / dim], d * d); atomicAdd(&tmp_dist[i / dim], d * d);
} }
__syncthreads(); __syncthreads();
for (ptrdiff_t n = start_idx + idx; n < end_idx; n += THREADS) { for (ptrdiff_t n = start_idx + idx; n < end_idx; n += THREADS) {
dist[n] = min(dist[n], tmp_dist[n]); dist[n] = min(dist[n], tmp_dist[n]);
if (dist[n] > best) { if (dist[n] > *best) {
best = dist[n]; *best = dist[n];
best_idx = n; *best_idx = n;
} }
} }
}
};
template <typename scalar_t, int64_t Dim>
__global__ void
fps_kernel(scalar_t *__restrict__ x, int64_t *__restrict__ cum_deg,
int64_t *__restrict__ cum_k, int64_t *__restrict__ start,
scalar_t *__restrict__ dist, scalar_t *__restrict__ tmp_dist,
int64_t *__restrict__ out, size_t dim) {
const ptrdiff_t batch_idx = blockIdx.x;
const ptrdiff_t idx = threadIdx.x;
const ptrdiff_t start_idx = cum_deg[batch_idx];
const ptrdiff_t end_idx = cum_deg[batch_idx + 1];
__shared__ scalar_t best_dist[THREADS];
__shared__ int64_t best_dist_idx[THREADS];
if (idx == 0) {
out[cum_k[batch_idx]] = start_idx + start[batch_idx];
}
for (ptrdiff_t m = cum_k[batch_idx] + 1; m < cum_k[batch_idx + 1]; m++) {
scalar_t best = -1;
ptrdiff_t best_idx = 0;
Dist<scalar_t, Dim>::compute(idx, start_idx, end_idx, out[m - 1], &best,
&best_idx, x, dist, tmp_dist, dim);
best_dist[idx] = best; best_dist[idx] = best;
best_dist_idx[idx] = best_idx; best_dist_idx[idx] = best_idx;
...@@ -64,10 +138,29 @@ fps_kernel(scalar_t *__restrict__ x, int64_t *__restrict__ cum_deg, ...@@ -64,10 +138,29 @@ fps_kernel(scalar_t *__restrict__ x, int64_t *__restrict__ cum_deg,
} }
__syncthreads(); __syncthreads();
out[m] = best_dist_idx[0]; if (idx == 0) {
out[m] = best_dist_idx[0];
}
} }
} }
#define FPS_KERNEL(DIM, ...) \
[&] { \
switch (DIM) { \
case 1: \
fps_kernel<scalar_t, 1><<<batch_size, THREADS>>>(__VA_ARGS__, DIM); \
break; \
case 2: \
fps_kernel<scalar_t, 2><<<batch_size, THREADS>>>(__VA_ARGS__, DIM); \
break; \
case 3: \
fps_kernel<scalar_t, 3><<<batch_size, THREADS>>>(__VA_ARGS__, DIM); \
break; \
default: \
fps_kernel<scalar_t, -1><<<batch_size, THREADS>>>(__VA_ARGS__, DIM); \
} \
}()
at::Tensor fps_cuda(at::Tensor x, at::Tensor batch, float ratio, bool random) { at::Tensor fps_cuda(at::Tensor x, at::Tensor batch, float ratio, bool random) {
auto batch_sizes = (int64_t *)malloc(sizeof(int64_t)); auto batch_sizes = (int64_t *)malloc(sizeof(int64_t));
cudaMemcpy(batch_sizes, batch[-1].data<int64_t>(), sizeof(int64_t), cudaMemcpy(batch_sizes, batch[-1].data<int64_t>(), sizeof(int64_t),
...@@ -96,10 +189,10 @@ at::Tensor fps_cuda(at::Tensor x, at::Tensor batch, float ratio, bool random) { ...@@ -96,10 +189,10 @@ at::Tensor fps_cuda(at::Tensor x, at::Tensor batch, float ratio, bool random) {
auto out = at::empty(k_sum[0], k.options()); auto out = at::empty(k_sum[0], k.options());
AT_DISPATCH_FLOATING_TYPES(x.type(), "fps_kernel", [&] { AT_DISPATCH_FLOATING_TYPES(x.type(), "fps_kernel", [&] {
fps_kernel<scalar_t><<<batch_size, THREADS>>>( FPS_KERNEL(x.size(1), x.data<scalar_t>(), cum_deg.data<int64_t>(),
x.data<scalar_t>(), cum_deg.data<int64_t>(), cum_k.data<int64_t>(), cum_k.data<int64_t>(), start.data<int64_t>(),
start.data<int64_t>(), dist.data<scalar_t>(), tmp_dist.data<scalar_t>(), dist.data<scalar_t>(), tmp_dist.data<scalar_t>(),
out.data<int64_t>(), x.size(1)); out.data<int64_t>());
}); });
return out; return out;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment