fps_cuda.cu 3.24 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
#include "fps_cuda.h"

rusty1s's avatar
rusty1s committed
3
4
#include <ATen/cuda/CUDAContext.h>

rusty1s's avatar
rusty1s committed
5
6
#include "utils.cuh"

rusty1s's avatar
rusty1s committed
7
8
9
#define THREADS 1024

template <typename scalar_t>
rusty1s's avatar
rusty1s committed
10
__global__ void fps_kernel(const scalar_t *src, const int64_t *ptr,
rusty1s's avatar
rusty1s committed
11
                           const int64_t *out_ptr, const int64_t *start,
rusty1s's avatar
rusty1s committed
12
                           scalar_t *dist, int64_t *out, int64_t dim) {
rusty1s's avatar
rusty1s committed
13
14

  const int64_t thread_idx = threadIdx.x;
rusty1s's avatar
rusty1s committed
15
  const int64_t batch_idx = blockIdx.x;
rusty1s's avatar
rusty1s committed
16
17
18
19
20
21
22

  const int64_t start_idx = ptr[batch_idx];
  const int64_t end_idx = ptr[batch_idx + 1];

  __shared__ scalar_t best_dist[THREADS];
  __shared__ int64_t best_dist_idx[THREADS];

rusty1s's avatar
rusty1s committed
23
  if (thread_idx == 0) {
rusty1s's avatar
rusty1s committed
24
25
26
27
    out[out_ptr[batch_idx]] = start_idx + start[batch_idx];
  }

  for (int64_t m = out_ptr[batch_idx] + 1; m < out_ptr[batch_idx + 1]; m++) {
rusty1s's avatar
rusty1s committed
28
29
30
    int64_t old = out[m - 1];

    scalar_t best = (scalar_t)-1.;
rusty1s's avatar
rusty1s committed
31
32
    int64_t best_idx = 0;

rusty1s's avatar
rusty1s committed
33
34
35
36
37
38
39
40
41
42
43
44
45
    for (int64_t n = start_idx + thread_idx; n < end_idx; n += THREADS) {
      scalar_t tmp;
      scalar_t dd = (scalar_t)0.;
      for (int64_t d = 0; d < dim; d++) {
        tmp = src[dim * old + d] - src[dim * n + d];
        dd += tmp * tmp;
      }
      dist[n] = min(dist[n], dd);
      if (dist[n] > best) {
        best = dist[n];
        best_idx = n;
      }
    }
rusty1s's avatar
rusty1s committed
46

rusty1s's avatar
rusty1s committed
47
48
    best_dist[thread_idx] = best;
    best_dist_idx[thread_idx] = best_idx;
rusty1s's avatar
rusty1s committed
49

rusty1s's avatar
rusty1s committed
50
    for (int64_t i = 1; i < THREADS; i *= 2) {
rusty1s's avatar
rusty1s committed
51
      __syncthreads();
rusty1s's avatar
rusty1s committed
52
53
54
55
      if ((thread_idx + i) < THREADS &&
          best_dist[thread_idx] < best_dist[thread_idx + i]) {
        best_dist[thread_idx] = best_dist[thread_idx + i];
        best_dist_idx[thread_idx] = best_dist_idx[thread_idx + i];
rusty1s's avatar
rusty1s committed
56
57
58
59
      }
    }

    __syncthreads();
rusty1s's avatar
rusty1s committed
60
    if (thread_idx == 0) {
rusty1s's avatar
rusty1s committed
61
62
63
64
65
      out[m] = best_dist_idx[0];
    }
  }
}

rusty1s's avatar
rusty1s committed
66
67
68
69
70
71
torch::Tensor fps_cuda(torch::Tensor src, torch::Tensor ptr, double ratio,
                       bool random_start) {

  CHECK_CUDA(src);
  CHECK_CUDA(ptr);
  CHECK_INPUT(ptr.dim() == 1);
rusty1s's avatar
typo  
rusty1s committed
72
  AT_ASSERTM(ratio > 0 && ratio < 1, "Invalid input");
rusty1s's avatar
rusty1s committed
73
  cudaSetDevice(src.get_device());
rusty1s's avatar
rusty1s committed
74
75
76
77
78
79
80
81

  src = src.view({src.size(0), -1}).contiguous();
  ptr = ptr.contiguous();
  auto batch_size = ptr.size(0) - 1;

  auto deg = ptr.narrow(0, 1, batch_size) - ptr.narrow(0, 0, batch_size);
  auto out_ptr = deg.toType(torch::kFloat) * (float)ratio;
  out_ptr = out_ptr.ceil().toType(torch::kLong).cumsum(0);
rusty1s's avatar
rusty1s committed
82
  out_ptr = torch::cat({torch::zeros(1, ptr.options()), out_ptr}, 0);
rusty1s's avatar
rusty1s committed
83
84
85

  torch::Tensor start;
  if (random_start) {
rusty1s's avatar
rusty1s committed
86
    start = torch::rand(batch_size, src.options());
rusty1s's avatar
rusty1s committed
87
88
89
90
91
    start = (start * deg.toType(torch::kFloat)).toType(torch::kLong);
  } else {
    start = torch::zeros(batch_size, ptr.options());
  }

rusty1s's avatar
rusty1s committed
92
93
94
95
96
  auto dist = torch::full(src.size(0), 1e38, src.options());

  auto out_size = (int64_t *)malloc(sizeof(int64_t));
  cudaMemcpy(out_size, out_ptr[-1].data_ptr<int64_t>(), sizeof(int64_t),
             cudaMemcpyDeviceToHost);
rusty1s's avatar
rusty1s committed
97
  auto out = torch::empty(out_size[0], out_ptr.options());
rusty1s's avatar
rusty1s committed
98

rusty1s's avatar
rusty1s committed
99
100
101
  auto stream = at::cuda::getCurrentCUDAStream();
  AT_DISPATCH_FLOATING_TYPES(src.scalar_type(), "fps_kernel", [&] {
    fps_kernel<scalar_t><<<batch_size, THREADS, 0, stream>>>(
rusty1s's avatar
rusty1s committed
102
        src.data_ptr<scalar_t>(), ptr.data_ptr<int64_t>(),
rusty1s's avatar
rusty1s committed
103
        out_ptr.data_ptr<int64_t>(), start.data_ptr<int64_t>(),
rusty1s's avatar
rusty1s committed
104
        dist.data_ptr<scalar_t>(), out.data_ptr<int64_t>(), src.size(1));
rusty1s's avatar
rusty1s committed
105
  });
rusty1s's avatar
rusty1s committed
106
107
108

  return out;
}