degree_padding.cpp 1.96 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
7
8
9
10
11
#include <Python.h>
#include <torch/script.h>

#ifdef WITH_CUDA
#include "cuda/degree_padding_cuda.h"
#endif

#ifdef _WIN32
PyMODINIT_FUNC PyInit__degree_padding(void) { return NULL; }
#endif

rusty1s's avatar
rusty1s committed
12
13
std::tuple<std::vector<torch::Tensor>, std::vector<int64_t>>
bin_assignment(torch::Tensor rowcount, torch::Tensor binptr) {
rusty1s's avatar
rusty1s committed
14
15
  if (rowcount.device().is_cuda()) {
#ifdef WITH_CUDA
rusty1s's avatar
rusty1s committed
16
    return bin_assignment_cuda(rowcount, binptr);
rusty1s's avatar
rusty1s committed
17
18
19
20
21
22
23
24
25
26
#else
    AT_ERROR("Not compiled with CUDA support");
#endif
  } else {
    AT_ERROR("Not implemented yet");
  }
}

std::tuple<torch::Tensor, torch::Tensor>
padded_index_select(torch::Tensor src, torch::Tensor rowptr, torch::Tensor col,
rusty1s's avatar
rusty1s committed
27
28
                    torch::Tensor index, int64_t length,
                    torch::Tensor fill_value) {
rusty1s's avatar
rusty1s committed
29
30
  if (src.device().is_cuda()) {
#ifdef WITH_CUDA
rusty1s's avatar
rusty1s committed
31
32
    return padded_index_select_cuda(src, rowptr, col, index, length,
                                    fill_value);
rusty1s's avatar
rusty1s committed
33
34
35
36
37
38
39
40
#else
    AT_ERROR("Not compiled with CUDA support");
#endif
  } else {
    AT_ERROR("Not implemented yet");
  }
}

rusty1s's avatar
rusty1s committed
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
// std::tuple<torch::Tensor, torch::Tensor>
// padded_index_select2(torch::Tensor src, torch::Tensor rowptr, torch::Tensor
// col,
//                      torch::Tensor bin, torch::Tensor index,
//                      std::vector<int64_t> node_counts,
//                      std::vector<int64_t> lengths, torch::Tensor fill_value)
//                      {
//   if (src.device().is_cuda()) {
// #ifdef WITH_CUDA
//     return padded_index_select_cuda2(src, rowptr, col, bin, index,
//     node_counts,
//                                      lengths, fill_value);
// #else
//     AT_ERROR("Not compiled with CUDA support");
// #endif
//   } else {
//     AT_ERROR("Not implemented yet");
//   }
// }

rusty1s's avatar
rusty1s committed
61
62
63
64
static auto registry =
    torch::RegisterOperators()
        .op("torch_sparse::bin_assignment", &bin_assignment)
        .op("torch_sparse::padded_index_select", &padded_index_select);
rusty1s's avatar
rusty1s committed
65
// .op("torch_sparse::padded_index_select2", &padded_index_select2);