spmm.cpp 7.01 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
#include <torch/extension.h>

#include "compat.h"

#define CHECK_CPU(x) AT_ASSERTM(!x.type().is_cuda(), #x " must be CPU tensor")

enum ReductionType { SUM, MEAN, MIN, MAX };

const std::map<std::string, ReductionType> reduce2REDUCE = {
    {"sum", SUM}, {"add", SUM}, {"mean", MEAN}, {"min", MIN}, {"max", MAX},
};

#define AT_DISPATCH_REDUCTION_TYPES(reduce, ...)                               \
  [&] {                                                                        \
    switch (reduce2REDUCE.at(reduce)) {                                        \
    case SUM: {                                                                \
      const ReductionType REDUCE = SUM;                                        \
      return __VA_ARGS__();                                                    \
    }                                                                          \
    case MEAN: {                                                               \
      const ReductionType REDUCE = MEAN;                                       \
      return __VA_ARGS__();                                                    \
    }                                                                          \
    case MIN: {                                                                \
      const ReductionType REDUCE = MIN;                                        \
      return __VA_ARGS__();                                                    \
    }                                                                          \
    case MAX: {                                                                \
      const ReductionType REDUCE = MAX;                                        \
      return __VA_ARGS__();                                                    \
    }                                                                          \
    }                                                                          \
  }()

#define AT_DISPATCH_HAS_VAL(value_opt, ...)                                    \
  [&] {                                                                        \
    switch (value_opt.has_value()) {                                           \
    case true: {                                                               \
      const bool HAS_VAL = true;                                               \
      return __VA_ARGS__();                                                    \
    }                                                                          \
    case false: {                                                              \
      const bool HAS_VAL = false;                                              \
      return __VA_ARGS__();                                                    \
    }                                                                          \
    }                                                                          \
  }()

template <typename scalar_t, ReductionType REDUCE> struct Reducer {
  static inline scalar_t init() {
    if (REDUCE == MIN) {
      return std::numeric_limits<scalar_t>::max();
    } else if (REDUCE == MAX) {
      return std::numeric_limits<scalar_t>::lowest();
    } else {
      return (scalar_t)0;
    }
  }

  static inline void update(scalar_t *val, scalar_t new_val, int64_t *arg,
                            int64_t new_arg) {
    if (REDUCE == SUM || REDUCE == MEAN) {
      *val = *val + new_val;
    } else if ((REDUCE == MIN && new_val < *val) ||
               (REDUCE == MAX && new_val > *val)) {
      *val = new_val;
      *arg = new_arg;
    }
  }

  static inline void write(scalar_t *address, scalar_t val,
                           int64_t *arg_address, int64_t arg, int count) {
    if (REDUCE == SUM) {
      *address = val;
    } else if (REDUCE == MEAN) {
      *address = val / (count > 0 ? count : (scalar_t)1);
    } else if (REDUCE == MIN || REDUCE == MAX) {
      if (count > 0) {
        *address = val;
        *arg_address = arg;
      } else {
        *address = (scalar_t)0;
      }
    }
  }
};

std::tuple<at::Tensor, at::optional<at::Tensor>>
spmm(at::Tensor rowptr, at::Tensor col, at::optional<at::Tensor> value_opt,
     at::Tensor mat, std::string reduce) {
  CHECK_CPU(rowptr);
  CHECK_CPU(col);
  if (value_opt.has_value())
    CHECK_CPU(value_opt.value());
  CHECK_CPU(mat);

  mat = mat.contiguous();

  AT_ASSERTM(rowptr.dim() == 1, "Input mismatch");
  AT_ASSERTM(col.dim() == 1, "Input mismatch");
  if (value_opt.has_value())
    AT_ASSERTM(value_opt.value().dim() == 1);
  AT_ASSERTM(mat.dim() >= 2, "Input mismatch");
  AT_ASSERTM(rowptr.numel() - 1 == mat.size(-2), "Input mismatch");

  auto sizes = mat.sizes().vec();
  sizes[mat.dim() - 2] = rowptr.numel() - 1;
  auto out = at::empty(sizes, mat.options());

  at::optional<at::Tensor> arg_out = at::nullopt;
  int64_t *arg_out_data = nullptr;
  if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) {
    arg_out = at::full_like(out, mat.size(-2), rowptr.options());
    arg_out_data = arg_out.value().DATA_PTR<int64_t>();
  }

  auto rowptr_data = rowptr.DATA_PTR<int64_t>();
  auto col_data = col.DATA_PTR<int64_t>();

  int N = rowptr.numel() - 1;
  int M = mat.size(-2);
  int K = mat.size(-1);
  int B = mat.numel() / (M * K);

  AT_DISPATCH_ALL_TYPES(mat.scalar_type(), "spmm", [&] {
    scalar_t *value_data = nullptr;
    auto mat_data = out.DATA_PTR<scalar_t>();
    auto out_data = mat.DATA_PTR<scalar_t>();

    scalar_t val;
    std::vector<scalar_t> vals(K);
    int64_t row_start, row_end, col_idx;
    std::vector<int64_t> args(K);

    AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
      AT_DISPATCH_HAS_VAL(value_opt, [&] {
        if (HAS_VAL) {
          value_data = value_opt.value().DATA_PTR<scalar_t>();
        }

        for (int b = 0; b < B; b++) {
          for (int n = 0; n < N; n++) {
            row_start = rowptr_data[n], row_end = rowptr_data[n + 1];

            for (int k = 0; k < K; k++)
              vals[k] = Reducer<scalar_t, REDUCE>::init();

            int offset = b * M * K;
            for (int e = row_start; e < row_end; e++) {
              col_idx = col_data[e];
              if (HAS_VAL)
                val = value_data[e];
              for (int k = 0; k < K; k++) {
                if (HAS_VAL)
                  Reducer<scalar_t, REDUCE>::update(
                      &vals[k], val * mat_data[offset + col_idx * K + k],
                      &args[k], e);
                else
                  Reducer<scalar_t, REDUCE>::update(
                      &vals[k], mat_data[offset + col_idx * K + k], &args[k],
                      e);
              }
            }
            offset = b * N * K + n * K;
            for (int k = 0; k < K; k++)
              Reducer<scalar_t, REDUCE>::write(out_data + offset + k, vals[k],
                                               arg_out_data + offset + k,
                                               args[k], row_end - row_start);
          }
        }
      });
    });
  });

  return std::make_tuple(out, arg_out);
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("spmm", &spmm, "Sparse-Dense Matrix Multiplication (CPU)");
}