"docs/vscode:/vscode.git/clone" did not exist on "dec4fdd60039ee34dae41e4bc52f1fc6deb9ddeb"
spmm.cc 4.33 KB
Newer Older
sangwzh's avatar
sangwzh committed
1
// !!! This is a file automatically generated by hipify!!!
2
3
4
5
6
7
8
9
10
/**
 *  Copyright (c) 2022 by Contributors
 * @file spmm.cc
 * @brief DGL C++ sparse SpMM operator implementation.
 */

#include <sparse/sddmm.h>
#include <sparse/sparse_matrix.h>
#include <sparse/spmm.h>
11
#include <torch/script.h>
12

13
14
#include <sstream>

sangwzh's avatar
sangwzh committed
15
16
#include "matmul.h"
#include "utils.h"
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37

namespace dgl {
namespace sparse {

using namespace torch::autograd;

class SpMMAutoGrad : public Function<SpMMAutoGrad> {
 public:
  static torch::Tensor forward(
      AutogradContext* ctx, c10::intrusive_ptr<SparseMatrix> sparse_mat,
      torch::Tensor sparse_val, torch::Tensor dense_mat);

  static tensor_list backward(AutogradContext* ctx, tensor_list grad_outputs);
};

void _SpMMSanityCheck(
    c10::intrusive_ptr<SparseMatrix> sparse_mat, torch::Tensor sparse_val,
    torch::Tensor dense_mat) {
  const auto& sparse_mat_shape = sparse_mat->shape();
  auto val_shape = sparse_val.sizes();
  auto dense_shape = dense_mat.sizes();
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
  bool shape_check = true;
  shape_check &= sparse_mat_shape[1] == dense_shape[0];
  shape_check &= val_shape.size() <= 2;
  shape_check &= val_shape[0] == sparse_mat->nnz();
  shape_check &= dense_shape.size() <= 3;
  if (dense_shape.size() == 3 || val_shape.size() == 2) {
    shape_check &= dense_shape.size() == val_shape.size() + 1;
    shape_check &= dense_shape[2] == val_shape[1];
  }
  if (!shape_check) {
    std::stringstream error;
    error << "SpMM: Invalid input shapes. sparse_mat: "
          << c10::IntArrayRef(sparse_mat->shape())
          << ", sparse_val: " << sparse_mat->value().sizes()
          << ", dense_mat: " << dense_mat.sizes()
          << ". Valid input shapes (sparse_mat, dense_mat) are: (1) (n, m) and "
             "(m, k); (2) (n, m) and (m,); (3) (n, m, b) and (m, k, b).";
    TORCH_CHECK(false, error.str());
  }
  TORCH_CHECK(
      sparse_val.dtype() == dense_mat.dtype(),
      "SpMM: the non-zero values does not have the same dtype as the dense "
      "matrix.");
  TORCH_CHECK(
62
      sparse_val.device() == sparse_mat->device() &&
63
64
65
          sparse_val.device() == dense_mat.device(),
      "SpMM: sparse matrix, non-zero values and the dense matrix should be "
      "on the same device.");
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
}

torch::Tensor SpMMAutoGrad::forward(
    AutogradContext* ctx, c10::intrusive_ptr<SparseMatrix> sparse_mat,
    torch::Tensor sparse_val, torch::Tensor dense_mat) {
  auto ret = SpMMNoAutoGrad(sparse_mat, sparse_val, dense_mat, false);

  const bool sparse_requires_grad = sparse_val.requires_grad();
  const bool dense_requires_grad = dense_mat.requires_grad();
  torch::Tensor cache_sparse_val, cache_dense_mat;
  if (dense_requires_grad) {
    cache_sparse_val = sparse_val;
  }
  if (sparse_requires_grad) {
    cache_dense_mat = dense_mat;
  }
  ctx->saved_data["sparse_matrix"] = sparse_mat;
  ctx->saved_data["sparse_requires_grad"] = sparse_requires_grad;
  ctx->saved_data["dense_requires_grad"] = dense_requires_grad;
  ctx->save_for_backward({cache_sparse_val, cache_dense_mat});
  return ret;
}

tensor_list SpMMAutoGrad::backward(
    AutogradContext* ctx, tensor_list grad_outputs) {
  auto saved = ctx->get_saved_variables();
  auto sparse_val = saved[0];
  auto dense_mat = saved[1];
  auto output_grad = grad_outputs[0];

  auto sparse_mat =
      ctx->saved_data["sparse_matrix"].toCustomClass<SparseMatrix>();
  const bool sparse_requires_grad =
      ctx->saved_data["sparse_requires_grad"].toBool();
  const bool dense_requires_grad =
      ctx->saved_data["dense_requires_grad"].toBool();

  torch::Tensor dense_mat_grad, sparse_val_grad;
  if (sparse_requires_grad) {
    // A @ B = C -> dA = dC @ (B^T)
    sparse_val_grad = SDDMMNoAutoGrad(sparse_mat, output_grad, dense_mat);
  }
  if (dense_requires_grad) {
    // A @ B = C -> dB = (A^T) @ dC
    dense_mat_grad = SpMMNoAutoGrad(sparse_mat, sparse_val, output_grad, true);
  }
  return {torch::Tensor(), sparse_val_grad, dense_mat_grad};
}

torch::Tensor SpMM(
    const c10::intrusive_ptr<SparseMatrix>& sparse_mat,
    torch::Tensor dense_mat) {
  _SpMMSanityCheck(sparse_mat, sparse_mat->value(), dense_mat);
  bool expand_dim = false;
  if (dense_mat.dim() == 1) {
    dense_mat = dense_mat.view({-1, 1});
    expand_dim = true;
  }
  auto ret = SpMMAutoGrad::apply(sparse_mat, sparse_mat->value(), dense_mat);
  if (expand_dim) {
    ret = ret.view(-1);
  }
  return ret;
}

}  // namespace sparse
}  // namespace dgl