weighting.cpp 2.3 KB
Newer Older
1
#include <torch/extension.h>
rusty1s's avatar
rusty1s committed
2

rusty1s's avatar
rusty1s committed
3
4
#define CHECK_CUDA(x)                                                          \
  AT_ASSERTM(x.device().is_cuda(), #x " must be CUDA tensor")
rusty1s's avatar
rusty1s committed
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61

at::Tensor weighting_fw_cuda(at::Tensor x, at::Tensor weight, at::Tensor basis,
                             at::Tensor weight_index);

at::Tensor weighting_bw_x_cuda(at::Tensor grad_out, at::Tensor weight,
                               at::Tensor basis, at::Tensor weight_index);

at::Tensor weighting_bw_w_cuda(at::Tensor grad_out, at::Tensor x,
                               at::Tensor basis, at::Tensor weight_index,
                               int64_t K);

at::Tensor weighting_bw_b_cuda(at::Tensor grad_out, at::Tensor x,
                               at::Tensor weight, at::Tensor weight_index);

at::Tensor weighting_fw(at::Tensor x, at::Tensor weight, at::Tensor basis,
                        at::Tensor weight_index) {
  CHECK_CUDA(x);
  CHECK_CUDA(weight);
  CHECK_CUDA(basis);
  CHECK_CUDA(weight_index);
  return weighting_fw_cuda(x, weight, basis, weight_index);
}

at::Tensor weighting_bw_x(at::Tensor grad_out, at::Tensor weight,
                          at::Tensor basis, at::Tensor weight_index) {
  CHECK_CUDA(grad_out);
  CHECK_CUDA(weight);
  CHECK_CUDA(basis);
  CHECK_CUDA(weight_index);
  return weighting_bw_x_cuda(grad_out, weight, basis, weight_index);
}

at::Tensor weighting_bw_w(at::Tensor grad_out, at::Tensor x, at::Tensor basis,
                          at::Tensor weight_index, int64_t K) {
  CHECK_CUDA(grad_out);
  CHECK_CUDA(x);
  CHECK_CUDA(basis);
  CHECK_CUDA(weight_index);
  return weighting_bw_w_cuda(grad_out, x, basis, weight_index, K);
}

at::Tensor weighting_bw_b(at::Tensor grad_out, at::Tensor x, at::Tensor weight,
                          at::Tensor weight_index) {
  CHECK_CUDA(grad_out);
  CHECK_CUDA(x);
  CHECK_CUDA(weight);
  CHECK_CUDA(weight_index);
  return weighting_bw_b_cuda(grad_out, x, weight, weight_index);
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("weighting_fw", &weighting_fw, "Weighting Forward (CUDA)");
  m.def("weighting_bw_x", &weighting_bw_x, "Weighting Backward X (CUDA)");
  m.def("weighting_bw_w", &weighting_bw_w, "Weighting Backward Weight (CUDA)");
  m.def("weighting_bw_b", &weighting_bw_b, "Weighting Backward Basis (CUDA)");
}
#define BLOCKS(N) (N + THREADS - 1) / THREADS