spspmm.cpp 1.11 KB
Newer Older
1
#ifdef WITH_PYTHON
rusty1s's avatar
rusty1s committed
2
#include <Python.h>
3
#endif
rusty1s's avatar
rusty1s committed
4
5
6
7
8
9
10
11
#include <torch/script.h>

#include "cpu/spspmm_cpu.h"

#ifdef WITH_CUDA
#include "cuda/spspmm_cuda.h"
#endif

rusty1s's avatar
rusty1s committed
12
#ifdef _WIN32
13
#ifdef WITH_PYTHON
rusty1s's avatar
rusty1s committed
14
15
16
17
18
#ifdef WITH_CUDA
PyMODINIT_FUNC PyInit__spspmm_cuda(void) { return NULL; }
#else
PyMODINIT_FUNC PyInit__spspmm_cpu(void) { return NULL; }
#endif
rusty1s's avatar
rusty1s committed
19
#endif
20
#endif
rusty1s's avatar
rusty1s committed
21

rusty1s's avatar
rusty1s committed
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
std::tuple<torch::Tensor, torch::Tensor, torch::optional<torch::Tensor>>
spspmm_sum(torch::Tensor rowptrA, torch::Tensor colA,
           torch::optional<torch::Tensor> optional_valueA,
           torch::Tensor rowptrB, torch::Tensor colB,
           torch::optional<torch::Tensor> optional_valueB, int64_t K) {
  if (rowptrA.device().is_cuda()) {
#ifdef WITH_CUDA
    return spspmm_cuda(rowptrA, colA, optional_valueA, rowptrB, colB,
                       optional_valueB, K, "sum");
#else
    AT_ERROR("Not compiled with CUDA support");
#endif
  } else {
    return spspmm_cpu(rowptrA, colA, optional_valueA, rowptrB, colB,
                      optional_valueB, K, "sum");
  }
}

static auto registry =
    torch::RegisterOperators().op("torch_sparse::spspmm_sum", &spspmm_sum);