Commit 8c8014b9 authored by rusty1s's avatar rusty1s
Browse files

fps fixes

parent 388a2e2b
#pragma once
static inline __device__ void atomAdd(float *address, float val) {
atomicAdd(address, val);
}
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 600 || CUDA_VERSION < 8000)
static inline __device__ void atomAdd(double *address, double val) {
unsigned long long int *address_as_ull = (unsigned long long int *)address;
unsigned long long int old = *address_as_ull;
unsigned long long int assumed;
do {
assumed = old;
old = atomicCAS(address_as_ull, assumed,
__double_as_longlong(val + __longlong_as_double(assumed)));
} while (assumed != old);
}
#else
static inline __device__ void atomAdd(double *address, double val) {
atomicAdd(address, val);
}
#endif
...@@ -2,11 +2,12 @@ ...@@ -2,11 +2,12 @@
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include "atomics.cuh"
#include "utils.cuh" #include "utils.cuh"
#define THREADS 1024 #define THREADS 1024
template <typename scalar_t> struct Dist<scalar_t> { template <typename scalar_t> struct Dist {
static inline __device__ void compute(int64_t idx, int64_t start_idx, static inline __device__ void compute(int64_t idx, int64_t start_idx,
int64_t end_idx, int64_t old, int64_t end_idx, int64_t old,
scalar_t *best, int64_t *best_idx, scalar_t *best, int64_t *best_idx,
...@@ -20,7 +21,7 @@ template <typename scalar_t> struct Dist<scalar_t> { ...@@ -20,7 +21,7 @@ template <typename scalar_t> struct Dist<scalar_t> {
__syncthreads(); __syncthreads();
for (int64_t i = start_idx * dim + idx; i < end_idx * dim; i += THREADS) { for (int64_t i = start_idx * dim + idx; i < end_idx * dim; i += THREADS) {
scalar_t d = src[(old * dim) + (i % dim)] - src[i]; scalar_t d = src[(old * dim) + (i % dim)] - src[i];
atomicAdd(&tmp_dist[i / dim], d * d); atomAdd(&tmp_dist[i / dim], d * d);
} }
__syncthreads(); __syncthreads();
...@@ -58,11 +59,11 @@ __global__ void fps_kernel(const scalar_t *src, const int64_t *ptr, ...@@ -58,11 +59,11 @@ __global__ void fps_kernel(const scalar_t *src, const int64_t *ptr,
int64_t best_idx = 0; int64_t best_idx = 0;
__syncthreads(); __syncthreads();
Dist<scalar_t, Dim>::compute(thread_idx, start_idx, end_idx, out[m - 1], Dist<scalar_t>::compute(thread_idx, start_idx, end_idx, out[m - 1], &best,
&best, &best_idx, src, dist, tmp_dist, dim); &best_idx, src, dist, tmp_dist, dim);
best_dist[idx] = best; best_dist[thread_idx] = best;
best_dist_idx[idx] = best_idx; best_dist_idx[thread_idx] = best_idx;
for (int64_t u = 0; (1 << u) < THREADS; u++) { for (int64_t u = 0; (1 << u) < THREADS; u++) {
__syncthreads(); __syncthreads();
...@@ -77,7 +78,7 @@ __global__ void fps_kernel(const scalar_t *src, const int64_t *ptr, ...@@ -77,7 +78,7 @@ __global__ void fps_kernel(const scalar_t *src, const int64_t *ptr,
} }
__syncthreads(); __syncthreads();
if (idx == 0) { if (thread_idx == 0) {
out[m] = best_dist_idx[0]; out[m] = best_dist_idx[0];
} }
} }
...@@ -99,7 +100,7 @@ torch::Tensor fps_cuda(torch::Tensor src, torch::Tensor ptr, double ratio, ...@@ -99,7 +100,7 @@ torch::Tensor fps_cuda(torch::Tensor src, torch::Tensor ptr, double ratio,
auto deg = ptr.narrow(0, 1, batch_size) - ptr.narrow(0, 0, batch_size); auto deg = ptr.narrow(0, 1, batch_size) - ptr.narrow(0, 0, batch_size);
auto out_ptr = deg.toType(torch::kFloat) * (float)ratio; auto out_ptr = deg.toType(torch::kFloat) * (float)ratio;
out_ptr = out_ptr.ceil().toType(torch::kLong).cumsum(0); out_ptr = out_ptr.ceil().toType(torch::kLong).cumsum(0);
out_ptr = torch::cat({torch.zeros(1, ptr.options()), out_ptr}, 0); out_ptr = torch::cat({torch::zeros(1, ptr.options()), out_ptr}, 0);
torch::Tensor start; torch::Tensor start;
if (random_start) { if (random_start) {
...@@ -120,7 +121,7 @@ torch::Tensor fps_cuda(torch::Tensor src, torch::Tensor ptr, double ratio, ...@@ -120,7 +121,7 @@ torch::Tensor fps_cuda(torch::Tensor src, torch::Tensor ptr, double ratio,
auto stream = at::cuda::getCurrentCUDAStream(); auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES(src.scalar_type(), "fps_kernel", [&] { AT_DISPATCH_FLOATING_TYPES(src.scalar_type(), "fps_kernel", [&] {
fps_kernel<scalar_t><<<batch_size, THREADS, 0, stream>>>( fps_kernel<scalar_t><<<batch_size, THREADS, 0, stream>>>(
src.data_ptr<scalar_t>(), rowptr.data_ptr<int64_t>(), src.data_ptr<scalar_t>(), ptr.data_ptr<int64_t>(),
out_ptr.data_ptr<int64_t>(), start.data_ptr<int64_t>(), out_ptr.data_ptr<int64_t>(), start.data_ptr<int64_t>(),
dist.data_ptr<scalar_t>(), tmp_dist.data_ptr<scalar_t>(), dist.data_ptr<scalar_t>(), tmp_dist.data_ptr<scalar_t>(),
out.data_ptr<int64_t>(), src.size(1)); out.data_ptr<int64_t>(), src.size(1));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment