"vscode:/vscode.git/clone" did not exist on "1168eaaadd69457d1e460512ab235b29bc552907"
spmm.cc 3.91 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
/**
 *  Copyright (c) 2022 by Contributors
 * @file spmm.cc
 * @brief DGL C++ sparse SpMM operator implementation.
 */

#include <sparse/sddmm.h>
#include <sparse/sparse_matrix.h>
#include <sparse/spmm.h>
#include <torch/script.h>

#include "./matmul.h"
#include "./utils.h"

namespace dgl {
namespace sparse {

using namespace torch::autograd;

class SpMMAutoGrad : public Function<SpMMAutoGrad> {
 public:
  static torch::Tensor forward(
      AutogradContext* ctx, c10::intrusive_ptr<SparseMatrix> sparse_mat,
      torch::Tensor sparse_val, torch::Tensor dense_mat);

  static tensor_list backward(AutogradContext* ctx, tensor_list grad_outputs);
};

void _SpMMSanityCheck(
    c10::intrusive_ptr<SparseMatrix> sparse_mat, torch::Tensor sparse_val,
    torch::Tensor dense_mat) {
  const auto& sparse_mat_shape = sparse_mat->shape();
  auto val_shape = sparse_val.sizes();
  auto dense_shape = dense_mat.sizes();
  CHECK_EQ(sparse_mat_shape[1], dense_shape[0])
      << "SpMM: the second dimension of the sparse matrix should be equal to "
         "the first dimension of the dense matrix.";
  CHECK_EQ(val_shape.size(), 1)
      << "SpMM: the values tensor for SpMM can only be 1-dimensional.";
  CHECK_EQ(val_shape[0], sparse_mat->nnz())
      << "SpMM: the value shape does not match nnz of the sparse matrix.";
  CHECK_LE(dense_shape.size(), 2)
      << "SpMM: the dense matrix can have at most two dimensions.";
  CHECK_EQ(sparse_val.dtype(), dense_mat.dtype())
      << "SpMM: the non-zero values does not have the same dtype as the dense "
         "matrix.";
  CHECK(
      sparse_val.device() == sparse_mat->device() &&
      sparse_val.device() == dense_mat.device())
      << "SpMM: sparse matrix, non-zero values and the dense matrix should be "
         "on the same device.";
}

torch::Tensor SpMMAutoGrad::forward(
    AutogradContext* ctx, c10::intrusive_ptr<SparseMatrix> sparse_mat,
    torch::Tensor sparse_val, torch::Tensor dense_mat) {
  auto ret = SpMMNoAutoGrad(sparse_mat, sparse_val, dense_mat, false);

  const bool sparse_requires_grad = sparse_val.requires_grad();
  const bool dense_requires_grad = dense_mat.requires_grad();
  torch::Tensor cache_sparse_val, cache_dense_mat;
  if (dense_requires_grad) {
    cache_sparse_val = sparse_val;
  }
  if (sparse_requires_grad) {
    cache_dense_mat = dense_mat;
  }
  ctx->saved_data["sparse_matrix"] = sparse_mat;
  ctx->saved_data["sparse_requires_grad"] = sparse_requires_grad;
  ctx->saved_data["dense_requires_grad"] = dense_requires_grad;
  ctx->save_for_backward({cache_sparse_val, cache_dense_mat});
  return ret;
}

tensor_list SpMMAutoGrad::backward(
    AutogradContext* ctx, tensor_list grad_outputs) {
  auto saved = ctx->get_saved_variables();
  auto sparse_val = saved[0];
  auto dense_mat = saved[1];
  auto output_grad = grad_outputs[0];

  auto sparse_mat =
      ctx->saved_data["sparse_matrix"].toCustomClass<SparseMatrix>();
  const bool sparse_requires_grad =
      ctx->saved_data["sparse_requires_grad"].toBool();
  const bool dense_requires_grad =
      ctx->saved_data["dense_requires_grad"].toBool();

  torch::Tensor dense_mat_grad, sparse_val_grad;
  if (sparse_requires_grad) {
    // A @ B = C -> dA = dC @ (B^T)
    sparse_val_grad = SDDMMNoAutoGrad(sparse_mat, output_grad, dense_mat);
  }
  if (dense_requires_grad) {
    // A @ B = C -> dB = (A^T) @ dC
    dense_mat_grad = SpMMNoAutoGrad(sparse_mat, sparse_val, output_grad, true);
  }
  return {torch::Tensor(), sparse_val_grad, dense_mat_grad};
}

torch::Tensor SpMM(
    const c10::intrusive_ptr<SparseMatrix>& sparse_mat,
    torch::Tensor dense_mat) {
  _SpMMSanityCheck(sparse_mat, sparse_mat->value(), dense_mat);
  bool expand_dim = false;
  if (dense_mat.dim() == 1) {
    dense_mat = dense_mat.view({-1, 1});
    expand_dim = true;
  }
  auto ret = SpMMAutoGrad::apply(sparse_mat, sparse_mat->value(), dense_mat);
  if (expand_dim) {
    ret = ret.view(-1);
  }
  return ret;
}

}  // namespace sparse
}  // namespace dgl