fps_cuda.cu 4.15 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
#include "fps_cuda.h"

rusty1s's avatar
rusty1s committed
3
4
#include <ATen/cuda/CUDAContext.h>

rusty1s's avatar
rusty1s committed
5
#include "atomics.cuh"
rusty1s's avatar
rusty1s committed
6
7
#include "utils.cuh"

rusty1s's avatar
rusty1s committed
8
9
#define THREADS 1024

rusty1s's avatar
rusty1s committed
10
template <typename scalar_t> struct Dist {
rusty1s's avatar
rusty1s committed
11
12
13
  static inline __device__ void compute(int64_t idx, int64_t start_idx,
                                        int64_t end_idx, int64_t old,
                                        scalar_t *best, int64_t *best_idx,
rusty1s's avatar
rusty1s committed
14
                                        const scalar_t *src, scalar_t *dist,
rusty1s's avatar
rusty1s committed
15
16
17
18
19
20
21
22
                                        scalar_t *tmp_dist, int64_t dim) {

    for (int64_t n = start_idx + idx; n < end_idx; n += THREADS) {
      tmp_dist[n] = 0;
    }

    __syncthreads();
    for (int64_t i = start_idx * dim + idx; i < end_idx * dim; i += THREADS) {
rusty1s's avatar
rusty1s committed
23
      scalar_t d = src[(old * dim) + (i % dim)] - src[i];
rusty1s's avatar
rusty1s committed
24
      atomAdd(&tmp_dist[i / dim], d * d);
rusty1s's avatar
rusty1s committed
25
26
27
28
29
30
31
32
33
34
35
36
37
38
    }

    __syncthreads();
    for (int64_t n = start_idx + idx; n < end_idx; n += THREADS) {
      dist[n] = min(dist[n], tmp_dist[n]);
      if (dist[n] > *best) {
        *best = dist[n];
        *best_idx = n;
      }
    }
  }
};

template <typename scalar_t>
rusty1s's avatar
rusty1s committed
39
__global__ void fps_kernel(const scalar_t *src, const int64_t *ptr,
rusty1s's avatar
rusty1s committed
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
                           const int64_t *out_ptr, const int64_t *start,
                           scalar_t *dist, scalar_t *tmp_dist, int64_t *out,
                           int64_t dim) {

  const int64_t batch_idx = blockIdx.x;
  const int64_t thread_idx = threadIdx.x;

  const int64_t start_idx = ptr[batch_idx];
  const int64_t end_idx = ptr[batch_idx + 1];

  __shared__ scalar_t best_dist[THREADS];
  __shared__ int64_t best_dist_idx[THREADS];

  if (threadIdx.x == 0) {
    out[out_ptr[batch_idx]] = start_idx + start[batch_idx];
  }

  for (int64_t m = out_ptr[batch_idx] + 1; m < out_ptr[batch_idx + 1]; m++) {
    scalar_t best = -1;
    int64_t best_idx = 0;

    __syncthreads();
rusty1s's avatar
rusty1s committed
62
63
    Dist<scalar_t>::compute(thread_idx, start_idx, end_idx, out[m - 1], &best,
                            &best_idx, src, dist, tmp_dist, dim);
rusty1s's avatar
rusty1s committed
64

rusty1s's avatar
rusty1s committed
65
66
    best_dist[thread_idx] = best;
    best_dist_idx[thread_idx] = best_idx;
rusty1s's avatar
rusty1s committed
67
68
69
70
71
72
73
74
75
76
77
78
79
80

    for (int64_t u = 0; (1 << u) < THREADS; u++) {
      __syncthreads();
      if (thread_idx < (THREADS >> (u + 1))) {
        int64_t idx1 = (thread_idx * 2) << u;
        int64_t idx2 = (thread_idx * 2 + 1) << u;
        if (best_dist[idx1] < best_dist[idx2]) {
          best_dist[idx1] = best_dist[idx2];
          best_dist_idx[idx1] = best_dist_idx[idx2];
        }
      }
    }

    __syncthreads();
rusty1s's avatar
rusty1s committed
81
    if (thread_idx == 0) {
rusty1s's avatar
rusty1s committed
82
83
84
85
86
      out[m] = best_dist_idx[0];
    }
  }
}

rusty1s's avatar
rusty1s committed
87
88
89
90
91
92
93
torch::Tensor fps_cuda(torch::Tensor src, torch::Tensor ptr, double ratio,
                       bool random_start) {

  CHECK_CUDA(src);
  CHECK_CUDA(ptr);
  CHECK_INPUT(ptr.dim() == 1);
  AT_ASSERTM(ratio > 0 and ratio < 1, "Invalid input");
rusty1s's avatar
rusty1s committed
94
  cudaSetDevice(src.get_device());
rusty1s's avatar
rusty1s committed
95
96
97
98
99
100
101
102

  src = src.view({src.size(0), -1}).contiguous();
  ptr = ptr.contiguous();
  auto batch_size = ptr.size(0) - 1;

  auto deg = ptr.narrow(0, 1, batch_size) - ptr.narrow(0, 0, batch_size);
  auto out_ptr = deg.toType(torch::kFloat) * (float)ratio;
  out_ptr = out_ptr.ceil().toType(torch::kLong).cumsum(0);
rusty1s's avatar
rusty1s committed
103
  out_ptr = torch::cat({torch::zeros(1, ptr.options()), out_ptr}, 0);
rusty1s's avatar
rusty1s committed
104
105
106

  torch::Tensor start;
  if (random_start) {
rusty1s's avatar
rusty1s committed
107
    start = torch::rand(batch_size, src.options());
rusty1s's avatar
rusty1s committed
108
109
110
111
112
    start = (start * deg.toType(torch::kFloat)).toType(torch::kLong);
  } else {
    start = torch::zeros(batch_size, ptr.options());
  }

rusty1s's avatar
rusty1s committed
113
114
115
116
117
118
  auto dist = torch::full(src.size(0), 1e38, src.options());
  auto tmp_dist = torch::empty(src.size(0), src.options());

  auto out_size = (int64_t *)malloc(sizeof(int64_t));
  cudaMemcpy(out_size, out_ptr[-1].data_ptr<int64_t>(), sizeof(int64_t),
             cudaMemcpyDeviceToHost);
rusty1s's avatar
rusty1s committed
119
  auto out = torch::empty(out_size[0], out_ptr.options());
rusty1s's avatar
rusty1s committed
120

rusty1s's avatar
rusty1s committed
121
122
123
  auto stream = at::cuda::getCurrentCUDAStream();
  AT_DISPATCH_FLOATING_TYPES(src.scalar_type(), "fps_kernel", [&] {
    fps_kernel<scalar_t><<<batch_size, THREADS, 0, stream>>>(
rusty1s's avatar
rusty1s committed
124
        src.data_ptr<scalar_t>(), ptr.data_ptr<int64_t>(),
rusty1s's avatar
rusty1s committed
125
126
127
128
        out_ptr.data_ptr<int64_t>(), start.data_ptr<int64_t>(),
        dist.data_ptr<scalar_t>(), tmp_dist.data_ptr<scalar_t>(),
        out.data_ptr<int64_t>(), src.size(1));
  });
rusty1s's avatar
rusty1s committed
129
130
131

  return out;
}