weighting.cpp 4.1 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
#include <Python.h>
#include <torch/script.h>

#include "cpu/weighting_cpu.h"
#include "utils.h"

#ifdef WITH_CUDA
#include "cuda/weighting_cuda.h"
#endif

#ifdef _WIN32
PyMODINIT_FUNC PyInit__weighting(void) { return NULL; }
#endif

torch::Tensor spline_weighting_fw(torch::Tensor x, torch::Tensor weight,
                                  torch::Tensor basis,
                                  torch::Tensor weight_index) {
  if (x.device().is_cuda()) {
#ifdef WITH_CUDA
    return spline_weighting_fw_cuda(x, weight, basis, weight_index);
#else
    AT_ERROR("Not compiled with CUDA support");
#endif
  } else {
    return spline_weighting_fw_cpu(x, weight, basis, weight_index);
  }
}

torch::Tensor spline_weighting_bw_x(torch::Tensor grad_out,
                                    torch::Tensor weight, torch::Tensor basis,
                                    torch::Tensor weight_index) {
  if (grad_out.device().is_cuda()) {
#ifdef WITH_CUDA
    return spline_weighting_bw_x_cuda(grad_out, weight, basis, weight_index);
#else
    AT_ERROR("Not compiled with CUDA support");
#endif
  } else {
    return spline_weighting_bw_x_cpu(grad_out, weight, basis, weight_index);
  }
}

torch::Tensor spline_weighting_bw_weight(torch::Tensor grad_out,
                                         torch::Tensor x, torch::Tensor basis,
                                         torch::Tensor weight_index,
                                         int64_t kernel_size) {
  if (grad_out.device().is_cuda()) {
#ifdef WITH_CUDA
    return spline_weighting_bw_weight_cuda(grad_out, x, basis, weight_index,
                                           kernel_size);
#else
    AT_ERROR("Not compiled with CUDA support");
#endif
  } else {
    return spline_weighting_bw_weight_cpu(grad_out, x, basis, weight_index,
                                          kernel_size);
  }
}

torch::Tensor spline_weighting_bw_basis(torch::Tensor grad_out, torch::Tensor x,
                                        torch::Tensor weight,
                                        torch::Tensor weight_index) {
  if (grad_out.device().is_cuda()) {
#ifdef WITH_CUDA
    return spline_weighting_bw_basis_cuda(grad_out, x, weight, weight_index);
#else
    AT_ERROR("Not compiled with CUDA support");
#endif
  } else {
    return spline_weighting_bw_basis_cpu(grad_out, x, weight, weight_index);
  }
}

using torch::autograd::AutogradContext;
using torch::autograd::Variable;
using torch::autograd::variable_list;

class SplineWeighting : public torch::autograd::Function<SplineWeighting> {
public:
  static variable_list forward(AutogradContext *ctx, Variable x,
                               Variable weight, Variable basis,
                               Variable weight_index) {
    auto out = spline_weighting_fw(x, weight, basis, weight_index);
    ctx->save_for_backward({x, weight, basis, weight_index});
    return {out};
  }

  static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
    auto grad_out = grad_outs[0];
    auto saved = ctx->get_saved_variables();
    auto x = saved[0], weight = saved[1], basis = saved[2],
         weight_index = saved[3];

    auto grad_x = Variable();
    if (torch::autograd::any_variable_requires_grad({x})) {
      grad_x = spline_weighting_bw_x(grad_out, weight, basis, weight_index);
    }

    auto grad_weight = Variable();
    if (torch::autograd::any_variable_requires_grad({weight})) {
      grad_weight = spline_weighting_bw_weight(grad_out, x, basis, weight_index,
                                               weight.size(0));
    }

    auto grad_basis = Variable();
    if (torch::autograd::any_variable_requires_grad({basis})) {
      grad_basis = spline_weighting_bw_basis(grad_out, x, weight, weight_index);
    }

    return {grad_x, grad_weight, grad_basis, Variable()};
  }
};

torch::Tensor spline_weighting(torch::Tensor x, torch::Tensor weight,
                               torch::Tensor basis,
                               torch::Tensor weight_index) {
  return SplineWeighting::apply(x, weight, basis, weight_index);
}

static auto registry = torch::RegisterOperators().op(
    "torch_spline_conv::spline_weighting", &spline_weighting);