"vscode:/vscode.git/clone" did not exist on "64e0e383148a613c327d4bf9e866b7a185df8277"
spspmm.cc 5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
/**
 *  Copyright (c) 2022 by Contributors
 * @file spspmm.cc
 * @brief DGL C++ sparse SpSpMM operator implementation.
 */

#include <sparse/sddmm.h>
#include <sparse/sparse_matrix.h>
#include <sparse/spspmm.h>
#include <torch/script.h>

#include "./matmul.h"
#include "./utils.h"

namespace dgl {
namespace sparse {

using namespace torch::autograd;

class SpSpMMAutoGrad : public Function<SpSpMMAutoGrad> {
 public:
  static variable_list forward(
      AutogradContext* ctx, c10::intrusive_ptr<SparseMatrix> lhs_mat,
      torch::Tensor lhs_val, c10::intrusive_ptr<SparseMatrix> rhs_mat,
      torch::Tensor rhs_val);

  static tensor_list backward(AutogradContext* ctx, tensor_list grad_outputs);
};

void _SpSpMMSanityCheck(
    const c10::intrusive_ptr<SparseMatrix>& lhs_mat,
    const c10::intrusive_ptr<SparseMatrix>& rhs_mat) {
  const auto& lhs_shape = lhs_mat->shape();
  const auto& rhs_shape = rhs_mat->shape();
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
  TORCH_CHECK(
      lhs_shape[1] == rhs_shape[0],
      "SpSpMM: the second dim of lhs_mat should be equal to the first dim ",
      "of the second matrix");
  TORCH_CHECK(
      lhs_mat->value().dim() == 1,
      "SpSpMM: the value shape of lhs_mat should be 1-D");
  TORCH_CHECK(
      rhs_mat->value().dim() == 1,
      "SpSpMM: the value shape of rhs_mat should be 1-D");
  TORCH_CHECK(
      lhs_mat->device() == rhs_mat->device(),
      "SpSpMM: lhs_mat and rhs_mat should be on the same device");
  TORCH_CHECK(
      lhs_mat->dtype() == rhs_mat->dtype(),
      "SpSpMM: lhs_mat and rhs_mat should have the same dtype");
  TORCH_CHECK(
      !lhs_mat->HasDuplicate(),
      "SpSpMM does not support lhs_mat with duplicate indices. ",
      "Call A = A.coalesce() to dedup first.");
  TORCH_CHECK(
      !rhs_mat->HasDuplicate(),
      "SpSpMM does not support rhs_mat with duplicate indices. ",
      "Call A = A.coalesce() to dedup first.");
59
60
61
62
63
64
65
66
67
68
69
}

// Mask select value of `mat` by `sub_mat`.
torch::Tensor _CSRMask(
    const c10::intrusive_ptr<SparseMatrix>& mat, torch::Tensor value,
    const c10::intrusive_ptr<SparseMatrix>& sub_mat) {
  auto csr = CSRToOldDGLCSR(mat->CSRPtr());
  auto val = TorchTensorToDGLArray(value);
  auto row = TorchTensorToDGLArray(sub_mat->COOPtr()->row);
  auto col = TorchTensorToDGLArray(sub_mat->COOPtr()->col);
  runtime::NDArray ret;
70
71
72
73
74
75
76
77
78
  if (val->dtype.bits == 32) {
    ret = aten::CSRGetData<float>(csr, row, col, val, 0.);
  } else if (val->dtype.bits == 64) {
    ret = aten::CSRGetData<double>(csr, row, col, val, 0.);
  } else {
    TORCH_CHECK(
        false, "Dtype of value for SpSpMM should be 32 or 64 bits but got: " +
                   std::to_string(val->dtype.bits));
  }
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
  return DGLArrayToTorchTensor(ret);
}

variable_list SpSpMMAutoGrad::forward(
    AutogradContext* ctx, c10::intrusive_ptr<SparseMatrix> lhs_mat,
    torch::Tensor lhs_val, c10::intrusive_ptr<SparseMatrix> rhs_mat,
    torch::Tensor rhs_val) {
  auto ret_mat =
      SpSpMMNoAutoGrad(lhs_mat, lhs_val, rhs_mat, rhs_val, false, false);

  ctx->saved_data["lhs_mat"] = lhs_mat;
  ctx->saved_data["rhs_mat"] = rhs_mat;
  ctx->saved_data["ret_mat"] = ret_mat;
  ctx->saved_data["lhs_require_grad"] = lhs_val.requires_grad();
  ctx->saved_data["rhs_require_grad"] = rhs_val.requires_grad();
  ctx->save_for_backward({lhs_val, rhs_val});

  auto csr = ret_mat->CSRPtr();
  auto val = ret_mat->value();
98
  TORCH_CHECK(!csr->value_indices.has_value());
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
  return {csr->indptr, csr->indices, val};
}

tensor_list SpSpMMAutoGrad::backward(
    AutogradContext* ctx, tensor_list grad_outputs) {
  auto saved = ctx->get_saved_variables();
  auto lhs_val = saved[0];
  auto rhs_val = saved[1];
  auto output_grad = grad_outputs[2];
  auto lhs_mat = ctx->saved_data["lhs_mat"].toCustomClass<SparseMatrix>();
  auto rhs_mat = ctx->saved_data["rhs_mat"].toCustomClass<SparseMatrix>();
  auto ret_mat = ctx->saved_data["ret_mat"].toCustomClass<SparseMatrix>();
  torch::Tensor lhs_val_grad, rhs_val_grad;

  if (ctx->saved_data["lhs_require_grad"].toBool()) {
    // A @ B = C -> dA = dC @ (B^T)
    auto lhs_mat_grad =
        SpSpMMNoAutoGrad(ret_mat, output_grad, rhs_mat, rhs_val, false, true);
    lhs_val_grad = _CSRMask(lhs_mat_grad, lhs_mat_grad->value(), lhs_mat);
  }
  if (ctx->saved_data["rhs_require_grad"].toBool()) {
    // A @ B = C -> dB = (A^T) @ dC
    auto rhs_mat_grad =
        SpSpMMNoAutoGrad(lhs_mat, lhs_val, ret_mat, output_grad, true, false);
    rhs_val_grad = _CSRMask(rhs_mat_grad, rhs_mat_grad->value(), rhs_mat);
  }
  return {torch::Tensor(), lhs_val_grad, torch::Tensor(), rhs_val_grad};
}

c10::intrusive_ptr<SparseMatrix> SpSpMM(
    const c10::intrusive_ptr<SparseMatrix>& lhs_mat,
    const c10::intrusive_ptr<SparseMatrix>& rhs_mat) {
  _SpSpMMSanityCheck(lhs_mat, rhs_mat);
  auto results = SpSpMMAutoGrad::apply(
      lhs_mat, lhs_mat->value(), rhs_mat, rhs_mat->value());
  std::vector<int64_t> ret_shape({lhs_mat->shape()[0], rhs_mat->shape()[1]});
  auto indptr = results[0];
  auto indices = results[1];
  auto value = results[2];
  return CreateFromCSR(indptr, indices, value, ret_shape);
}

}  // namespace sparse
}  // namespace dgl