padding.cpp 1.68 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
7
8
9
10
11
12
#include <Python.h>
#include <torch/script.h>

#ifdef WITH_CUDA
#include "cuda/padding_cuda.h"
#endif

#ifdef _WIN32
PyMODINIT_FUNC PyInit__padding(void) { return NULL; }
#endif

std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor,
rusty1s's avatar
DONE  
rusty1s committed
13
14
           std::vector<int64_t>, std::vector<int64_t>>
padded_index(torch::Tensor rowptr, torch::Tensor col, torch::Tensor rowcount,
rusty1s's avatar
rusty1s committed
15
             torch::Tensor binptr) {
rusty1s's avatar
DONE  
rusty1s committed
16
  return padded_index_cuda(rowptr, col, rowcount, binptr);
rusty1s's avatar
rusty1s committed
17
18
}

rusty1s's avatar
rusty1s committed
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
using torch::autograd::AutogradContext;
using torch::autograd::Variable;
using torch::autograd::variable_list;

class PaddedIndexSelect : public torch::autograd::Function<PaddedIndexSelect> {
public:
  static variable_list forward(AutogradContext *ctx, Variable src,
                               Variable index, Variable fill_value) {
    ctx->saved_data["N"] = src.size(0);
    auto out = padded_index_select_cuda(src, index, fill_value);
    ctx->save_for_backward({index});
    return {out};
  }

  static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
    auto grad_out = grad_outs[0];
    auto saved = ctx->get_saved_variables();
    auto index = saved[0];
    auto N = ctx->saved_data["N"].toInt();
    auto grad_in = padded_index_scatter_cuda(grad_out, index, N);
    return {grad_in, Variable(), Variable()};
  }
};

rusty1s's avatar
DONE  
rusty1s committed
43
torch::Tensor padded_index_select(torch::Tensor src, torch::Tensor index,
rusty1s's avatar
rusty1s committed
44
                                  torch::Tensor fill_value) {
rusty1s's avatar
rusty1s committed
45
  return PaddedIndexSelect::apply(src, index, fill_value)[0];
rusty1s's avatar
rusty1s committed
46
47
48
49
50
51
}

static auto registry =
    torch::RegisterOperators()
        .op("torch_sparse::padded_index", &padded_index)
        .op("torch_sparse::padded_index_select", &padded_index_select);