weighting.cpp 4.23 KB
Newer Older
quyuanhao123's avatar
quyuanhao123 committed
1
2
3
4
5
#include <Python.h>
#include <torch/script.h>

#include "cpu/weighting_cpu.h"

limm's avatar
limm committed
6
7
#ifdef WITH_CUDA
#include "cuda/weighting_cuda.h"
quyuanhao123's avatar
quyuanhao123 committed
8
9
10
#endif

#ifdef _WIN32
limm's avatar
limm committed
11
#ifdef WITH_CUDA
quyuanhao123's avatar
quyuanhao123 committed
12
13
14
15
16
17
18
19
20
21
PyMODINIT_FUNC PyInit__weighting_cuda(void) { return NULL; }
#else
PyMODINIT_FUNC PyInit__weighting_cpu(void) { return NULL; }
#endif
#endif

torch::Tensor spline_weighting_fw(torch::Tensor x, torch::Tensor weight,
                                  torch::Tensor basis,
                                  torch::Tensor weight_index) {
  if (x.device().is_cuda()) {
limm's avatar
limm committed
22
#ifdef WITH_CUDA
quyuanhao123's avatar
quyuanhao123 committed
23
24
25
26
27
28
29
30
31
32
33
34
35
    return spline_weighting_fw_cuda(x, weight, basis, weight_index);
#else
    AT_ERROR("Not compiled with CUDA support");
#endif
  } else {
    return spline_weighting_fw_cpu(x, weight, basis, weight_index);
  }
}

torch::Tensor spline_weighting_bw_x(torch::Tensor grad_out,
                                    torch::Tensor weight, torch::Tensor basis,
                                    torch::Tensor weight_index) {
  if (grad_out.device().is_cuda()) {
limm's avatar
limm committed
36
#ifdef WITH_CUDA
quyuanhao123's avatar
quyuanhao123 committed
37
38
39
40
41
42
43
44
45
46
47
48
49
50
    return spline_weighting_bw_x_cuda(grad_out, weight, basis, weight_index);
#else
    AT_ERROR("Not compiled with CUDA support");
#endif
  } else {
    return spline_weighting_bw_x_cpu(grad_out, weight, basis, weight_index);
  }
}

torch::Tensor spline_weighting_bw_weight(torch::Tensor grad_out,
                                         torch::Tensor x, torch::Tensor basis,
                                         torch::Tensor weight_index,
                                         int64_t kernel_size) {
  if (grad_out.device().is_cuda()) {
limm's avatar
limm committed
51
#ifdef WITH_CUDA
quyuanhao123's avatar
quyuanhao123 committed
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
    return spline_weighting_bw_weight_cuda(grad_out, x, basis, weight_index,
                                           kernel_size);
#else
    AT_ERROR("Not compiled with CUDA support");
#endif
  } else {
    return spline_weighting_bw_weight_cpu(grad_out, x, basis, weight_index,
                                          kernel_size);
  }
}

torch::Tensor spline_weighting_bw_basis(torch::Tensor grad_out, torch::Tensor x,
                                        torch::Tensor weight,
                                        torch::Tensor weight_index) {
  if (grad_out.device().is_cuda()) {
limm's avatar
limm committed
67
#ifdef WITH_CUDA
quyuanhao123's avatar
quyuanhao123 committed
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
    return spline_weighting_bw_basis_cuda(grad_out, x, weight, weight_index);
#else
    AT_ERROR("Not compiled with CUDA support");
#endif
  } else {
    return spline_weighting_bw_basis_cpu(grad_out, x, weight, weight_index);
  }
}

using torch::autograd::AutogradContext;
using torch::autograd::Variable;
using torch::autograd::variable_list;

class SplineWeighting : public torch::autograd::Function<SplineWeighting> {
public:
  static variable_list forward(AutogradContext *ctx, Variable x,
                               Variable weight, Variable basis,
                               Variable weight_index) {
    auto out = spline_weighting_fw(x, weight, basis, weight_index);
    ctx->save_for_backward({x, weight, basis, weight_index});
    return {out};
  }

  static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
    auto grad_out = grad_outs[0];
    auto saved = ctx->get_saved_variables();
    auto x = saved[0], weight = saved[1], basis = saved[2],
         weight_index = saved[3];

    auto grad_x = Variable();
    if (torch::autograd::any_variable_requires_grad({x})) {
      grad_x = spline_weighting_bw_x(grad_out, weight, basis, weight_index);
    }

    auto grad_weight = Variable();
    if (torch::autograd::any_variable_requires_grad({weight})) {
      grad_weight = spline_weighting_bw_weight(grad_out, x, basis, weight_index,
                                               weight.size(0));
    }

    auto grad_basis = Variable();
    if (torch::autograd::any_variable_requires_grad({basis})) {
      grad_basis = spline_weighting_bw_basis(grad_out, x, weight, weight_index);
    }

    return {grad_x, grad_weight, grad_basis, Variable()};
  }
};

torch::Tensor spline_weighting(torch::Tensor x, torch::Tensor weight,
                               torch::Tensor basis,
                               torch::Tensor weight_index) {
  x = x.contiguous();
  weight = weight.contiguous();
  return SplineWeighting::apply(x, weight, basis, weight_index)[0];
}

static auto registry = torch::RegisterOperators().op(
    "torch_spline_conv::spline_weighting", &spline_weighting);