spmm.cpp 7.76 KB
Newer Older
rusty1s's avatar
matmul  
rusty1s committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
#include <torch/script.h>

#include "cpu/spmm_cpu.h"

#ifdef WITH_CUDA
#include "cuda/spmm_cuda.h"
#endif

std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
spmm_fw(torch::Tensor rowptr, torch::Tensor col,
        torch::optional<torch::Tensor> optional_value, torch::Tensor mat,
        std::string reduce) {
  if (rowptr.device().is_cuda()) {
#ifdef WITH_CUDA
    return spmm_cuda(rowptr, col, optional_value, mat, reduce);
#else
    AT_ERROR("Not compiled with CUDA support");
#endif
  } else {
    return spmm_cpu(rowptr, col, optional_value, mat, reduce);
  }
}

torch::Tensor spmm_value_bw(torch::Tensor row, torch::Tensor rowptr,
                            torch::Tensor col, torch::Tensor mat,
                            torch::Tensor grad, std::string reduce) {
27
  if (row.device().is_cuda()) {
rusty1s's avatar
matmul  
rusty1s committed
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
#ifdef WITH_CUDA
    return spmm_value_bw_cuda(row, rowptr, col, mat, grad, reduce);
#else
    AT_ERROR("Not compiled with CUDA support");
#endif
  } else {
    return spmm_value_bw_cpu(row, rowptr, col, mat, grad, reduce);
  }
}

using torch::autograd::AutogradContext;
using torch::autograd::Variable;
using torch::autograd::variable_list;

class SPMMSum : public torch::autograd::Function<SPMMSum> {
public:
  static variable_list forward(AutogradContext *ctx,
45
                               torch::optional<Variable> opt_row,
rusty1s's avatar
matmul  
rusty1s committed
46
                               Variable rowptr, Variable col, Variable value,
47
48
                               torch::optional<Variable> opt_colptr,
                               torch::optional<Variable> opt_csr2csc,
rusty1s's avatar
rusty1s committed
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
                               Variable mat, bool has_value) {

    if (has_value && torch::autograd::any_variable_requires_grad({value})) {
      AT_ASSERTM(opt_row.has_value(), "Argument `row` is missing");
    }

    if (torch::autograd::any_variable_requires_grad({mat})) {
      AT_ASSERTM(opt_row.has_value(), "Argument `row` is missing");
      AT_ASSERTM(opt_colptr.has_value(), "Argument `colptr` is missing");
      AT_ASSERTM(opt_csr2csc.has_value(), "Argument `csr2csc` is missing");
    }

    auto row = opt_row.has_value() ? opt_row.value() : col;
    auto colptr = opt_colptr.has_value() ? opt_colptr.value() : col;
    auto csr2csc = opt_csr2csc.has_value() ? opt_csr2csc.value() : col;
64
65

    torch::optional<torch::Tensor> opt_value = torch::nullopt;
rusty1s's avatar
rusty1s committed
66
    if (has_value)
67
68
69
      opt_value = value;

    auto out = std::get<0>(spmm_fw(rowptr, col, opt_value, mat, "sum"));
rusty1s's avatar
rusty1s committed
70
    ctx->saved_data["has_value"] = has_value;
rusty1s's avatar
matmul  
rusty1s committed
71
72
73
74
75
    ctx->save_for_backward({row, rowptr, col, value, colptr, csr2csc, mat});
    return {out};
  }

  static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
rusty1s's avatar
rusty1s committed
76
    auto has_value = ctx->saved_data["has_value"].toBool();
rusty1s's avatar
matmul  
rusty1s committed
77
78
    auto grad_out = grad_outs[0];
    auto saved = ctx->get_saved_variables();
79
80
    auto row = saved[0], rowptr = saved[1], col = saved[2], value = saved[3],
         colptr = saved[4], csr2csc = saved[5], mat = saved[6];
rusty1s's avatar
matmul  
rusty1s committed
81
82

    auto grad_value = Variable();
rusty1s's avatar
rusty1s committed
83
    if (has_value > 0 && torch::autograd::any_variable_requires_grad({value})) {
rusty1s's avatar
matmul  
rusty1s committed
84
85
86
87
88
      grad_value = spmm_value_bw(row, rowptr, col, mat, grad_out, "sum");
    }

    auto grad_mat = Variable();
    if (torch::autograd::any_variable_requires_grad({mat})) {
89
      torch::optional<torch::Tensor> opt_value = torch::nullopt;
rusty1s's avatar
rusty1s committed
90
      if (has_value)
91
92
        opt_value = value.index_select(0, csr2csc);

rusty1s's avatar
matmul  
rusty1s committed
93
      grad_mat = std::get<0>(spmm_fw(colptr, row.index_select(0, csr2csc),
94
                                     opt_value, grad_out, "sum"));
rusty1s's avatar
matmul  
rusty1s committed
95
96
97
    }

    return {Variable(), Variable(), Variable(), grad_value,
rusty1s's avatar
rusty1s committed
98
            Variable(), Variable(), grad_mat,   Variable()};
rusty1s's avatar
matmul  
rusty1s committed
99
100
101
  }
};

102
103
104
105
106
107
108
109
class SPMMMean : public torch::autograd::Function<SPMMMean> {
public:
  static variable_list forward(AutogradContext *ctx,
                               torch::optional<Variable> opt_row,
                               Variable rowptr, Variable col, Variable value,
                               torch::optional<Variable> opt_rowcount,
                               torch::optional<Variable> opt_colptr,
                               torch::optional<Variable> opt_csr2csc,
rusty1s's avatar
rusty1s committed
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
                               Variable mat, bool has_value) {

    if (has_value && torch::autograd::any_variable_requires_grad({value})) {
      AT_ASSERTM(opt_row.has_value(), "Argument `row` is missing");
    }

    if (torch::autograd::any_variable_requires_grad({mat})) {
      AT_ASSERTM(opt_row.has_value(), "Argument `row` is missing");
      AT_ASSERTM(opt_rowcount.has_value(), "Argument `rowcount` is missing");
      AT_ASSERTM(opt_colptr.has_value(), "Argument `colptr` is missing");
      AT_ASSERTM(opt_csr2csc.has_value(), "Argument `csr2csc` is missing");
    }

    auto row = opt_row.has_value() ? opt_row.value() : col;
    auto rowcount = opt_rowcount.has_value() ? opt_rowcount.value() : col;
    auto colptr = opt_colptr.has_value() ? opt_colptr.value() : col;
    auto csr2csc = opt_csr2csc.has_value() ? opt_csr2csc.value() : col;
127
128

    torch::optional<torch::Tensor> opt_value = torch::nullopt;
rusty1s's avatar
rusty1s committed
129
    if (has_value)
130
131
132
      opt_value = value;

    auto out = std::get<0>(spmm_fw(rowptr, col, opt_value, mat, "mean"));
rusty1s's avatar
rusty1s committed
133
    ctx->saved_data["has_value"] = has_value;
134
135
136
137
138
139
    ctx->save_for_backward(
        {row, rowptr, col, value, rowcount, colptr, csr2csc, mat});
    return {out};
  }

  static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
rusty1s's avatar
rusty1s committed
140
    auto has_value = ctx->saved_data["has_value"].toBool();
141
142
143
144
145
146
147
    auto grad_out = grad_outs[0];
    auto saved = ctx->get_saved_variables();
    auto row = saved[0], rowptr = saved[1], col = saved[2], value = saved[3],
         rowcount = saved[4], colptr = saved[5], csr2csc = saved[6],
         mat = saved[7];

    auto grad_value = Variable();
rusty1s's avatar
rusty1s committed
148
    if (has_value > 0 && torch::autograd::any_variable_requires_grad({value})) {
149
150
151
152
153
154
155
156
157
      grad_value = spmm_value_bw(row, rowptr, col, mat, grad_out, "mean");
    }

    auto grad_mat = Variable();
    if (torch::autograd::any_variable_requires_grad({mat})) {
      row = row.index_select(0, csr2csc);
      rowcount = rowcount.toType(mat.scalar_type()).index_select(0, row);
      rowcount.clamp_(1);

rusty1s's avatar
rusty1s committed
158
      if (has_value > 0)
159
160
161
162
163
164
165
        rowcount = value.index_select(0, csr2csc).div(rowcount);
      else
        rowcount.pow_(-1);

      grad_mat = std::get<0>(spmm_fw(colptr, row, rowcount, grad_out, "sum"));
    }

rusty1s's avatar
rusty1s committed
166
167
    return {Variable(), Variable(), Variable(), grad_value, Variable(),
            Variable(), Variable(), grad_mat,   Variable()};
168
169
170
171
  }
};

torch::Tensor spmm_sum(torch::optional<torch::Tensor> opt_row,
rusty1s's avatar
matmul  
rusty1s committed
172
                       torch::Tensor rowptr, torch::Tensor col,
173
174
175
                       torch::optional<torch::Tensor> opt_value,
                       torch::optional<torch::Tensor> opt_colptr,
                       torch::optional<torch::Tensor> opt_csr2csc,
rusty1s's avatar
matmul  
rusty1s committed
176
                       torch::Tensor mat) {
rusty1s's avatar
rusty1s committed
177
  auto value = opt_value.has_value() ? opt_value.value() : col;
178
  return SPMMSum::apply(opt_row, rowptr, col, value, opt_colptr, opt_csr2csc,
rusty1s's avatar
rusty1s committed
179
                        mat, opt_value.has_value())[0];
180
181
182
183
184
185
186
187
188
}

torch::Tensor spmm_mean(torch::optional<torch::Tensor> opt_row,
                        torch::Tensor rowptr, torch::Tensor col,
                        torch::optional<torch::Tensor> opt_value,
                        torch::optional<torch::Tensor> opt_rowcount,
                        torch::optional<torch::Tensor> opt_colptr,
                        torch::optional<torch::Tensor> opt_csr2csc,
                        torch::Tensor mat) {
rusty1s's avatar
rusty1s committed
189
  auto value = opt_value.has_value() ? opt_value.value() : col;
190
  return SPMMMean::apply(opt_row, rowptr, col, value, opt_rowcount, opt_colptr,
rusty1s's avatar
rusty1s committed
191
                         opt_csr2csc, mat, opt_value.has_value())[0];
rusty1s's avatar
matmul  
rusty1s committed
192
193
}

194
195
196
static auto registry = torch::RegisterOperators()
                           .op("torch_sparse::spmm_sum", &spmm_sum)
                           .op("torch_sparse::spmm_mean", &spmm_mean);