spmm.cpp 8.83 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
#include <torch/extension.h>

#include "compat.h"

#define CHECK_CPU(x) AT_ASSERTM(!x.type().is_cuda(), #x " must be CPU tensor")

enum ReductionType { SUM, MEAN, MIN, MAX };

const std::map<std::string, ReductionType> reduce2REDUCE = {
    {"sum", SUM}, {"add", SUM}, {"mean", MEAN}, {"min", MIN}, {"max", MAX},
};

#define AT_DISPATCH_REDUCTION_TYPES(reduce, ...)                               \
  [&] {                                                                        \
    switch (reduce2REDUCE.at(reduce)) {                                        \
    case SUM: {                                                                \
      const ReductionType REDUCE = SUM;                                        \
      return __VA_ARGS__();                                                    \
    }                                                                          \
    case MEAN: {                                                               \
      const ReductionType REDUCE = MEAN;                                       \
      return __VA_ARGS__();                                                    \
    }                                                                          \
    case MIN: {                                                                \
      const ReductionType REDUCE = MIN;                                        \
      return __VA_ARGS__();                                                    \
    }                                                                          \
    case MAX: {                                                                \
      const ReductionType REDUCE = MAX;                                        \
      return __VA_ARGS__();                                                    \
    }                                                                          \
    }                                                                          \
  }()

#define AT_DISPATCH_HAS_VAL(value_opt, ...)                                    \
  [&] {                                                                        \
    switch (value_opt.has_value()) {                                           \
    case true: {                                                               \
      const bool HAS_VAL = true;                                               \
      return __VA_ARGS__();                                                    \
    }                                                                          \
    case false: {                                                              \
      const bool HAS_VAL = false;                                              \
      return __VA_ARGS__();                                                    \
    }                                                                          \
    }                                                                          \
  }()

template <typename scalar_t, ReductionType REDUCE> struct Reducer {
  static inline scalar_t init() {
    if (REDUCE == MIN) {
      return std::numeric_limits<scalar_t>::max();
    } else if (REDUCE == MAX) {
      return std::numeric_limits<scalar_t>::lowest();
    } else {
      return (scalar_t)0;
    }
  }

  static inline void update(scalar_t *val, scalar_t new_val, int64_t *arg,
                            int64_t new_arg) {
    if (REDUCE == SUM || REDUCE == MEAN) {
      *val = *val + new_val;
    } else if ((REDUCE == MIN && new_val < *val) ||
               (REDUCE == MAX && new_val > *val)) {
      *val = new_val;
      *arg = new_arg;
    }
  }

  static inline void write(scalar_t *address, scalar_t val,
                           int64_t *arg_address, int64_t arg, int count) {
    if (REDUCE == SUM) {
      *address = val;
    } else if (REDUCE == MEAN) {
      *address = val / (count > 0 ? count : (scalar_t)1);
    } else if (REDUCE == MIN || REDUCE == MAX) {
      if (count > 0) {
        *address = val;
        *arg_address = arg;
      } else {
        *address = (scalar_t)0;
      }
    }
  }
};

std::tuple<at::Tensor, at::optional<at::Tensor>>
spmm(at::Tensor rowptr, at::Tensor col, at::optional<at::Tensor> value_opt,
     at::Tensor mat, std::string reduce) {
rusty1s's avatar
rusty1s committed
91

92
93
94
95
96
97
98
99
100
101
102
103
  CHECK_CPU(rowptr);
  CHECK_CPU(col);
  if (value_opt.has_value())
    CHECK_CPU(value_opt.value());
  CHECK_CPU(mat);

  AT_ASSERTM(rowptr.dim() == 1, "Input mismatch");
  AT_ASSERTM(col.dim() == 1, "Input mismatch");
  if (value_opt.has_value())
    AT_ASSERTM(value_opt.value().dim() == 1);
  AT_ASSERTM(mat.dim() >= 2, "Input mismatch");

rusty1s's avatar
rusty1s committed
104
105
  mat = mat.contiguous();

106
107
108
109
110
111
112
  auto sizes = mat.sizes().vec();
  sizes[mat.dim() - 2] = rowptr.numel() - 1;
  auto out = at::empty(sizes, mat.options());

  at::optional<at::Tensor> arg_out = at::nullopt;
  int64_t *arg_out_data = nullptr;
  if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) {
rusty1s's avatar
rusty1s committed
113
    arg_out = at::full_like(out, col.numel(), rowptr.options());
114
115
116
117
118
119
    arg_out_data = arg_out.value().DATA_PTR<int64_t>();
  }

  auto rowptr_data = rowptr.DATA_PTR<int64_t>();
  auto col_data = col.DATA_PTR<int64_t>();

rusty1s's avatar
rusty1s committed
120
121
  auto M = rowptr.numel() - 1;
  auto N = mat.size(-2);
rusty1s's avatar
rusty1s committed
122
  auto K = mat.size(-1);
rusty1s's avatar
rusty1s committed
123
  auto B = mat.numel() / (N * K);
124
125
126

  AT_DISPATCH_ALL_TYPES(mat.scalar_type(), "spmm", [&] {
    scalar_t *value_data = nullptr;
rusty1s's avatar
rusty1s committed
127
128
    auto mat_data = mat.DATA_PTR<scalar_t>();
    auto out_data = out.DATA_PTR<scalar_t>();
129
130
131

    scalar_t val;
    std::vector<scalar_t> vals(K);
rusty1s's avatar
rusty1s committed
132
    int64_t row_start, row_end, c;
133
134
135
136
137
138
139
140
141
    std::vector<int64_t> args(K);

    AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
      AT_DISPATCH_HAS_VAL(value_opt, [&] {
        if (HAS_VAL) {
          value_data = value_opt.value().DATA_PTR<scalar_t>();
        }

        for (int b = 0; b < B; b++) {
rusty1s's avatar
rusty1s committed
142
143
          for (int m = 0; m < M; m++) {
            row_start = rowptr_data[m], row_end = rowptr_data[m + 1];
144
145
146
147

            for (int k = 0; k < K; k++)
              vals[k] = Reducer<scalar_t, REDUCE>::init();

rusty1s's avatar
rusty1s committed
148
            int offset = b * N * K;
149
            for (int e = row_start; e < row_end; e++) {
rusty1s's avatar
rusty1s committed
150
              c = col_data[e];
151
152
153
154
155
              if (HAS_VAL)
                val = value_data[e];
              for (int k = 0; k < K; k++) {
                if (HAS_VAL)
                  Reducer<scalar_t, REDUCE>::update(
rusty1s's avatar
rusty1s committed
156
157
                      &vals[k], val * mat_data[offset + c * K + k], &args[k],
                      e);
158
159
                else
                  Reducer<scalar_t, REDUCE>::update(
rusty1s's avatar
rusty1s committed
160
                      &vals[k], mat_data[offset + c * K + k], &args[k], e);
161
162
              }
            }
rusty1s's avatar
rusty1s committed
163
            offset = b * M * K + m * K;
164
165
166
167
168
169
170
171
172
173
174
175
176
            for (int k = 0; k < K; k++)
              Reducer<scalar_t, REDUCE>::write(out_data + offset + k, vals[k],
                                               arg_out_data + offset + k,
                                               args[k], row_end - row_start);
          }
        }
      });
    });
  });

  return std::make_tuple(out, arg_out);
}

rusty1s's avatar
rusty1s committed
177
at::Tensor spmm_val_bw(at::Tensor index, at::Tensor rowptr, at::Tensor mat,
rusty1s's avatar
rusty1s committed
178
                       at::Tensor grad, std::string reduce) {
rusty1s's avatar
rusty1s committed
179
  CHECK_CPU(index);
rusty1s's avatar
rusty1s committed
180
181
182
183
  CHECK_CPU(rowptr);
  CHECK_CPU(mat);
  CHECK_CPU(grad);

rusty1s's avatar
rusty1s committed
184
185
186
187
188
189
190
191
192
193
  AT_ASSERTM(index.dim() == 2, "Input mismatch");
  AT_ASSERTM(index.size(0) == 2, "Input mismatch");
  AT_ASSERTM(rowptr.dim() == 1, "Input mismatch");
  AT_ASSERTM(mat.dim() >= 2, "Input mismatch");
  AT_ASSERTM(mat.dim() == grad.dim(), "Input mismatch");
  AT_ASSERTM(reduce2REDUCE.at(reduce) == SUM ||
                 reduce2REDUCE.at(reduce) == MEAN,
             "Reduce operation not supported");

  index = index.contiguous();
rusty1s's avatar
rusty1s committed
194
  mat = mat.contiguous();
rusty1s's avatar
rusty1s committed
195
  grad = grad.contiguous();
rusty1s's avatar
rusty1s committed
196

rusty1s's avatar
rusty1s committed
197
  auto M = grad.size(-2);
rusty1s's avatar
rusty1s committed
198
  auto N = mat.size(-2);
rusty1s's avatar
rusty1s committed
199
  auto E = index.size(1);
rusty1s's avatar
rusty1s committed
200
201
202
  auto K = mat.size(-1);
  auto B = mat.numel() / (N * K);

rusty1s's avatar
rusty1s committed
203
  auto out = at::zeros(index.size(1), grad.options());
rusty1s's avatar
rusty1s committed
204

rusty1s's avatar
rusty1s committed
205
  auto index_data = index.DATA_PTR<int64_t>();
rusty1s's avatar
rusty1s committed
206
207
208
209
210
211
212
  auto rowptr_data = rowptr.DATA_PTR<int64_t>();
  AT_DISPATCH_ALL_TYPES(mat.scalar_type(), "spmm_val_bw", [&] {
    auto mat_data = mat.DATA_PTR<scalar_t>();
    auto grad_data = grad.DATA_PTR<scalar_t>();
    auto out_data = out.DATA_PTR<scalar_t>();

    scalar_t val;
rusty1s's avatar
rusty1s committed
213
    int64_t row, col;
rusty1s's avatar
rusty1s committed
214
215
    AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
      for (int b = 0; b < B; b++) {
rusty1s's avatar
rusty1s committed
216
217
218
219
220
221
222
223
224
        for (int e = 0; e < E; e++) {
          row = index_data[e], col = index_data[E + e], val = (scalar_t)0;
          for (int k = 0; k < K; k++) {
            val += mat_data[b * N * K + col * K + k] *
                   grad_data[b * M * K + row * K + k];
          }
          if (REDUCE == MEAN) {
            int row_start = rowptr_data[row], row_end = rowptr_data[row + 1];
            val /= (scalar_t)std::max(row_end - row_start, 1);
rusty1s's avatar
rusty1s committed
225
          }
rusty1s's avatar
rusty1s committed
226
          out_data[e] += val;
rusty1s's avatar
rusty1s committed
227
228
229
230
231
232
233
234
        }
      }
    });
  });

  return out;
}

235
236
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("spmm", &spmm, "Sparse-Dense Matrix Multiplication (CPU)");
rusty1s's avatar
rusty1s committed
237
238
  m.def("spmm_val_bw", &spmm_val_bw,
        "Sparse-Dense Matrix Multiplication Value Backward (CPU)");
239
}