three_interpolate_cuda.cu 3.52 KB
Newer Older
wuyuefeng's avatar
wuyuefeng committed
1
2
3
4
5
#include <math.h>
#include <stdio.h>
#include <stdlib.h>

#define THREADS_PER_BLOCK 256
zhangwenwei's avatar
zhangwenwei committed
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0))

__global__ void three_interpolate_kernel(int b, int c, int m, int n,
                                         const float *__restrict__ points,
                                         const int *__restrict__ idx,
                                         const float *__restrict__ weight,
                                         float *__restrict__ out) {
  // points: (B, C, M)
  // idx: (B, N, 3)
  // weight: (B, N, 3)
  // output:
  //      out: (B, C, N)

  int bs_idx = blockIdx.z;
  int c_idx = blockIdx.y;
  int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;

  if (bs_idx >= b || c_idx >= c || pt_idx >= n) return;

  weight += bs_idx * n * 3 + pt_idx * 3;
  points += bs_idx * c * m + c_idx * m;
  idx += bs_idx * n * 3 + pt_idx * 3;
  out += bs_idx * c * n + c_idx * n;

  out[pt_idx] = weight[0] * points[idx[0]] + weight[1] * points[idx[1]] +
                weight[2] * points[idx[2]];
wuyuefeng's avatar
wuyuefeng committed
32
33
34
}

void three_interpolate_kernel_launcher(int b, int c, int m, int n,
zhangwenwei's avatar
zhangwenwei committed
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
                                       const float *points, const int *idx,
                                       const float *weight, float *out,
                                       cudaStream_t stream) {
  // points: (B, C, M)
  // idx: (B, N, 3)
  // weight: (B, N, 3)
  // output:
  //      out: (B, C, N)

  cudaError_t err;
  dim3 blocks(DIVUP(n, THREADS_PER_BLOCK), c,
              b);  // blockIdx.x(col), blockIdx.y(row)
  dim3 threads(THREADS_PER_BLOCK);
  three_interpolate_kernel<<<blocks, threads, 0, stream>>>(b, c, m, n, points,
                                                           idx, weight, out);

  err = cudaGetLastError();
  if (cudaSuccess != err) {
    fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
    exit(-1);
  }
wuyuefeng's avatar
wuyuefeng committed
56
57
}

zhangwenwei's avatar
zhangwenwei committed
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
__global__ void three_interpolate_grad_kernel(
    int b, int c, int n, int m, const float *__restrict__ grad_out,
    const int *__restrict__ idx, const float *__restrict__ weight,
    float *__restrict__ grad_points) {
  // grad_out: (B, C, N)
  // weight: (B, N, 3)
  // output:
  //      grad_points: (B, C, M)

  int bs_idx = blockIdx.z;
  int c_idx = blockIdx.y;
  int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;

  if (bs_idx >= b || c_idx >= c || pt_idx >= n) return;

  grad_out += bs_idx * c * n + c_idx * n + pt_idx;
  weight += bs_idx * n * 3 + pt_idx * 3;
  grad_points += bs_idx * c * m + c_idx * m;
  idx += bs_idx * n * 3 + pt_idx * 3;

  atomicAdd(grad_points + idx[0], grad_out[0] * weight[0]);
  atomicAdd(grad_points + idx[1], grad_out[0] * weight[1]);
  atomicAdd(grad_points + idx[2], grad_out[0] * weight[2]);
wuyuefeng's avatar
wuyuefeng committed
81
82
}

zhangwenwei's avatar
zhangwenwei committed
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
void three_interpolate_grad_kernel_launcher(int b, int c, int n, int m,
                                            const float *grad_out,
                                            const int *idx, const float *weight,
                                            float *grad_points,
                                            cudaStream_t stream) {
  // grad_out: (B, C, N)
  // weight: (B, N, 3)
  // output:
  //      grad_points: (B, C, M)

  cudaError_t err;
  dim3 blocks(DIVUP(n, THREADS_PER_BLOCK), c,
              b);  // blockIdx.x(col), blockIdx.y(row)
  dim3 threads(THREADS_PER_BLOCK);
  three_interpolate_grad_kernel<<<blocks, threads, 0, stream>>>(
      b, c, n, m, grad_out, idx, weight, grad_points);

  err = cudaGetLastError();
  if (cudaSuccess != err) {
    fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
    exit(-1);
  }
wuyuefeng's avatar
wuyuefeng committed
105
}