"vscode:/vscode.git/clone" did not exist on "d73def35a3961cfc326f132fc8038cdbc5493d3d"
Commit 8c8014b9 authored by rusty1s's avatar rusty1s
Browse files

fps fixes

parent 388a2e2b
#pragma once
static inline __device__ void atomAdd(float *address, float val) {
atomicAdd(address, val);
}
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 600 || CUDA_VERSION < 8000)
static inline __device__ void atomAdd(double *address, double val) {
unsigned long long int *address_as_ull = (unsigned long long int *)address;
unsigned long long int old = *address_as_ull;
unsigned long long int assumed;
do {
assumed = old;
old = atomicCAS(address_as_ull, assumed,
__double_as_longlong(val + __longlong_as_double(assumed)));
} while (assumed != old);
}
#else
static inline __device__ void atomAdd(double *address, double val) {
atomicAdd(address, val);
}
#endif
......@@ -2,11 +2,12 @@
#include <ATen/cuda/CUDAContext.h>
#include "atomics.cuh"
#include "utils.cuh"
#define THREADS 1024
template <typename scalar_t> struct Dist<scalar_t> {
template <typename scalar_t> struct Dist {
static inline __device__ void compute(int64_t idx, int64_t start_idx,
int64_t end_idx, int64_t old,
scalar_t *best, int64_t *best_idx,
......@@ -20,7 +21,7 @@ template <typename scalar_t> struct Dist<scalar_t> {
__syncthreads();
for (int64_t i = start_idx * dim + idx; i < end_idx * dim; i += THREADS) {
scalar_t d = src[(old * dim) + (i % dim)] - src[i];
atomicAdd(&tmp_dist[i / dim], d * d);
atomAdd(&tmp_dist[i / dim], d * d);
}
__syncthreads();
......@@ -58,11 +59,11 @@ __global__ void fps_kernel(const scalar_t *src, const int64_t *ptr,
int64_t best_idx = 0;
__syncthreads();
Dist<scalar_t, Dim>::compute(thread_idx, start_idx, end_idx, out[m - 1],
&best, &best_idx, src, dist, tmp_dist, dim);
Dist<scalar_t>::compute(thread_idx, start_idx, end_idx, out[m - 1], &best,
&best_idx, src, dist, tmp_dist, dim);
best_dist[idx] = best;
best_dist_idx[idx] = best_idx;
best_dist[thread_idx] = best;
best_dist_idx[thread_idx] = best_idx;
for (int64_t u = 0; (1 << u) < THREADS; u++) {
__syncthreads();
......@@ -77,7 +78,7 @@ __global__ void fps_kernel(const scalar_t *src, const int64_t *ptr,
}
__syncthreads();
if (idx == 0) {
if (thread_idx == 0) {
out[m] = best_dist_idx[0];
}
}
......@@ -99,7 +100,7 @@ torch::Tensor fps_cuda(torch::Tensor src, torch::Tensor ptr, double ratio,
auto deg = ptr.narrow(0, 1, batch_size) - ptr.narrow(0, 0, batch_size);
auto out_ptr = deg.toType(torch::kFloat) * (float)ratio;
out_ptr = out_ptr.ceil().toType(torch::kLong).cumsum(0);
out_ptr = torch::cat({torch.zeros(1, ptr.options()), out_ptr}, 0);
out_ptr = torch::cat({torch::zeros(1, ptr.options()), out_ptr}, 0);
torch::Tensor start;
if (random_start) {
......@@ -120,7 +121,7 @@ torch::Tensor fps_cuda(torch::Tensor src, torch::Tensor ptr, double ratio,
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES(src.scalar_type(), "fps_kernel", [&] {
fps_kernel<scalar_t><<<batch_size, THREADS, 0, stream>>>(
src.data_ptr<scalar_t>(), rowptr.data_ptr<int64_t>(),
src.data_ptr<scalar_t>(), ptr.data_ptr<int64_t>(),
out_ptr.data_ptr<int64_t>(), start.data_ptr<int64_t>(),
dist.data_ptr<scalar_t>(), tmp_dist.data_ptr<scalar_t>(),
out.data_ptr<int64_t>(), src.size(1));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment