padding.cpp 863 Bytes
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
7
8
9
10
11
12
#include <Python.h>
#include <torch/script.h>

#ifdef WITH_CUDA
#include "cuda/padding_cuda.h"
#endif

#ifdef _WIN32
PyMODINIT_FUNC PyInit__padding(void) { return NULL; }
#endif

std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor,
rusty1s's avatar
DONE  
rusty1s committed
13
14
           std::vector<int64_t>, std::vector<int64_t>>
padded_index(torch::Tensor rowptr, torch::Tensor col, torch::Tensor rowcount,
rusty1s's avatar
rusty1s committed
15
             torch::Tensor binptr) {
rusty1s's avatar
DONE  
rusty1s committed
16
  return padded_index_cuda(rowptr, col, rowcount, binptr);
rusty1s's avatar
rusty1s committed
17
18
}

rusty1s's avatar
DONE  
rusty1s committed
19
torch::Tensor padded_index_select(torch::Tensor src, torch::Tensor index,
rusty1s's avatar
rusty1s committed
20
                                  torch::Tensor fill_value) {
rusty1s's avatar
DONE  
rusty1s committed
21
  return padded_index_select_cuda(src, index, fill_value);
rusty1s's avatar
rusty1s committed
22
23
24
25
26
27
}

static auto registry =
    torch::RegisterOperators()
        .op("torch_sparse::padded_index", &padded_index)
        .op("torch_sparse::padded_index_select", &padded_index_select);