ball_query_cuda.cu 2.47 KB
Newer Older
wuyuefeng's avatar
Credit  
wuyuefeng committed
1
2
3
// Modified from
// https://github.com/sshaoshuai/Pointnet2.PyTorch/tree/master/pointnet2/src/ball_query_gpu.cu

wuyuefeng's avatar
wuyuefeng committed
4
5
6
7
8
#include <math.h>
#include <stdio.h>
#include <stdlib.h>

#define THREADS_PER_BLOCK 256
zhangwenwei's avatar
zhangwenwei committed
9
#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0))
wuyuefeng's avatar
wuyuefeng committed
10

11
12
13
__global__ void ball_query_kernel(int b, int n, int m,
                                  float min_radius,
                                  float max_radius,
zhangwenwei's avatar
zhangwenwei committed
14
15
16
17
18
19
20
21
22
23
24
                                  int nsample,
                                  const float *__restrict__ new_xyz,
                                  const float *__restrict__ xyz,
                                  int *__restrict__ idx) {
  // new_xyz: (B, M, 3)
  // xyz: (B, N, 3)
  // output:
  //      idx: (B, M, nsample)
  int bs_idx = blockIdx.y;
  int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
  if (bs_idx >= b || pt_idx >= m) return;
wuyuefeng's avatar
wuyuefeng committed
25

zhangwenwei's avatar
zhangwenwei committed
26
27
28
  new_xyz += bs_idx * m * 3 + pt_idx * 3;
  xyz += bs_idx * n * 3;
  idx += bs_idx * m * nsample + pt_idx * nsample;
wuyuefeng's avatar
wuyuefeng committed
29

30
31
  float max_radius2 = max_radius * max_radius;
  float min_radius2 = min_radius * min_radius;
zhangwenwei's avatar
zhangwenwei committed
32
33
34
  float new_x = new_xyz[0];
  float new_y = new_xyz[1];
  float new_z = new_xyz[2];
wuyuefeng's avatar
wuyuefeng committed
35

zhangwenwei's avatar
zhangwenwei committed
36
37
38
39
40
41
42
  int cnt = 0;
  for (int k = 0; k < n; ++k) {
    float x = xyz[k * 3 + 0];
    float y = xyz[k * 3 + 1];
    float z = xyz[k * 3 + 2];
    float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) +
               (new_z - z) * (new_z - z);
43
    if (d2 == 0 || (d2 >= min_radius2 && d2 < max_radius2)) {
zhangwenwei's avatar
zhangwenwei committed
44
45
46
      if (cnt == 0) {
        for (int l = 0; l < nsample; ++l) {
          idx[l] = k;
wuyuefeng's avatar
wuyuefeng committed
47
        }
zhangwenwei's avatar
zhangwenwei committed
48
49
50
51
      }
      idx[cnt] = k;
      ++cnt;
      if (cnt >= nsample) break;
wuyuefeng's avatar
wuyuefeng committed
52
    }
zhangwenwei's avatar
zhangwenwei committed
53
  }
wuyuefeng's avatar
wuyuefeng committed
54
55
}

56
57
void ball_query_kernel_launcher(int b, int n, int m, float min_radius, float max_radius,
                                int nsample, const float *new_xyz, const float *xyz,
zhangwenwei's avatar
zhangwenwei committed
58
59
60
61
62
                                int *idx, cudaStream_t stream) {
  // new_xyz: (B, M, 3)
  // xyz: (B, N, 3)
  // output:
  //      idx: (B, M, nsample)
wuyuefeng's avatar
wuyuefeng committed
63

zhangwenwei's avatar
zhangwenwei committed
64
  cudaError_t err;
wuyuefeng's avatar
wuyuefeng committed
65

zhangwenwei's avatar
zhangwenwei committed
66
67
68
  dim3 blocks(DIVUP(m, THREADS_PER_BLOCK),
              b);  // blockIdx.x(col), blockIdx.y(row)
  dim3 threads(THREADS_PER_BLOCK);
wuyuefeng's avatar
wuyuefeng committed
69

70
71
  ball_query_kernel<<<blocks, threads, 0, stream>>>(b, n, m, min_radius, max_radius,
                                                    nsample, new_xyz, xyz, idx);
zhangwenwei's avatar
zhangwenwei committed
72
73
74
75
76
77
  // cudaDeviceSynchronize();  // for using printf in kernel function
  err = cudaGetLastError();
  if (cudaSuccess != err) {
    fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
    exit(-1);
  }
wuyuefeng's avatar
wuyuefeng committed
78
}