matmul.cc 7.32 KB
Newer Older
sangwzh's avatar
sangwzh committed
1
// !!! This is a file automatically generated by hipify!!!
2
3
4
5
6
/**
 *  Copyright (c) 2022 by Contributors
 * @file matmul.cc
 * @brief DGL sparse matrix multiplication functions.
 */
sangwzh's avatar
sangwzh committed
7
#include "matmul.h"
8
9
10
11
12
13

// clang-format off
#include <sparse/dgl_headers.h>
// clang-format on

#include <sparse/sparse_matrix.h>
14
#include <torch/script.h>
15

sangwzh's avatar
sangwzh committed
16
#include "utils.h"
17
18
19
20
21
22
23
24
25
26
27

namespace dgl {
namespace sparse {

torch::Tensor SpMMNoAutoGrad(
    const c10::intrusive_ptr<SparseMatrix>& sparse_mat,
    torch::Tensor sparse_val, torch::Tensor dense_mat, bool transpose_sparse) {
  const std::string op = "mul";
  const std::string reduce = "sum";
  const int64_t out_row =
      transpose_sparse ? sparse_mat->shape()[1] : sparse_mat->shape()[0];
28
29
30
31
32
  std::vector<int64_t> shape = {out_row, dense_mat.size(1)};
  // Batched SpMM
  if (sparse_val.dim() >= 2) {
    shape = {out_row, dense_mat.size(1), sparse_val.size(1)};
  }
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81

  auto ret = torch::zeros(shape, dense_mat.options());
  auto dgl_sparse_val = TorchTensorToDGLArray(sparse_val);
  auto dgl_dense_mat = TorchTensorToDGLArray(dense_mat);
  auto dgl_ret = TorchTensorToDGLArray(ret);
  if (!transpose_sparse) {
    // The format for calculation will be chosen in the following order: CSR,
    // COO. CSR is created if the sparse matrix only has CSC format.
    if (sparse_mat->HasCSR() || !sparse_mat->HasCOO()) {
      // sparse_mat->CSRPtr() will implicitly convert CSC to CSR format if CSR
      // does not exist.
      auto csr = CSRToOldDGLCSR(sparse_mat->CSRPtr());
      aten::CSRSpMM(
          op.c_str(), reduce.c_str(), csr, dgl_dense_mat, dgl_sparse_val,
          dgl_ret, {});
    } else {  // COO
      // Use the reverse order of aten::COOSpMM because it calculates A^T @ X.
      auto coo = COOToOldDGLCOO(sparse_mat->COOPtr());
      coo = aten::COOTranspose(coo);
      aten::COOSpMM(
          op.c_str(), reduce.c_str(), coo, dgl_dense_mat, dgl_sparse_val,
          dgl_ret, {});
    }
  } else {  // transpose_sparse
    // The format for calculation will be chosen in the following order: CSC,
    // COO. CSC is created if the sparse matrix only has CSR format.
    if (sparse_mat->HasCSC() || !sparse_mat->HasCOO()) {
      // sparse_mat->CSCPtr() will implicitly convert CSR to CSC format if CSR
      // does not exist.
      // Use CSC in DGL's CSRSpMM is equivalent as computing A^T @ X.
      auto csc = CSRToOldDGLCSR(sparse_mat->CSCPtr());
      aten::CSRSpMM(
          op.c_str(), reduce.c_str(), csc, dgl_dense_mat, dgl_sparse_val,
          dgl_ret, {});
    } else {  // COO
      // Use the reverse order of aten::COOSpMM because it calculates A^T @ X.
      auto coo = COOToOldDGLCOO(sparse_mat->COOPtr());
      aten::COOSpMM(
          op.c_str(), reduce.c_str(), coo, dgl_dense_mat, dgl_sparse_val,
          dgl_ret, {});
    }
  }
  return ret;
}

torch::Tensor SDDMMNoAutoGrad(
    const c10::intrusive_ptr<SparseMatrix>& sparse_mat, torch::Tensor mat1,
    torch::Tensor mat2_tr) {
  const int64_t out_row = sparse_mat->nnz();
82
83
84
85
86
  std::vector<int64_t> shape({out_row});
  // Batched SDDMM
  if (mat1.dim() >= 3) {
    shape.push_back(mat1.size(2));
    // (N, K, B) -> (N, B, K)
87
    mat1 = mat1.transpose(1, 2);
88
    // (M, K, B) -> (M, B, K)
89
    mat2_tr = mat2_tr.transpose(1, 2);
90
  }
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
  auto ret = torch::zeros(shape, mat1.options());
  const std::string op = "dot";
  auto dgl_mat1 = TorchTensorToDGLArray(mat1);
  auto dgl_mat2_tr = TorchTensorToDGLArray(mat2_tr);
  auto dgl_ret = TorchTensorToDGLArray(ret);
  // The format for calculation will be chosen in the following order: CSR,
  // COO. CSR is created if the sparse matrix only has CSC format.
  if (sparse_mat->HasCSR() || !sparse_mat->HasCOO()) {
    // sparse_mat->CSRPtr() will implicitly convert CSC to CSR format if CSR
    // does not exist.
    auto csr = CSRToOldDGLCSR(sparse_mat->CSRPtr());
    aten::CSRSDDMM(
        op.c_str(), csr, dgl_mat1, dgl_mat2_tr, dgl_ret, 0 /* Lhs target: u */,
        2 /* rhs target: v */);
  } else {  // COO
    auto coo = COOToOldDGLCOO(sparse_mat->COOPtr());
    aten::COOSDDMM(
        op.c_str(), coo, dgl_mat1, dgl_mat2_tr, dgl_ret, 0 /* Lhs target: u */,
        2 /* rhs target: v */);
  }
  return ret;
}

114
115
torch::Tensor BroadcastOpNoAutoGrad(
    const c10::intrusive_ptr<SparseMatrix>& sparse_mat, torch::Tensor dense_mat,
116
    const std::string& op, int64_t dim) {
117
118
119
120
121
122
123
124
  auto sparse_val = sparse_mat->value();
  const int64_t out_row = sparse_mat->nnz();
  const std::vector<int64_t> shape({out_row, sparse_val.size(1)});
  auto ret = torch::zeros(shape, sparse_val.options());

  auto dgl_sparse_val = TorchTensorToDGLArray(sparse_val);
  auto dgl_dense_mat = TorchTensorToDGLArray(dense_mat);
  auto dgl_ret = TorchTensorToDGLArray(ret);
125
126
127
  // Setting dgl_rhs_target to 0 or 2 means using row or column coordinators
  // to access dgl_dense_mat for each edge, respectively.
  auto dgl_rhs_target = dim == 0 ? 2 : 0;
128
129
130
131
132
133
134
135
136

  // The format for calculation will be chosen in the following order: COO, CSR
  // . COO is created if the sparse matrix only has CSC format.
  if (sparse_mat->HasCOO() || !sparse_mat->HasCSR()) {
    // sparse_mat->COOPtr() will implicitly convert CSC to COO format if COO
    // does not exist.
    auto coo = COOToOldDGLCOO(sparse_mat->COOPtr());
    aten::COOSDDMM(
        op.c_str(), coo, dgl_sparse_val, dgl_dense_mat, dgl_ret,
137
        1 /* Lhs target: e */, dgl_rhs_target);
138
139
140
141
  } else {
    auto csr = CSRToOldDGLCSR(sparse_mat->CSRPtr());
    aten::CSRSDDMM(
        op.c_str(), csr, dgl_sparse_val, dgl_dense_mat, dgl_ret,
142
        1 /* Lhs target: e */, dgl_rhs_target);
143
144
145
146
147
  }
  return ret;
}

torch::Tensor BroadcastSubNoAutoGrad(
148
149
150
    const c10::intrusive_ptr<SparseMatrix>& sparse_mat, torch::Tensor dense_mat,
    int64_t dim) {
  return BroadcastOpNoAutoGrad(sparse_mat, dense_mat, "sub", dim);
151
152
153
}

torch::Tensor BroadcastDivNoAutoGrad(
154
155
156
    const c10::intrusive_ptr<SparseMatrix>& sparse_mat, torch::Tensor dense_mat,
    int64_t dim) {
  return BroadcastOpNoAutoGrad(sparse_mat, dense_mat, "div", dim);
157
158
159
}

torch::Tensor BroadcastMulNoAutoGrad(
160
161
162
    const c10::intrusive_ptr<SparseMatrix>& sparse_mat, torch::Tensor dense_mat,
    int64_t dim) {
  return BroadcastOpNoAutoGrad(sparse_mat, dense_mat, "mul", dim);
163
164
}

165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
c10::intrusive_ptr<SparseMatrix> SpSpMMNoAutoGrad(
    const c10::intrusive_ptr<SparseMatrix>& lhs_mat, torch::Tensor lhs_val,
    const c10::intrusive_ptr<SparseMatrix>& rhs_mat, torch::Tensor rhs_val,
    bool lhs_transpose, bool rhs_transpose) {
  aten::CSRMatrix lhs_dgl_csr, rhs_dgl_csr;
  if (!lhs_transpose) {
    lhs_dgl_csr = CSRToOldDGLCSR(lhs_mat->CSRPtr());
  } else {
    lhs_dgl_csr = CSRToOldDGLCSR(lhs_mat->CSCPtr());
  }
  if (!rhs_transpose) {
    rhs_dgl_csr = CSRToOldDGLCSR(rhs_mat->CSRPtr());
  } else {
    rhs_dgl_csr = CSRToOldDGLCSR(rhs_mat->CSCPtr());
  }
  auto lhs_dgl_val = TorchTensorToDGLArray(lhs_val);
  auto rhs_dgl_val = TorchTensorToDGLArray(rhs_val);
  const int64_t ret_row =
      lhs_transpose ? lhs_mat->shape()[1] : lhs_mat->shape()[0];
  const int64_t ret_col =
      rhs_transpose ? rhs_mat->shape()[0] : rhs_mat->shape()[1];
  std::vector<int64_t> ret_shape({ret_row, ret_col});
  aten::CSRMatrix ret_dgl_csr;
  runtime::NDArray ret_val;
  std::tie(ret_dgl_csr, ret_val) =
      aten::CSRMM(lhs_dgl_csr, lhs_dgl_val, rhs_dgl_csr, rhs_dgl_val);
191
  return SparseMatrix::FromCSRPointer(
192
193
194
      CSRFromOldDGLCSR(ret_dgl_csr), DGLArrayToTorchTensor(ret_val), ret_shape);
}

195
196
}  // namespace sparse
}  // namespace dgl