fps_kernel.cu 5 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
#include <ATen/ATen.h>

#include "utils.cuh"

#define THREADS 1024

template <typename scalar_t>
__global__ void
fps_kernel(scalar_t *__restrict__ x, int64_t *__restrict__ cum_deg,
           int64_t *__restrict__ cum_k, int64_t *__restrict__ start,
           scalar_t *__restrict__ tmp_dist, int64_t *__restrict__ out) {
  // const size_t index = blockIdx.x * blockDim.x + threadIdx.x;
  // const size_t stride = blockDim.x * gridDim.x;
  // for (ptrdiff_t i = index; i < numel; i += stride) {
  // }
}

at::Tensor fps_cuda(at::Tensor x, at::Tensor batch, float ratio, bool random) {
  auto batch_sizes = (int64_t *)malloc(sizeof(int64_t));
  cudaMemcpy(batch_sizes, batch[-1].data<int64_t>(), sizeof(int64_t),
             cudaMemcpyDeviceToHost);
  auto batch_size = batch_sizes[0] + 1;

  auto deg = degree(batch, batch_size);
  auto cum_deg = at::cat({at::zeros(1, deg.options()), deg.cumsum(0)}, 0);
  auto k = (deg.toType(at::kFloat) * ratio).round().toType(at::kLong);
  auto cum_k = at::cat({at::zeros(1, k.options()), k.cumsum(0)}, 0);

  at::Tensor start;
  if (random) {
    start = at::rand(batch_size, x.options());
    start = (start * deg.toType(at::kFloat)).toType(at::kLong);
  } else {
    start = at::zeros(batch_size, k.options());
  }

  auto tmp_dist = at::full(x.size(0), 1e38, x.options());

  auto k_sum = (int64_t *)malloc(sizeof(int64_t));
  cudaMemcpy(k_sum, cum_k[-1].data<int64_t>(), sizeof(int64_t),
             cudaMemcpyDeviceToHost);
  auto out = at::empty(k_sum[0], k.options());

  AT_DISPATCH_FLOATING_TYPES(x.type(), "fps_kernel", [&] {
    fps_kernel<scalar_t><<<batch_size, THREADS>>>(
        x.data<scalar_t>(), cum_deg.data<int64_t>(), cum_k.data<int64_t>(),
        start.data<int64_t>(), tmp_dist.data<scalar_t>(), out.data<int64_t>());
  });

  return out;
}

// at::Tensor ifp_cuda(at::Tensor x, at::Tensor batch, float ratio) {
//   AT_DISPATCH_FLOATING_TYPES(x.type(), "ifp_kernel", [&] {
//     ifp_kernel<scalar_t><<<BLOCKS(x.numel()), THREADS>>>(
//         x.data<scalar_t>(), batch.data<int64_t>(), ratio, x.numel());
//   });

//   return x;
// }

// __global__ void ifps_kernel() {}

// // x: [N, F]
// // count: [B]
// // batch: [N]
// // tmp min distances: [N]
// // start node idx

// // we parallelize over n times f
// // parallelization over n times f: We can compute distances over atomicAdd
// // each block corresponds to a batch

// __global__ void farthestpointsamplingKernel(int b, int n, int m,
//                                             const float *__restrict__
//                                             dataset, float *__restrict__
//                                             temp, int *__restrict__ idxs) {
//   // dataset: [N*3] entries
//   // b: batch-size
//   // n: number of nodes
//   // m: number of sample points

//   if (m <= 0)
//     return;
//   const int BlockSize = 512;
//   __shared__ float dists[BlockSize];
//   __shared__ int dists_i[BlockSize];
//   const int BufferSize = 3072;
//   __shared__ float buf[BufferSize * 3];
//   for (int i = blockIdx.x; i < b; i += gridDim.x) { // iterate over all
//   batches?
//     int old = 0;
//     if (threadIdx.x == 0)
//       idxs[i * m + 0] = old;
//     for (int j = threadIdx.x; j < n; j += blockDim.x) { // iterate over all n
//       temp[blockIdx.x * n + j] = 1e38;
//     }
//     for (int j = threadIdx.x; j < min(BufferSize, n) * 3; j += blockDim.x) {
//       buf[j] = dataset[i * n * 3 + j];
//     }
//     __syncthreads();
//     for (int j = 1; j < m; j++) {
//       int besti = 0;
//       float best = -1;
//       float x1 = dataset[i * n * 3 + old * 3 + 0];
//       float y1 = dataset[i * n * 3 + old * 3 + 1];
//       float z1 = dataset[i * n * 3 + old * 3 + 2];
//       for (int k = threadIdx.x; k < n; k += blockDim.x) {
//         float td = temp[blockIdx.x * n + k];
//         float x2, y2, z2;
//         if (k < BufferSize) {
//           x2 = buf[k * 3 + 0];
//           y2 = buf[k * 3 + 1];
//           z2 = buf[k * 3 + 2];
//         } else {
//           x2 = dataset[i * n * 3 + k * 3 + 0];
//           y2 = dataset[i * n * 3 + k * 3 + 1];
//           z2 = dataset[i * n * 3 + k * 3 + 2];
//         }
//         float d = (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) +
//                   (z2 - z1) * (z2 - z1);
//         float d2 = min(d, td);
//         if (d2 != td)
//           temp[blockIdx.x * n + k] = d2;
//         if (d2 > best) {
//           best = d2;
//           besti = k;
//         }
//       }
//       dists[threadIdx.x] = best;
//       dists_i[threadIdx.x] = besti;
//       for (int u = 0; (1 << u) < blockDim.x; u++) {
//         __syncthreads();
//         if (threadIdx.x < (blockDim.x >> (u + 1))) {
//           int i1 = (threadIdx.x * 2) << u;
//           int i2 = (threadIdx.x * 2 + 1) << u;
//           if (dists[i1] < dists[i2]) {
//             dists[i1] = dists[i2];
//             dists_i[i1] = dists_i[i2];
//           }
//         }
//       }
//       __syncthreads();
//       old = dists_i[0];
//       if (threadIdx.x == 0)
//         idxs[i * m + j] = old;
//     }
//   }
// }