fps_kernel.cu 3.22 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
#include <ATen/ATen.h>

rusty1s's avatar
rusty1s committed
3
#include "atomics.cuh"
rusty1s's avatar
rusty1s committed
4
5
6
7
8
9
10
11
#include "utils.cuh"

#define THREADS 1024

template <typename scalar_t>
__global__ void
fps_kernel(scalar_t *__restrict__ x, int64_t *__restrict__ cum_deg,
           int64_t *__restrict__ cum_k, int64_t *__restrict__ start,
rusty1s's avatar
rusty1s committed
12
13
14
15
16
17
18
19
20
           scalar_t *__restrict__ dist, scalar_t *__restrict__ tmp_dist,
           int64_t *__restrict__ out, size_t dim) {

  const size_t batch_idx = blockIdx.x;
  const size_t idx = threadIdx.x;

  const size_t start_idx = cum_deg[batch_idx];
  const size_t end_idx = cum_deg[batch_idx + 1];

rusty1s's avatar
rusty1s committed
21
22
  __shared__ scalar_t best_dist[THREADS];
  __shared__ int64_t best_dist_idx[THREADS];
rusty1s's avatar
rusty1s committed
23
24

  if (idx == 0) {
rusty1s's avatar
rusty1s committed
25
    out[cum_k[batch_idx]] = start_idx + start[batch_idx];
rusty1s's avatar
rusty1s committed
26
27
28
  }

  for (ptrdiff_t m = cum_k[batch_idx] + 1; m < cum_k[batch_idx + 1]; m++) {
rusty1s's avatar
rusty1s committed
29
30
    ptrdiff_t best_idx = 0;
    scalar_t best = -1;
rusty1s's avatar
rusty1s committed
31

rusty1s's avatar
rusty1s committed
32
    for (ptrdiff_t n = start_idx + idx; n < end_idx; n += THREADS) {
rusty1s's avatar
rusty1s committed
33
34
35
36
      tmp_dist[n] = 0;
    }

    __syncthreads();
rusty1s's avatar
rusty1s committed
37
38
    for (ptrdiff_t i = start_idx * dim + idx; i < end_idx * dim; i += THREADS) {
      scalar_t d = x[(out[m - 1] * dim) + (i % dim)] - x[i];
rusty1s's avatar
rusty1s committed
39
40
41
42
      atomicAdd(&tmp_dist[i / dim], d * d);
    }

    __syncthreads();
rusty1s's avatar
rusty1s committed
43
    for (ptrdiff_t n = start_idx + idx; n < end_idx; n += THREADS) {
rusty1s's avatar
rusty1s committed
44
      dist[n] = min(dist[n], tmp_dist[n]);
rusty1s's avatar
rusty1s committed
45
46
47
48
      if (dist[n] > best) {
        best = dist[n];
        best_idx = n;
      }
rusty1s's avatar
rusty1s committed
49
    }
rusty1s's avatar
rusty1s committed
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67

    best_dist[idx] = best;
    best_dist_idx[idx] = best_idx;

    for (int64_t u = 0; (1 << u) < THREADS; u++) {
      __syncthreads();
      if (idx < (THREADS >> (u + 1))) {
        int64_t idx_1 = (idx * 2) << u;
        int64_t idx_2 = (idx * 2 + 1) << u;
        if (best_dist[idx_1] < best_dist[idx_2]) {
          best_dist[idx_1] = best_dist[idx_2];
          best_dist_idx[idx_1] = best_dist_idx[idx_2];
        }
      }
    }

    __syncthreads();
    out[m] = best_dist_idx[0];
rusty1s's avatar
rusty1s committed
68
  }
rusty1s's avatar
rusty1s committed
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
}

at::Tensor fps_cuda(at::Tensor x, at::Tensor batch, float ratio, bool random) {
  auto batch_sizes = (int64_t *)malloc(sizeof(int64_t));
  cudaMemcpy(batch_sizes, batch[-1].data<int64_t>(), sizeof(int64_t),
             cudaMemcpyDeviceToHost);
  auto batch_size = batch_sizes[0] + 1;

  auto deg = degree(batch, batch_size);
  auto cum_deg = at::cat({at::zeros(1, deg.options()), deg.cumsum(0)}, 0);
  auto k = (deg.toType(at::kFloat) * ratio).round().toType(at::kLong);
  auto cum_k = at::cat({at::zeros(1, k.options()), k.cumsum(0)}, 0);

  at::Tensor start;
  if (random) {
    start = at::rand(batch_size, x.options());
    start = (start * deg.toType(at::kFloat)).toType(at::kLong);
  } else {
    start = at::zeros(batch_size, k.options());
  }

rusty1s's avatar
rusty1s committed
90
91
  auto dist = at::full(x.size(0), 1e38, x.options());
  auto tmp_dist = at::empty(x.size(0), x.options());
rusty1s's avatar
rusty1s committed
92
93
94
95
96
97
98
99
100

  auto k_sum = (int64_t *)malloc(sizeof(int64_t));
  cudaMemcpy(k_sum, cum_k[-1].data<int64_t>(), sizeof(int64_t),
             cudaMemcpyDeviceToHost);
  auto out = at::empty(k_sum[0], k.options());

  AT_DISPATCH_FLOATING_TYPES(x.type(), "fps_kernel", [&] {
    fps_kernel<scalar_t><<<batch_size, THREADS>>>(
        x.data<scalar_t>(), cum_deg.data<int64_t>(), cum_k.data<int64_t>(),
rusty1s's avatar
rusty1s committed
101
102
        start.data<int64_t>(), dist.data<scalar_t>(), tmp_dist.data<scalar_t>(),
        out.data<int64_t>(), x.size(1));
rusty1s's avatar
rusty1s committed
103
104
  });

rusty1s's avatar
rusty1s committed
105
  return out;
rusty1s's avatar
rusty1s committed
106
}