spspmm_cpu.cpp 3.44 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
#include "spspmm_cpu.h"

#include "utils.h"

std::tuple<torch::Tensor, torch::Tensor, torch::optional<torch::Tensor>>
spspmm_cpu(torch::Tensor rowptrA, torch::Tensor colA,
           torch::optional<torch::Tensor> optional_valueA,
           torch::Tensor rowptrB, torch::Tensor colB,
           torch::optional<torch::Tensor> optional_valueB, int64_t K,
           std::string reduce) {

  CHECK_CPU(rowptrA);
  CHECK_CPU(colA);
  if (optional_valueA.has_value())
    CHECK_CPU(optional_valueA.value());
  CHECK_CPU(rowptrB);
  CHECK_CPU(colB);
  if (optional_valueB.has_value())
    CHECK_CPU(optional_valueB.value());

  CHECK_INPUT(rowptrA.dim() == 1);
  CHECK_INPUT(colA.dim() == 1);
  if (optional_valueA.has_value()) {
    CHECK_INPUT(optional_valueA.value().dim() == 1);
    CHECK_INPUT(optional_valueA.value().size(0) == colA.size(0));
  }
  CHECK_INPUT(rowptrB.dim() == 1);
  CHECK_INPUT(colB.dim() == 1);
  if (optional_valueB.has_value()) {
    CHECK_INPUT(optional_valueB.value().dim() == 1);
    CHECK_INPUT(optional_valueB.value().size(0) == colB.size(0));
  }

  if (!optional_valueA.has_value() && optional_valueB.has_value())
    optional_valueA =
        torch::ones(colA.numel(), optional_valueB.value().options());

  if (!optional_valueB.has_value() && optional_valueA.has_value())
    optional_valueB =
        torch::ones(colB.numel(), optional_valueA.value().options());

  auto scalar_type = torch::ScalarType::Float;
  if (optional_valueA.has_value())
    scalar_type = optional_valueA.value().scalar_type();

  auto rowptrA_data = rowptrA.data_ptr<int64_t>();
  auto colA_data = colA.data_ptr<int64_t>();
  auto rowptrB_data = rowptrB.data_ptr<int64_t>();
  auto colB_data = colB.data_ptr<int64_t>();

  auto rowptrC = torch::empty_like(rowptrA);
  auto rowptrC_data = rowptrC.data_ptr<int64_t>();
  rowptrC_data[0] = 0;

rusty1s's avatar
rusty1s committed
55
  torch::Tensor colC;
rusty1s's avatar
rusty1s committed
56
57
58
  torch::optional<torch::Tensor> optional_valueC = torch::nullopt;

  AT_DISPATCH_ALL_TYPES(scalar_type, "spspmm", [&] {
rusty1s's avatar
rusty1s committed
59
60
    AT_DISPATCH_HAS_VALUE(optional_valueA, [&] {
      scalar_t *valA_data = nullptr, *valB_data = nullptr;
rusty1s's avatar
rusty1s committed
61
62
63
64
65
      if (HAS_VALUE) {
        valA_data = optional_valueA.value().data_ptr<scalar_t>();
        valB_data = optional_valueB.value().data_ptr<scalar_t>();
      }

rusty1s's avatar
rusty1s committed
66
67
68
69
      int64_t nnz = 0, cA, cB;
      std::vector<scalar_t> tmp_vals(K, 0);
      std::vector<int64_t> cols;
      std::vector<scalar_t> vals;
rusty1s's avatar
rusty1s committed
70

rusty1s's avatar
rusty1s committed
71
72
      for (auto rA = 0; rA < rowptrA.numel() - 1; rA++) {
        for (auto eA = rowptrA_data[rA]; eA < rowptrA_data[rA + 1]; eA++) {
rusty1s's avatar
rusty1s committed
73
          cA = colA_data[eA];
rusty1s's avatar
rusty1s committed
74
          for (auto eB = rowptrB_data[cA]; eB < rowptrB_data[cA + 1]; eB++) {
rusty1s's avatar
rusty1s committed
75
            cB = colB_data[eB];
rusty1s's avatar
rusty1s committed
76

rusty1s's avatar
rusty1s committed
77
            if (HAS_VALUE)
rusty1s's avatar
rusty1s committed
78
              tmp_vals[cB] += valA_data[eA] * valB_data[eB];
rusty1s's avatar
rusty1s committed
79
            else
rusty1s's avatar
rusty1s committed
80
              tmp_vals[cB]++;
rusty1s's avatar
rusty1s committed
81
82
83
84
          }
        }

        for (auto k = 0; k < K; k++) {
rusty1s's avatar
rusty1s committed
85
86
          if (tmp_vals[k] != 0) {
            cols.push_back(k);
rusty1s's avatar
rusty1s committed
87
            if (HAS_VALUE)
rusty1s's avatar
rusty1s committed
88
              vals.push_back(tmp_vals[k]);
rusty1s's avatar
rusty1s committed
89
90
            nnz++;
          }
rusty1s's avatar
rusty1s committed
91
          tmp_vals[k] = (scalar_t)0;
rusty1s's avatar
rusty1s committed
92
        }
rusty1s's avatar
rusty1s committed
93
94
        rowptrC_data[rA + 1] = nnz;
      }
rusty1s's avatar
rusty1s committed
95

rusty1s's avatar
rusty1s committed
96
97
98
99
100
      colC = torch::from_blob(cols.data(), {nnz}, colA.options()).clone();
      if (HAS_VALUE) {
        optional_valueC = torch::from_blob(vals.data(), {nnz},
                                           optional_valueA.value().options());
        optional_valueC = optional_valueC.value().clone();
rusty1s's avatar
rusty1s committed
101
102
103
104
105
106
      }
    });
  });

  return std::make_tuple(rowptrC, colC, optional_valueC);
}