matmul.cc 7.09 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
/**
 *  Copyright (c) 2022 by Contributors
 * @file matmul.cc
 * @brief DGL sparse matrix multiplication functions.
 */
#include "./matmul.h"

// clang-format off
#include <sparse/dgl_headers.h>
// clang-format on

#include <sparse/sparse_matrix.h>
#include <torch/script.h>

#include "./utils.h"

namespace dgl {
namespace sparse {

torch::Tensor SpMMNoAutoGrad(
    const c10::intrusive_ptr<SparseMatrix>& sparse_mat,
    torch::Tensor sparse_val, torch::Tensor dense_mat, bool transpose_sparse) {
  const std::string op = "mul";
  const std::string reduce = "sum";
  const int64_t out_row =
      transpose_sparse ? sparse_mat->shape()[1] : sparse_mat->shape()[0];
27
28
29
30
31
  std::vector<int64_t> shape = {out_row, dense_mat.size(1)};
  // Batched SpMM
  if (sparse_val.dim() >= 2) {
    shape = {out_row, dense_mat.size(1), sparse_val.size(1)};
  }
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80

  auto ret = torch::zeros(shape, dense_mat.options());
  auto dgl_sparse_val = TorchTensorToDGLArray(sparse_val);
  auto dgl_dense_mat = TorchTensorToDGLArray(dense_mat);
  auto dgl_ret = TorchTensorToDGLArray(ret);
  if (!transpose_sparse) {
    // The format for calculation will be chosen in the following order: CSR,
    // COO. CSR is created if the sparse matrix only has CSC format.
    if (sparse_mat->HasCSR() || !sparse_mat->HasCOO()) {
      // sparse_mat->CSRPtr() will implicitly convert CSC to CSR format if CSR
      // does not exist.
      auto csr = CSRToOldDGLCSR(sparse_mat->CSRPtr());
      aten::CSRSpMM(
          op.c_str(), reduce.c_str(), csr, dgl_dense_mat, dgl_sparse_val,
          dgl_ret, {});
    } else {  // COO
      // Use the reverse order of aten::COOSpMM because it calculates A^T @ X.
      auto coo = COOToOldDGLCOO(sparse_mat->COOPtr());
      coo = aten::COOTranspose(coo);
      aten::COOSpMM(
          op.c_str(), reduce.c_str(), coo, dgl_dense_mat, dgl_sparse_val,
          dgl_ret, {});
    }
  } else {  // transpose_sparse
    // The format for calculation will be chosen in the following order: CSC,
    // COO. CSC is created if the sparse matrix only has CSR format.
    if (sparse_mat->HasCSC() || !sparse_mat->HasCOO()) {
      // sparse_mat->CSCPtr() will implicitly convert CSR to CSC format if CSR
      // does not exist.
      // Use CSC in DGL's CSRSpMM is equivalent as computing A^T @ X.
      auto csc = CSRToOldDGLCSR(sparse_mat->CSCPtr());
      aten::CSRSpMM(
          op.c_str(), reduce.c_str(), csc, dgl_dense_mat, dgl_sparse_val,
          dgl_ret, {});
    } else {  // COO
      // Use the reverse order of aten::COOSpMM because it calculates A^T @ X.
      auto coo = COOToOldDGLCOO(sparse_mat->COOPtr());
      aten::COOSpMM(
          op.c_str(), reduce.c_str(), coo, dgl_dense_mat, dgl_sparse_val,
          dgl_ret, {});
    }
  }
  return ret;
}

torch::Tensor SDDMMNoAutoGrad(
    const c10::intrusive_ptr<SparseMatrix>& sparse_mat, torch::Tensor mat1,
    torch::Tensor mat2_tr) {
  const int64_t out_row = sparse_mat->nnz();
81
82
83
84
85
86
87
88
89
  std::vector<int64_t> shape({out_row});
  // Batched SDDMM
  if (mat1.dim() >= 3) {
    shape.push_back(mat1.size(2));
    // (N, K, B) -> (N, B, K)
    mat1 = mat1.transpose(1, 2).contiguous();
    // (M, K, B) -> (M, B, K)
    mat2_tr = mat2_tr.transpose(1, 2).contiguous();
  }
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
  auto ret = torch::zeros(shape, mat1.options());
  const std::string op = "dot";
  auto dgl_mat1 = TorchTensorToDGLArray(mat1);
  auto dgl_mat2_tr = TorchTensorToDGLArray(mat2_tr);
  auto dgl_ret = TorchTensorToDGLArray(ret);
  // The format for calculation will be chosen in the following order: CSR,
  // COO. CSR is created if the sparse matrix only has CSC format.
  if (sparse_mat->HasCSR() || !sparse_mat->HasCOO()) {
    // sparse_mat->CSRPtr() will implicitly convert CSC to CSR format if CSR
    // does not exist.
    auto csr = CSRToOldDGLCSR(sparse_mat->CSRPtr());
    aten::CSRSDDMM(
        op.c_str(), csr, dgl_mat1, dgl_mat2_tr, dgl_ret, 0 /* Lhs target: u */,
        2 /* rhs target: v */);
  } else {  // COO
    auto coo = COOToOldDGLCOO(sparse_mat->COOPtr());
    aten::COOSDDMM(
        op.c_str(), coo, dgl_mat1, dgl_mat2_tr, dgl_ret, 0 /* Lhs target: u */,
        2 /* rhs target: v */);
  }
  return ret;
}

113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
torch::Tensor BroadcastOpNoAutoGrad(
    const c10::intrusive_ptr<SparseMatrix>& sparse_mat, torch::Tensor dense_mat,
    const std::string& op) {
  auto sparse_val = sparse_mat->value();
  const int64_t out_row = sparse_mat->nnz();
  const std::vector<int64_t> shape({out_row, sparse_val.size(1)});
  auto ret = torch::zeros(shape, sparse_val.options());

  auto dgl_sparse_val = TorchTensorToDGLArray(sparse_val);
  auto dgl_dense_mat = TorchTensorToDGLArray(dense_mat);
  auto dgl_ret = TorchTensorToDGLArray(ret);

  // The format for calculation will be chosen in the following order: COO, CSR
  // . COO is created if the sparse matrix only has CSC format.
  if (sparse_mat->HasCOO() || !sparse_mat->HasCSR()) {
    // sparse_mat->COOPtr() will implicitly convert CSC to COO format if COO
    // does not exist.
    auto coo = COOToOldDGLCOO(sparse_mat->COOPtr());
    aten::COOSDDMM(
        op.c_str(), coo, dgl_sparse_val, dgl_dense_mat, dgl_ret,
        1 /* Lhs target: e */, 0 /* rhs target: u due to transpose */);
  } else {
    auto csr = CSRToOldDGLCSR(sparse_mat->CSRPtr());
    aten::CSRSDDMM(
        op.c_str(), csr, dgl_sparse_val, dgl_dense_mat, dgl_ret,
        1 /* Lhs target: e */, 0 /* rhs target: u due to transpose */);
  }
  return ret;
}

torch::Tensor BroadcastSubNoAutoGrad(
    const c10::intrusive_ptr<SparseMatrix>& sparse_mat,
    torch::Tensor dense_mat) {
  return BroadcastOpNoAutoGrad(sparse_mat, dense_mat, "sub");
}

torch::Tensor BroadcastDivNoAutoGrad(
    const c10::intrusive_ptr<SparseMatrix>& sparse_mat,
    torch::Tensor dense_mat) {
  return BroadcastOpNoAutoGrad(sparse_mat, dense_mat, "div");
}

torch::Tensor BroadcastMulNoAutoGrad(
    const c10::intrusive_ptr<SparseMatrix>& sparse_mat,
    torch::Tensor dense_mat) {
  return BroadcastOpNoAutoGrad(sparse_mat, dense_mat, "mul");
}

161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
c10::intrusive_ptr<SparseMatrix> SpSpMMNoAutoGrad(
    const c10::intrusive_ptr<SparseMatrix>& lhs_mat, torch::Tensor lhs_val,
    const c10::intrusive_ptr<SparseMatrix>& rhs_mat, torch::Tensor rhs_val,
    bool lhs_transpose, bool rhs_transpose) {
  aten::CSRMatrix lhs_dgl_csr, rhs_dgl_csr;
  if (!lhs_transpose) {
    lhs_dgl_csr = CSRToOldDGLCSR(lhs_mat->CSRPtr());
  } else {
    lhs_dgl_csr = CSRToOldDGLCSR(lhs_mat->CSCPtr());
  }
  if (!rhs_transpose) {
    rhs_dgl_csr = CSRToOldDGLCSR(rhs_mat->CSRPtr());
  } else {
    rhs_dgl_csr = CSRToOldDGLCSR(rhs_mat->CSCPtr());
  }
  auto lhs_dgl_val = TorchTensorToDGLArray(lhs_val);
  auto rhs_dgl_val = TorchTensorToDGLArray(rhs_val);
  const int64_t ret_row =
      lhs_transpose ? lhs_mat->shape()[1] : lhs_mat->shape()[0];
  const int64_t ret_col =
      rhs_transpose ? rhs_mat->shape()[0] : rhs_mat->shape()[1];
  std::vector<int64_t> ret_shape({ret_row, ret_col});
  aten::CSRMatrix ret_dgl_csr;
  runtime::NDArray ret_val;
  std::tie(ret_dgl_csr, ret_val) =
      aten::CSRMM(lhs_dgl_csr, lhs_dgl_val, rhs_dgl_csr, rhs_dgl_val);
  return SparseMatrix::FromCSR(
      CSRFromOldDGLCSR(ret_dgl_csr), DGLArrayToTorchTensor(ret_val), ret_shape);
}

191
192
}  // namespace sparse
}  // namespace dgl