"docs/vscode:/vscode.git/clone" did not exist on "4efec7326970ddc31aaea7faff8f49dca45b41c7"
spmm.cpp 8.49 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
#include <torch/script.h>
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87

#include "compat.h"

#define CHECK_CPU(x) AT_ASSERTM(!x.type().is_cuda(), #x " must be CPU tensor")

enum ReductionType { SUM, MEAN, MIN, MAX };

const std::map<std::string, ReductionType> reduce2REDUCE = {
    {"sum", SUM}, {"add", SUM}, {"mean", MEAN}, {"min", MIN}, {"max", MAX},
};

#define AT_DISPATCH_REDUCTION_TYPES(reduce, ...)                               \
  [&] {                                                                        \
    switch (reduce2REDUCE.at(reduce)) {                                        \
    case SUM: {                                                                \
      const ReductionType REDUCE = SUM;                                        \
      return __VA_ARGS__();                                                    \
    }                                                                          \
    case MEAN: {                                                               \
      const ReductionType REDUCE = MEAN;                                       \
      return __VA_ARGS__();                                                    \
    }                                                                          \
    case MIN: {                                                                \
      const ReductionType REDUCE = MIN;                                        \
      return __VA_ARGS__();                                                    \
    }                                                                          \
    case MAX: {                                                                \
      const ReductionType REDUCE = MAX;                                        \
      return __VA_ARGS__();                                                    \
    }                                                                          \
    }                                                                          \
  }()

#define AT_DISPATCH_HAS_VAL(value_opt, ...)                                    \
  [&] {                                                                        \
    switch (value_opt.has_value()) {                                           \
    case true: {                                                               \
      const bool HAS_VAL = true;                                               \
      return __VA_ARGS__();                                                    \
    }                                                                          \
    case false: {                                                              \
      const bool HAS_VAL = false;                                              \
      return __VA_ARGS__();                                                    \
    }                                                                          \
    }                                                                          \
  }()

template <typename scalar_t, ReductionType REDUCE> struct Reducer {
  static inline scalar_t init() {
    if (REDUCE == MIN) {
      return std::numeric_limits<scalar_t>::max();
    } else if (REDUCE == MAX) {
      return std::numeric_limits<scalar_t>::lowest();
    } else {
      return (scalar_t)0;
    }
  }

  static inline void update(scalar_t *val, scalar_t new_val, int64_t *arg,
                            int64_t new_arg) {
    if (REDUCE == SUM || REDUCE == MEAN) {
      *val = *val + new_val;
    } else if ((REDUCE == MIN && new_val < *val) ||
               (REDUCE == MAX && new_val > *val)) {
      *val = new_val;
      *arg = new_arg;
    }
  }

  static inline void write(scalar_t *address, scalar_t val,
                           int64_t *arg_address, int64_t arg, int count) {
    if (REDUCE == SUM) {
      *address = val;
    } else if (REDUCE == MEAN) {
      *address = val / (count > 0 ? count : (scalar_t)1);
    } else if (REDUCE == MIN || REDUCE == MAX) {
      if (count > 0) {
        *address = val;
        *arg_address = arg;
      } else {
        *address = (scalar_t)0;
      }
    }
  }
};

rusty1s's avatar
rusty1s committed
88
89
90
91
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
spmm(torch::Tensor rowptr, torch::Tensor col,
     torch::optional<torch::Tensor> value_opt, torch::Tensor mat,
     std::string reduce) {
rusty1s's avatar
rusty1s committed
92

93
94
95
96
97
98
99
100
101
102
103
104
  CHECK_CPU(rowptr);
  CHECK_CPU(col);
  if (value_opt.has_value())
    CHECK_CPU(value_opt.value());
  CHECK_CPU(mat);

  AT_ASSERTM(rowptr.dim() == 1, "Input mismatch");
  AT_ASSERTM(col.dim() == 1, "Input mismatch");
  if (value_opt.has_value())
    AT_ASSERTM(value_opt.value().dim() == 1);
  AT_ASSERTM(mat.dim() >= 2, "Input mismatch");

rusty1s's avatar
rusty1s committed
105
106
  mat = mat.contiguous();

107
108
  auto sizes = mat.sizes().vec();
  sizes[mat.dim() - 2] = rowptr.numel() - 1;
rusty1s's avatar
rusty1s committed
109
  auto out = torch::empty(sizes, mat.options());
110

rusty1s's avatar
rusty1s committed
111
  torch::optional<torch::Tensor> arg_out = torch::nullopt;
112
113
  int64_t *arg_out_data = nullptr;
  if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) {
rusty1s's avatar
rusty1s committed
114
    arg_out = torch::full_like(out, col.numel(), rowptr.options());
115
116
117
118
119
120
    arg_out_data = arg_out.value().DATA_PTR<int64_t>();
  }

  auto rowptr_data = rowptr.DATA_PTR<int64_t>();
  auto col_data = col.DATA_PTR<int64_t>();

rusty1s's avatar
rusty1s committed
121
122
  auto M = rowptr.numel() - 1;
  auto N = mat.size(-2);
rusty1s's avatar
rusty1s committed
123
  auto K = mat.size(-1);
rusty1s's avatar
rusty1s committed
124
  auto B = mat.numel() / (N * K);
125
126
127

  AT_DISPATCH_ALL_TYPES(mat.scalar_type(), "spmm", [&] {
    scalar_t *value_data = nullptr;
rusty1s's avatar
rusty1s committed
128
129
    auto mat_data = mat.DATA_PTR<scalar_t>();
    auto out_data = out.DATA_PTR<scalar_t>();
130
131
132

    scalar_t val;
    std::vector<scalar_t> vals(K);
rusty1s's avatar
rusty1s committed
133
    int64_t row_start, row_end, c;
134
135
136
137
138
139
140
141
142
    std::vector<int64_t> args(K);

    AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
      AT_DISPATCH_HAS_VAL(value_opt, [&] {
        if (HAS_VAL) {
          value_data = value_opt.value().DATA_PTR<scalar_t>();
        }

        for (int b = 0; b < B; b++) {
rusty1s's avatar
rusty1s committed
143
144
          for (int m = 0; m < M; m++) {
            row_start = rowptr_data[m], row_end = rowptr_data[m + 1];
145
146
147
148

            for (int k = 0; k < K; k++)
              vals[k] = Reducer<scalar_t, REDUCE>::init();

rusty1s's avatar
rusty1s committed
149
            int offset = b * N * K;
150
            for (int e = row_start; e < row_end; e++) {
rusty1s's avatar
rusty1s committed
151
              c = col_data[e];
152
153
154
155
156
              if (HAS_VAL)
                val = value_data[e];
              for (int k = 0; k < K; k++) {
                if (HAS_VAL)
                  Reducer<scalar_t, REDUCE>::update(
rusty1s's avatar
rusty1s committed
157
158
                      &vals[k], val * mat_data[offset + c * K + k], &args[k],
                      e);
159
160
                else
                  Reducer<scalar_t, REDUCE>::update(
rusty1s's avatar
rusty1s committed
161
                      &vals[k], mat_data[offset + c * K + k], &args[k], e);
162
163
              }
            }
rusty1s's avatar
rusty1s committed
164
            offset = b * M * K + m * K;
165
166
167
168
169
170
171
172
173
174
175
176
177
            for (int k = 0; k < K; k++)
              Reducer<scalar_t, REDUCE>::write(out_data + offset + k, vals[k],
                                               arg_out_data + offset + k,
                                               args[k], row_end - row_start);
          }
        }
      });
    });
  });

  return std::make_tuple(out, arg_out);
}

rusty1s's avatar
rusty1s committed
178
179
180
torch::Tensor spmm_val_bw(torch::Tensor row, torch::Tensor rowptr,
                          torch::Tensor col, torch::Tensor mat,
                          torch::Tensor grad, std::string reduce) {
rusty1s's avatar
rusty1s committed
181
  CHECK_CPU(row);
rusty1s's avatar
rusty1s committed
182
  CHECK_CPU(rowptr);
rusty1s's avatar
rusty1s committed
183
  CHECK_CPU(col);
rusty1s's avatar
rusty1s committed
184
185
186
187
  CHECK_CPU(mat);
  CHECK_CPU(grad);

  mat = mat.contiguous();
rusty1s's avatar
rusty1s committed
188
  grad = grad.contiguous();
rusty1s's avatar
rusty1s committed
189

rusty1s's avatar
rusty1s committed
190
  auto M = grad.size(-2);
rusty1s's avatar
rusty1s committed
191
  auto N = mat.size(-2);
rusty1s's avatar
rusty1s committed
192
  auto E = row.numel();
rusty1s's avatar
rusty1s committed
193
194
195
  auto K = mat.size(-1);
  auto B = mat.numel() / (N * K);

rusty1s's avatar
rusty1s committed
196
  auto out = torch::zeros(row.numel(), grad.options());
rusty1s's avatar
rusty1s committed
197

rusty1s's avatar
rusty1s committed
198
  auto row_data = row.DATA_PTR<int64_t>();
rusty1s's avatar
rusty1s committed
199
  auto rowptr_data = rowptr.DATA_PTR<int64_t>();
rusty1s's avatar
rusty1s committed
200
  auto col_data = col.DATA_PTR<int64_t>();
rusty1s's avatar
rusty1s committed
201
202
203
204
205
206
  AT_DISPATCH_ALL_TYPES(mat.scalar_type(), "spmm_val_bw", [&] {
    auto mat_data = mat.DATA_PTR<scalar_t>();
    auto grad_data = grad.DATA_PTR<scalar_t>();
    auto out_data = out.DATA_PTR<scalar_t>();

    scalar_t val;
rusty1s's avatar
rusty1s committed
207
    int64_t row, col;
rusty1s's avatar
rusty1s committed
208
209
    AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
      for (int b = 0; b < B; b++) {
rusty1s's avatar
rusty1s committed
210
        for (int e = 0; e < E; e++) {
rusty1s's avatar
rusty1s committed
211
          row = row_data[e], col = col_data[e], val = (scalar_t)0;
rusty1s's avatar
rusty1s committed
212
213
214
215
216
217
218
          for (int k = 0; k < K; k++) {
            val += mat_data[b * N * K + col * K + k] *
                   grad_data[b * M * K + row * K + k];
          }
          if (REDUCE == MEAN) {
            int row_start = rowptr_data[row], row_end = rowptr_data[row + 1];
            val /= (scalar_t)std::max(row_end - row_start, 1);
rusty1s's avatar
rusty1s committed
219
          }
rusty1s's avatar
rusty1s committed
220
          out_data[e] += val;
rusty1s's avatar
rusty1s committed
221
222
223
224
225
226
227
228
        }
      }
    });
  });

  return out;
}

rusty1s's avatar
rusty1s committed
229
230
static auto registry = torch::RegisterOperators("torch_sparse_cpu::spmm", &spmm)
                           .op("torch_sparse_cpu::spmm_val_bw", &spmm_val_bw);