spspmm.cpp 2.14 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
#include <torch/extension.h>

rusty1s's avatar
rusty1s committed
3
4
#include "compat.h"

rusty1s's avatar
rusty1s committed
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
at::Tensor degree(at::Tensor row, int64_t num_nodes) {
  auto zero = at::zeros(num_nodes, row.options());
  auto one = at::ones(row.size(0), row.options());
  return zero.scatter_add_(0, row, one);
}

std::tuple<at::Tensor, at::Tensor> to_csr(at::Tensor row, at::Tensor col,
                                          int64_t num_nodes) {
  // Assert already coalesced input.
  row = degree(row, num_nodes).cumsum(0);
  row = at::cat({at::zeros(1, row.options()), row}, 0); // Prepend zero.
  return std::make_tuple(row, col);
}

at::Tensor spspmm_bw(at::Tensor index, at::Tensor indexA, at::Tensor valueA,
                     at::Tensor indexB, at::Tensor valueB, size_t rowA_max,
                     size_t rowB_max) {

rusty1s's avatar
rusty1s committed
23
  int64_t *index_data = index.DATA_PTR<int64_t>();
rusty1s's avatar
rusty1s committed
24
25
26
27
  auto value = at::zeros(index.size(1), valueA.options());

  at::Tensor rowA, colA;
  std::tie(rowA, colA) = to_csr(indexA[0], indexA[1], rowA_max);
rusty1s's avatar
rusty1s committed
28
29
  int64_t *rowA_data = rowA.DATA_PTR<int64_t>();
  int64_t *colA_data = colA.DATA_PTR<int64_t>();
rusty1s's avatar
rusty1s committed
30
31
32

  at::Tensor rowB, colB;
  std::tie(rowB, colB) = to_csr(indexB[0], indexB[1], rowB_max);
rusty1s's avatar
rusty1s committed
33
34
  int64_t *rowB_data = rowB.DATA_PTR<int64_t>();
  int64_t *colB_data = colB.DATA_PTR<int64_t>();
rusty1s's avatar
rusty1s committed
35

rusty1s's avatar
rusty1s committed
36
  AT_DISPATCH_FLOATING_TYPES(valueA.scalar_type(), "spspmm_bw", [&] {
rusty1s's avatar
rusty1s committed
37
38
39
    scalar_t *value_data = value.DATA_PTR<scalar_t>();
    scalar_t *valueA_data = valueA.DATA_PTR<scalar_t>();
    scalar_t *valueB_data = valueB.DATA_PTR<scalar_t>();
rusty1s's avatar
rusty1s committed
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68

    for (int64_t e = 0; e < value.size(0); e++) {
      int64_t i = index_data[e], j = index_data[value.size(0) + e];

      for (ptrdiff_t dA = rowA_data[i]; dA < rowA_data[i + 1]; dA++) {
        int64_t cA = colA_data[dA];

        for (ptrdiff_t dB = rowB_data[j]; dB < rowB_data[j + 1]; dB++) {
          int64_t cB = colB_data[dB];

          if (cA == cB) {
            value_data[e] += valueA_data[dA] * valueB_data[dB];
          }

          if (cB >= cA) {
            break;
          }
        }
      }
    }
  });

  return value;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("spspmm_bw", &spspmm_bw,
        "Sparse-Sparse Matrix Multiplication Backward (CPU)");
}