three_nn_cuda.cu 2.45 KB
Newer Older
wuyuefeng's avatar
Credit  
wuyuefeng committed
1
2
3
// Modified from
// https://github.com/sshaoshuai/Pointnet2.PyTorch/tree/master/pointnet2/src/interpolate_gpu.cu

wuyuefeng's avatar
wuyuefeng committed
4
5
6
7
8
#include <math.h>
#include <stdio.h>
#include <stdlib.h>

#define THREADS_PER_BLOCK 256
zhangwenwei's avatar
zhangwenwei committed
9
#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0))
wuyuefeng's avatar
wuyuefeng committed
10

zhangwenwei's avatar
zhangwenwei committed
11
12
13
14
15
16
17
18
19
20
__global__ void three_nn_kernel(int b, int n, int m,
                                const float *__restrict__ unknown,
                                const float *__restrict__ known,
                                float *__restrict__ dist2,
                                int *__restrict__ idx) {
  // unknown: (B, N, 3)
  // known: (B, M, 3)
  // output:
  //      dist2: (B, N, 3)
  //      idx: (B, N, 3)
wuyuefeng's avatar
wuyuefeng committed
21

zhangwenwei's avatar
zhangwenwei committed
22
23
24
  int bs_idx = blockIdx.y;
  int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
  if (bs_idx >= b || pt_idx >= n) return;
wuyuefeng's avatar
wuyuefeng committed
25

zhangwenwei's avatar
zhangwenwei committed
26
27
28
29
  unknown += bs_idx * n * 3 + pt_idx * 3;
  known += bs_idx * m * 3;
  dist2 += bs_idx * n * 3 + pt_idx * 3;
  idx += bs_idx * n * 3 + pt_idx * 3;
wuyuefeng's avatar
wuyuefeng committed
30

zhangwenwei's avatar
zhangwenwei committed
31
32
33
  float ux = unknown[0];
  float uy = unknown[1];
  float uz = unknown[2];
wuyuefeng's avatar
wuyuefeng committed
34

zhangwenwei's avatar
zhangwenwei committed
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
  double best1 = 1e40, best2 = 1e40, best3 = 1e40;
  int besti1 = 0, besti2 = 0, besti3 = 0;
  for (int k = 0; k < m; ++k) {
    float x = known[k * 3 + 0];
    float y = known[k * 3 + 1];
    float z = known[k * 3 + 2];
    float d = (ux - x) * (ux - x) + (uy - y) * (uy - y) + (uz - z) * (uz - z);
    if (d < best1) {
      best3 = best2;
      besti3 = besti2;
      best2 = best1;
      besti2 = besti1;
      best1 = d;
      besti1 = k;
    } else if (d < best2) {
      best3 = best2;
      besti3 = besti2;
      best2 = d;
      besti2 = k;
    } else if (d < best3) {
      best3 = d;
      besti3 = k;
wuyuefeng's avatar
wuyuefeng committed
57
    }
zhangwenwei's avatar
zhangwenwei committed
58
59
60
61
62
63
64
  }
  dist2[0] = best1;
  dist2[1] = best2;
  dist2[2] = best3;
  idx[0] = besti1;
  idx[1] = besti2;
  idx[2] = besti3;
wuyuefeng's avatar
wuyuefeng committed
65
66
67
}

void three_nn_kernel_launcher(int b, int n, int m, const float *unknown,
zhangwenwei's avatar
zhangwenwei committed
68
69
70
71
72
73
74
                              const float *known, float *dist2, int *idx,
                              cudaStream_t stream) {
  // unknown: (B, N, 3)
  // known: (B, M, 3)
  // output:
  //      dist2: (B, N, 3)
  //      idx: (B, N, 3)
wuyuefeng's avatar
wuyuefeng committed
75

zhangwenwei's avatar
zhangwenwei committed
76
77
78
79
  cudaError_t err;
  dim3 blocks(DIVUP(n, THREADS_PER_BLOCK),
              b);  // blockIdx.x(col), blockIdx.y(row)
  dim3 threads(THREADS_PER_BLOCK);
wuyuefeng's avatar
wuyuefeng committed
80

zhangwenwei's avatar
zhangwenwei committed
81
82
  three_nn_kernel<<<blocks, threads, 0, stream>>>(b, n, m, unknown, known,
                                                  dist2, idx);
wuyuefeng's avatar
wuyuefeng committed
83

zhangwenwei's avatar
zhangwenwei committed
84
85
86
87
88
  err = cudaGetLastError();
  if (cudaSuccess != err) {
    fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
    exit(-1);
  }
wuyuefeng's avatar
wuyuefeng committed
89
}