"vscode:/vscode.git/clone" did not exist on "a5a41922fbb2434c31ee4be5a9bf4fde81f9544a"
gather_points_cuda.cu 4.12 KB
Newer Older
wuyuefeng's avatar
wuyuefeng committed
1
2
#include <stdio.h>
#include <stdlib.h>
3
4
5
6
7
8
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/types.h>

#include <ATen/cuda/CUDAApplyUtils.cuh>
wuyuefeng's avatar
wuyuefeng committed
9
10
11

#define TOTAL_THREADS 1024
#define THREADS_PER_BLOCK 256
zhangwenwei's avatar
zhangwenwei committed
12
#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0))
wuyuefeng's avatar
wuyuefeng committed
13

14
template <typename scalar_t>
wuyuefeng's avatar
wuyuefeng committed
15
__global__ void gather_points_kernel(int b, int c, int n, int m,
16
                                     const scalar_t *__restrict__ points,
zhangwenwei's avatar
zhangwenwei committed
17
                                     const int *__restrict__ idx,
18
                                     scalar_t *__restrict__ out) {
zhangwenwei's avatar
zhangwenwei committed
19
20
21
22
23
24
25
26
27
28
29
30
31
32
  // points: (B, C, N)
  // idx: (B, M)
  // output:
  //      out: (B, C, M)

  int bs_idx = blockIdx.z;
  int c_idx = blockIdx.y;
  int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
  if (bs_idx >= b || c_idx >= c || pt_idx >= m) return;

  out += bs_idx * c * m + c_idx * m + pt_idx;
  idx += bs_idx * m + pt_idx;
  points += bs_idx * c * n + c_idx * n;
  out[0] = points[idx[0]];
wuyuefeng's avatar
wuyuefeng committed
33
34
35
}

void gather_points_kernel_launcher(int b, int c, int n, int npoints,
36
37
38
39
                                   const at::Tensor& points_tensor,
                                   const at::Tensor& idx_tensor,
                                   at::Tensor& out_tensor)
{
zhangwenwei's avatar
zhangwenwei committed
40
41
42
43
44
45
46
  // points: (B, C, N)
  // idx: (B, npoints)
  // output:
  //      out: (B, C, npoints)

  cudaError_t err;
  dim3 blocks(DIVUP(npoints, THREADS_PER_BLOCK), c,
47
              b); // blockIdx.x(col), blockIdx.y(row)
zhangwenwei's avatar
zhangwenwei committed
48
  dim3 threads(THREADS_PER_BLOCK);
49
50
51
52
53
54
55
56
57
58
59
60
  cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();

  AT_DISPATCH_FLOATING_TYPES_AND_HALF(
      out_tensor.scalar_type(), "gather_points_kernel",
      [&]
       {
         const scalar_t *points = points_tensor.data_ptr<scalar_t>();
         const int *idx = idx_tensor.data_ptr<int>();
         scalar_t *out = out_tensor.data_ptr<scalar_t>();
         gather_points_kernel<<<blocks, threads, 0, stream>>>(b, c, n, npoints, points,
                                                              idx, out);
       });
zhangwenwei's avatar
zhangwenwei committed
61
  err = cudaGetLastError();
62
63
  if (cudaSuccess != err)
  {
zhangwenwei's avatar
zhangwenwei committed
64
65
66
    fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
    exit(-1);
  }
wuyuefeng's avatar
wuyuefeng committed
67
68
}

69
template <typename scalar_t>
zhangwenwei's avatar
zhangwenwei committed
70
__global__ void gather_points_grad_kernel(int b, int c, int n, int m,
71
                                          const scalar_t *__restrict__ grad_out,
zhangwenwei's avatar
zhangwenwei committed
72
                                          const int *__restrict__ idx,
73
                                          scalar_t *__restrict__ grad_points) {
zhangwenwei's avatar
zhangwenwei committed
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
  // grad_out: (B, C, M)
  // idx: (B, M)
  // output:
  //      grad_points: (B, C, N)

  int bs_idx = blockIdx.z;
  int c_idx = blockIdx.y;
  int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
  if (bs_idx >= b || c_idx >= c || pt_idx >= m) return;

  grad_out += bs_idx * c * m + c_idx * m + pt_idx;
  idx += bs_idx * m + pt_idx;
  grad_points += bs_idx * c * n + c_idx * n;

  atomicAdd(grad_points + idx[0], grad_out[0]);
wuyuefeng's avatar
wuyuefeng committed
89
90
91
}

void gather_points_grad_kernel_launcher(int b, int c, int n, int npoints,
92
93
94
95
                                        const at::Tensor& grad_out_tensor,
                                        const at::Tensor& idx_tensor,
                                        at::Tensor& grad_points_tensor)
{
zhangwenwei's avatar
zhangwenwei committed
96
97
98
99
100
101
102
  // grad_out: (B, C, npoints)
  // idx: (B, npoints)
  // output:
  //      grad_points: (B, C, N)

  cudaError_t err;
  dim3 blocks(DIVUP(npoints, THREADS_PER_BLOCK), c,
103
              b); // blockIdx.x(col), blockIdx.y(row)
zhangwenwei's avatar
zhangwenwei committed
104
105
  dim3 threads(THREADS_PER_BLOCK);

106
107
108
109
110
111
112
113
114
115
116
  cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
  AT_DISPATCH_FLOATING_TYPES_AND_HALF(
      grad_points_tensor.scalar_type(), "gather_points_grad_kernel",
      [&]
       {
         const scalar_t *grad_out = grad_out_tensor.data_ptr<scalar_t>();
         const int *idx = idx_tensor.data_ptr<int>();
         scalar_t *grad_points = grad_points_tensor.data_ptr<scalar_t>();
         gather_points_grad_kernel<scalar_t><<<blocks, threads, 0, stream>>>(
             b, c, n, npoints, grad_out, idx, grad_points);
       });
zhangwenwei's avatar
zhangwenwei committed
117
118

  err = cudaGetLastError();
119
120
  if (cudaSuccess != err)
  {
zhangwenwei's avatar
zhangwenwei committed
121
122
123
    fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
    exit(-1);
  }
wuyuefeng's avatar
wuyuefeng committed
124
}