padding.cpp 2.32 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
#include <Python.h>
#include <torch/script.h>

rusty1s's avatar
rusty1s committed
4
5
#include "cpu/padding_cpu.h"

rusty1s's avatar
rusty1s committed
6
7
8
9
10
11
12
13
14
#ifdef WITH_CUDA
#include "cuda/padding_cuda.h"
#endif

#ifdef _WIN32
PyMODINIT_FUNC PyInit__padding(void) { return NULL; }
#endif

std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor,
rusty1s's avatar
DONE  
rusty1s committed
15
16
           std::vector<int64_t>, std::vector<int64_t>>
padded_index(torch::Tensor rowptr, torch::Tensor col, torch::Tensor rowcount,
rusty1s's avatar
rusty1s committed
17
             torch::Tensor binptr) {
rusty1s's avatar
rusty1s committed
18
19
20
21
22
23
24
25
26
  if (rowptr.device().is_cuda()) {
#ifdef WITH_CUDA
    return padded_index_cuda(rowptr, col, rowcount, binptr);
#else
    AT_ERROR("Not compiled with CUDA support");
#endif
  } else {
    return padded_index_cpu(rowptr, col, rowcount, binptr);
  }
rusty1s's avatar
rusty1s committed
27
28
}

rusty1s's avatar
rusty1s committed
29
30
31
32
33
34
35
36
37
using torch::autograd::AutogradContext;
using torch::autograd::Variable;
using torch::autograd::variable_list;

class PaddedIndexSelect : public torch::autograd::Function<PaddedIndexSelect> {
public:
  static variable_list forward(AutogradContext *ctx, Variable src,
                               Variable index, Variable fill_value) {
    ctx->saved_data["N"] = src.size(0);
rusty1s's avatar
rusty1s committed
38
39
40
41
42
43
44
45
46
47
48

    torch::Tensor out;
    if (src.device().is_cuda()) {
#ifdef WITH_CUDA
      out = padded_index_select_cuda(src, index, fill_value);
#else
      AT_ERROR("Not compiled with CUDA support");
#endif
    } else {
      out = padded_index_select_cpu(src, index, fill_value);
    }
rusty1s's avatar
rusty1s committed
49
50
51
52
53
54
55
56
57
    ctx->save_for_backward({index});
    return {out};
  }

  static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
    auto grad_out = grad_outs[0];
    auto saved = ctx->get_saved_variables();
    auto index = saved[0];
    auto N = ctx->saved_data["N"].toInt();
rusty1s's avatar
rusty1s committed
58
59
60
61
62
63
64
65
66
67
    torch::Tensor grad_in;
    if (grad_out.device().is_cuda()) {
#ifdef WITH_CUDA
      grad_in = padded_index_scatter_cuda(grad_out, index, N);
#else
      AT_ERROR("Not compiled with CUDA support");
#endif
    } else {
      grad_in = padded_index_scatter_cpu(grad_out, index, N);
    }
rusty1s's avatar
rusty1s committed
68
69
70
71
    return {grad_in, Variable(), Variable()};
  }
};

rusty1s's avatar
DONE  
rusty1s committed
72
torch::Tensor padded_index_select(torch::Tensor src, torch::Tensor index,
rusty1s's avatar
rusty1s committed
73
                                  torch::Tensor fill_value) {
rusty1s's avatar
rusty1s committed
74
  return PaddedIndexSelect::apply(src, index, fill_value)[0];
rusty1s's avatar
rusty1s committed
75
76
77
78
79
80
}

static auto registry =
    torch::RegisterOperators()
        .op("torch_sparse::padded_index", &padded_index)
        .op("torch_sparse::padded_index_select", &padded_index_select);