gather_points.cpp 2.08 KB
Newer Older
wuyuefeng's avatar
wuyuefeng committed
1
#include <ATen/cuda/CUDAContext.h>
2
#include <ATen/TensorUtils.h>
wuyuefeng's avatar
wuyuefeng committed
3
4
#include <THC/THC.h>
#include <torch/extension.h>
zhangwenwei's avatar
zhangwenwei committed
5
6
7
#include <torch/serialize/tensor.h>

#include <vector>
wuyuefeng's avatar
wuyuefeng committed
8

9

wuyuefeng's avatar
wuyuefeng committed
10
11
12
extern THCState *state;

int gather_points_wrapper(int b, int c, int n, int npoints,
13
14
                          at::Tensor& points_tensor, at::Tensor& idx_tensor,
                          at::Tensor& out_tensor);
wuyuefeng's avatar
wuyuefeng committed
15
16

void gather_points_kernel_launcher(int b, int c, int n, int npoints,
17
18
19
                                   const at::Tensor& points_tensor,
                                   const at::Tensor& idx_tensor,
                                   at::Tensor& out_tensor);
wuyuefeng's avatar
wuyuefeng committed
20
21

int gather_points_grad_wrapper(int b, int c, int n, int npoints,
22
23
24
                               at::Tensor& grad_out_tensor,
                               at::Tensor& idx_tensor,
                               at::Tensor& grad_points_tensor);
wuyuefeng's avatar
wuyuefeng committed
25
26

void gather_points_grad_kernel_launcher(int b, int c, int n, int npoints,
27
28
29
                                        const at::Tensor& grad_out_tensor,
                                        const at::Tensor& idx_tensor,
                                        at::Tensor& grad_points_tensor);
wuyuefeng's avatar
wuyuefeng committed
30
31

int gather_points_wrapper(int b, int c, int n, int npoints,
32
33
34
35
                          at::Tensor& points_tensor, at::Tensor& idx_tensor,
                          at::Tensor& out_tensor)
{
  gather_points_kernel_launcher(b, c, n, npoints, points_tensor, idx_tensor, out_tensor);
zhangwenwei's avatar
zhangwenwei committed
36
  return 1;
wuyuefeng's avatar
wuyuefeng committed
37
38
39
}

int gather_points_grad_wrapper(int b, int c, int n, int npoints,
40
41
42
43
44
45
                               at::Tensor& grad_out_tensor,
                               at::Tensor& idx_tensor,
                               at::Tensor& grad_points_tensor)
{
  gather_points_grad_kernel_launcher(b, c, n, npoints, grad_out_tensor, idx_tensor,
                                     grad_points_tensor);
zhangwenwei's avatar
zhangwenwei committed
46
  return 1;
wuyuefeng's avatar
wuyuefeng committed
47
48
}

49
50
51

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
zhangwenwei's avatar
zhangwenwei committed
52
53
54
55
  m.def("gather_points_wrapper", &gather_points_wrapper,
        "gather_points_wrapper");
  m.def("gather_points_grad_wrapper", &gather_points_grad_wrapper,
        "gather_points_grad_wrapper");
wuyuefeng's avatar
wuyuefeng committed
56
}