gather_points_cuda.cu 3.03 KB
Newer Older
wuyuefeng's avatar
wuyuefeng committed
1
2
3
4
5
#include <stdio.h>
#include <stdlib.h>

#define TOTAL_THREADS 1024
#define THREADS_PER_BLOCK 256
zhangwenwei's avatar
zhangwenwei committed
6
#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0))
wuyuefeng's avatar
wuyuefeng committed
7
8

__global__ void gather_points_kernel(int b, int c, int n, int m,
zhangwenwei's avatar
zhangwenwei committed
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
                                     const float *__restrict__ points,
                                     const int *__restrict__ idx,
                                     float *__restrict__ out) {
  // points: (B, C, N)
  // idx: (B, M)
  // output:
  //      out: (B, C, M)

  int bs_idx = blockIdx.z;
  int c_idx = blockIdx.y;
  int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
  if (bs_idx >= b || c_idx >= c || pt_idx >= m) return;

  out += bs_idx * c * m + c_idx * m + pt_idx;
  idx += bs_idx * m + pt_idx;
  points += bs_idx * c * n + c_idx * n;
  out[0] = points[idx[0]];
wuyuefeng's avatar
wuyuefeng committed
26
27
28
}

void gather_points_kernel_launcher(int b, int c, int n, int npoints,
zhangwenwei's avatar
zhangwenwei committed
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
                                   const float *points, const int *idx,
                                   float *out, cudaStream_t stream) {
  // points: (B, C, N)
  // idx: (B, npoints)
  // output:
  //      out: (B, C, npoints)

  cudaError_t err;
  dim3 blocks(DIVUP(npoints, THREADS_PER_BLOCK), c,
              b);  // blockIdx.x(col), blockIdx.y(row)
  dim3 threads(THREADS_PER_BLOCK);

  gather_points_kernel<<<blocks, threads, 0, stream>>>(b, c, n, npoints, points,
                                                       idx, out);

  err = cudaGetLastError();
  if (cudaSuccess != err) {
    fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
    exit(-1);
  }
wuyuefeng's avatar
wuyuefeng committed
49
50
}

zhangwenwei's avatar
zhangwenwei committed
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
__global__ void gather_points_grad_kernel(int b, int c, int n, int m,
                                          const float *__restrict__ grad_out,
                                          const int *__restrict__ idx,
                                          float *__restrict__ grad_points) {
  // grad_out: (B, C, M)
  // idx: (B, M)
  // output:
  //      grad_points: (B, C, N)

  int bs_idx = blockIdx.z;
  int c_idx = blockIdx.y;
  int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
  if (bs_idx >= b || c_idx >= c || pt_idx >= m) return;

  grad_out += bs_idx * c * m + c_idx * m + pt_idx;
  idx += bs_idx * m + pt_idx;
  grad_points += bs_idx * c * n + c_idx * n;

  atomicAdd(grad_points + idx[0], grad_out[0]);
wuyuefeng's avatar
wuyuefeng committed
70
71
72
}

void gather_points_grad_kernel_launcher(int b, int c, int n, int npoints,
zhangwenwei's avatar
zhangwenwei committed
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
                                        const float *grad_out, const int *idx,
                                        float *grad_points,
                                        cudaStream_t stream) {
  // grad_out: (B, C, npoints)
  // idx: (B, npoints)
  // output:
  //      grad_points: (B, C, N)

  cudaError_t err;
  dim3 blocks(DIVUP(npoints, THREADS_PER_BLOCK), c,
              b);  // blockIdx.x(col), blockIdx.y(row)
  dim3 threads(THREADS_PER_BLOCK);

  gather_points_grad_kernel<<<blocks, threads, 0, stream>>>(
      b, c, n, npoints, grad_out, idx, grad_points);

  err = cudaGetLastError();
  if (cudaSuccess != err) {
    fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
    exit(-1);
  }
wuyuefeng's avatar
wuyuefeng committed
94
}