fps_kernel.cu 7.06 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
#include <ATen/ATen.h>

rusty1s's avatar
rusty1s committed
3
#include "atomics.cuh"
rusty1s's avatar
rusty1s committed
4
5
6
7
#include "utils.cuh"

#define THREADS 1024

rusty1s's avatar
rusty1s committed
8
template <typename scalar_t, int64_t Dim> struct Dist;
rusty1s's avatar
rusty1s committed
9

rusty1s's avatar
rusty1s committed
10
11
12
13
template <typename scalar_t> struct Dist<scalar_t, 1> {
  static __device__ void
  compute(ptrdiff_t idx, ptrdiff_t start_idx, ptrdiff_t end_idx, ptrdiff_t old,
          scalar_t *__restrict__ best, ptrdiff_t *__restrict__ best_idx,
rusty1s's avatar
typos  
rusty1s committed
14
          const scalar_t *__restrict__ x, scalar_t *__restrict__ dist,
rusty1s's avatar
rusty1s committed
15
          scalar_t *__restrict__ tmp_dist, size_t dim) {
rusty1s's avatar
rusty1s committed
16

rusty1s's avatar
rusty1s committed
17
    for (ptrdiff_t n = start_idx + idx; n < end_idx; n += THREADS) {
rusty1s's avatar
rusty1s committed
18
      scalar_t d = x[old] - x[n];
rusty1s's avatar
typo  
rusty1s committed
19
      dist[n] = min(dist[n], d * d);
rusty1s's avatar
rusty1s committed
20
21
22
23
24
25
26
      if (dist[n] > *best) {
        *best = dist[n];
        *best_idx = n;
      }
    }
  }
};
rusty1s's avatar
rusty1s committed
27

rusty1s's avatar
rusty1s committed
28
29
30
31
template <typename scalar_t> struct Dist<scalar_t, 2> {
  static __device__ void
  compute(ptrdiff_t idx, ptrdiff_t start_idx, ptrdiff_t end_idx, ptrdiff_t old,
          scalar_t *__restrict__ best, ptrdiff_t *__restrict__ best_idx,
rusty1s's avatar
typos  
rusty1s committed
32
          const scalar_t *__restrict__ x, scalar_t *__restrict__ dist,
rusty1s's avatar
rusty1s committed
33
          scalar_t *__restrict__ tmp_dist, size_t dim) {
rusty1s's avatar
rusty1s committed
34

rusty1s's avatar
rusty1s committed
35
    for (ptrdiff_t n = start_idx + idx; n < end_idx; n += THREADS) {
rusty1s's avatar
rusty1s committed
36
37
      scalar_t a = x[2 * old + 0] - x[2 * n + 0];
      scalar_t b = x[2 * old + 1] - x[2 * n + 1];
rusty1s's avatar
rusty1s committed
38
39
40
41
42
43
44
      scalar_t d = a * a + b * b;
      dist[n] = min(dist[n], d);
      if (dist[n] > *best) {
        *best = dist[n];
        *best_idx = n;
      }
    }
rusty1s's avatar
rusty1s committed
45
  }
rusty1s's avatar
rusty1s committed
46
};
rusty1s's avatar
rusty1s committed
47

rusty1s's avatar
rusty1s committed
48
49
50
51
template <typename scalar_t> struct Dist<scalar_t, 3> {
  static __device__ void
  compute(ptrdiff_t idx, ptrdiff_t start_idx, ptrdiff_t end_idx, ptrdiff_t old,
          scalar_t *__restrict__ best, ptrdiff_t *__restrict__ best_idx,
rusty1s's avatar
typos  
rusty1s committed
52
          const scalar_t *__restrict__ x, scalar_t *__restrict__ dist,
rusty1s's avatar
rusty1s committed
53
54
55
          scalar_t *__restrict__ tmp_dist, size_t dim) {

    for (ptrdiff_t n = start_idx + idx; n < end_idx; n += THREADS) {
rusty1s's avatar
rusty1s committed
56
57
58
      scalar_t a = x[3 * old + 0] - x[3 * n + 0];
      scalar_t b = x[3 * old + 1] - x[3 * n + 1];
      scalar_t c = x[3 * old + 2] - x[3 * n + 2];
rusty1s's avatar
rusty1s committed
59
60
61
62
63
64
65
66
67
68
69
70
71
72
      scalar_t d = a * a + b * b + c * c;
      dist[n] = min(dist[n], d);
      if (dist[n] > *best) {
        *best = dist[n];
        *best_idx = n;
      }
    }
  }
};

template <typename scalar_t> struct Dist<scalar_t, -1> {
  static __device__ void
  compute(ptrdiff_t idx, ptrdiff_t start_idx, ptrdiff_t end_idx, ptrdiff_t old,
          scalar_t *__restrict__ best, ptrdiff_t *__restrict__ best_idx,
rusty1s's avatar
typos  
rusty1s committed
73
          const scalar_t *__restrict__ x, scalar_t *__restrict__ dist,
rusty1s's avatar
rusty1s committed
74
          scalar_t *__restrict__ tmp_dist, size_t dim) {
rusty1s's avatar
rusty1s committed
75

rusty1s's avatar
rusty1s committed
76
    for (ptrdiff_t n = start_idx + idx; n < end_idx; n += THREADS) {
rusty1s's avatar
rusty1s committed
77
78
79
80
      tmp_dist[n] = 0;
    }

    __syncthreads();
rusty1s's avatar
rusty1s committed
81
    for (ptrdiff_t i = start_idx * dim + idx; i < end_idx * dim; i += THREADS) {
rusty1s's avatar
rusty1s committed
82
      scalar_t d = x[(old * dim) + (i % dim)] - x[i];
rusty1s's avatar
rusty1s committed
83
84
85
86
      atomicAdd(&tmp_dist[i / dim], d * d);
    }

    __syncthreads();
rusty1s's avatar
rusty1s committed
87
    for (ptrdiff_t n = start_idx + idx; n < end_idx; n += THREADS) {
rusty1s's avatar
rusty1s committed
88
      dist[n] = min(dist[n], tmp_dist[n]);
rusty1s's avatar
rusty1s committed
89
90
91
      if (dist[n] > *best) {
        *best = dist[n];
        *best_idx = n;
rusty1s's avatar
rusty1s committed
92
      }
rusty1s's avatar
rusty1s committed
93
    }
rusty1s's avatar
rusty1s committed
94
95
96
97
98
  }
};

template <typename scalar_t, int64_t Dim>
__global__ void
rusty1s's avatar
typos  
rusty1s committed
99
100
fps_kernel(const scalar_t *__restrict__ x, const int64_t *__restrict__ cum_deg,
           const int64_t *__restrict__ cum_k, const int64_t *__restrict__ start,
rusty1s's avatar
rusty1s committed
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
           scalar_t *__restrict__ dist, scalar_t *__restrict__ tmp_dist,
           int64_t *__restrict__ out, size_t dim) {

  const ptrdiff_t batch_idx = blockIdx.x;
  const ptrdiff_t idx = threadIdx.x;

  const ptrdiff_t start_idx = cum_deg[batch_idx];
  const ptrdiff_t end_idx = cum_deg[batch_idx + 1];

  __shared__ scalar_t best_dist[THREADS];
  __shared__ int64_t best_dist_idx[THREADS];

  if (idx == 0) {
    out[cum_k[batch_idx]] = start_idx + start[batch_idx];
  }

  for (ptrdiff_t m = cum_k[batch_idx] + 1; m < cum_k[batch_idx + 1]; m++) {
    scalar_t best = -1;
    ptrdiff_t best_idx = 0;

rusty1s's avatar
rusty1s committed
121
    __syncthreads();
rusty1s's avatar
rusty1s committed
122
123
    Dist<scalar_t, Dim>::compute(idx, start_idx, end_idx, out[m - 1], &best,
                                 &best_idx, x, dist, tmp_dist, dim);
rusty1s's avatar
rusty1s committed
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140

    best_dist[idx] = best;
    best_dist_idx[idx] = best_idx;

    for (int64_t u = 0; (1 << u) < THREADS; u++) {
      __syncthreads();
      if (idx < (THREADS >> (u + 1))) {
        int64_t idx_1 = (idx * 2) << u;
        int64_t idx_2 = (idx * 2 + 1) << u;
        if (best_dist[idx_1] < best_dist[idx_2]) {
          best_dist[idx_1] = best_dist[idx_2];
          best_dist_idx[idx_1] = best_dist_idx[idx_2];
        }
      }
    }

    __syncthreads();
rusty1s's avatar
rusty1s committed
141
142
143
    if (idx == 0) {
      out[m] = best_dist_idx[0];
    }
rusty1s's avatar
rusty1s committed
144
  }
rusty1s's avatar
rusty1s committed
145
146
}

rusty1s's avatar
rusty1s committed
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
#define FPS_KERNEL(DIM, ...)                                                   \
  [&] {                                                                        \
    switch (DIM) {                                                             \
    case 1:                                                                    \
      fps_kernel<scalar_t, 1><<<batch_size, THREADS>>>(__VA_ARGS__, DIM);      \
      break;                                                                   \
    case 2:                                                                    \
      fps_kernel<scalar_t, 2><<<batch_size, THREADS>>>(__VA_ARGS__, DIM);      \
      break;                                                                   \
    case 3:                                                                    \
      fps_kernel<scalar_t, 3><<<batch_size, THREADS>>>(__VA_ARGS__, DIM);      \
      break;                                                                   \
    default:                                                                   \
      fps_kernel<scalar_t, -1><<<batch_size, THREADS>>>(__VA_ARGS__, DIM);     \
    }                                                                          \
  }()

rusty1s's avatar
rusty1s committed
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
at::Tensor fps_cuda(at::Tensor x, at::Tensor batch, float ratio, bool random) {
  auto batch_sizes = (int64_t *)malloc(sizeof(int64_t));
  cudaMemcpy(batch_sizes, batch[-1].data<int64_t>(), sizeof(int64_t),
             cudaMemcpyDeviceToHost);
  auto batch_size = batch_sizes[0] + 1;

  auto deg = degree(batch, batch_size);
  auto cum_deg = at::cat({at::zeros(1, deg.options()), deg.cumsum(0)}, 0);
  auto k = (deg.toType(at::kFloat) * ratio).round().toType(at::kLong);
  auto cum_k = at::cat({at::zeros(1, k.options()), k.cumsum(0)}, 0);

  at::Tensor start;
  if (random) {
    start = at::rand(batch_size, x.options());
    start = (start * deg.toType(at::kFloat)).toType(at::kLong);
  } else {
    start = at::zeros(batch_size, k.options());
  }

rusty1s's avatar
rusty1s committed
183
184
  auto dist = at::full(x.size(0), 1e38, x.options());
  auto tmp_dist = at::empty(x.size(0), x.options());
rusty1s's avatar
rusty1s committed
185
186
187
188
189
190
191

  auto k_sum = (int64_t *)malloc(sizeof(int64_t));
  cudaMemcpy(k_sum, cum_k[-1].data<int64_t>(), sizeof(int64_t),
             cudaMemcpyDeviceToHost);
  auto out = at::empty(k_sum[0], k.options());

  AT_DISPATCH_FLOATING_TYPES(x.type(), "fps_kernel", [&] {
rusty1s's avatar
rusty1s committed
192
193
194
195
    FPS_KERNEL(x.size(1), x.data<scalar_t>(), cum_deg.data<int64_t>(),
               cum_k.data<int64_t>(), start.data<int64_t>(),
               dist.data<scalar_t>(), tmp_dist.data<scalar_t>(),
               out.data<int64_t>());
rusty1s's avatar
rusty1s committed
196
197
  });

rusty1s's avatar
rusty1s committed
198
  return out;
rusty1s's avatar
rusty1s committed
199
}