matmul.cc 4.96 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
/**
 *  Copyright (c) 2022 by Contributors
 * @file matmul.cc
 * @brief DGL sparse matrix multiplication functions.
 */
#include "./matmul.h"

// clang-format off
#include <sparse/dgl_headers.h>
// clang-format on

#include <sparse/sparse_matrix.h>
#include <torch/script.h>

#include "./utils.h"

namespace dgl {
namespace sparse {

torch::Tensor SpMMNoAutoGrad(
    const c10::intrusive_ptr<SparseMatrix>& sparse_mat,
    torch::Tensor sparse_val, torch::Tensor dense_mat, bool transpose_sparse) {
  const std::string op = "mul";
  const std::string reduce = "sum";
  const int64_t out_row =
      transpose_sparse ? sparse_mat->shape()[1] : sparse_mat->shape()[0];
  const std::vector<int64_t> shape = {out_row, dense_mat.size(1)};

  auto ret = torch::zeros(shape, dense_mat.options());
  auto dgl_sparse_val = TorchTensorToDGLArray(sparse_val);
  auto dgl_dense_mat = TorchTensorToDGLArray(dense_mat);
  auto dgl_ret = TorchTensorToDGLArray(ret);
  if (!transpose_sparse) {
    // The format for calculation will be chosen in the following order: CSR,
    // COO. CSR is created if the sparse matrix only has CSC format.
    if (sparse_mat->HasCSR() || !sparse_mat->HasCOO()) {
      // sparse_mat->CSRPtr() will implicitly convert CSC to CSR format if CSR
      // does not exist.
      auto csr = CSRToOldDGLCSR(sparse_mat->CSRPtr());
      aten::CSRSpMM(
          op.c_str(), reduce.c_str(), csr, dgl_dense_mat, dgl_sparse_val,
          dgl_ret, {});
    } else {  // COO
      // Use the reverse order of aten::COOSpMM because it calculates A^T @ X.
      auto coo = COOToOldDGLCOO(sparse_mat->COOPtr());
      coo = aten::COOTranspose(coo);
      aten::COOSpMM(
          op.c_str(), reduce.c_str(), coo, dgl_dense_mat, dgl_sparse_val,
          dgl_ret, {});
    }
  } else {  // transpose_sparse
    // The format for calculation will be chosen in the following order: CSC,
    // COO. CSC is created if the sparse matrix only has CSR format.
    if (sparse_mat->HasCSC() || !sparse_mat->HasCOO()) {
      // sparse_mat->CSCPtr() will implicitly convert CSR to CSC format if CSR
      // does not exist.
      // Use CSC in DGL's CSRSpMM is equivalent as computing A^T @ X.
      auto csc = CSRToOldDGLCSR(sparse_mat->CSCPtr());
      aten::CSRSpMM(
          op.c_str(), reduce.c_str(), csc, dgl_dense_mat, dgl_sparse_val,
          dgl_ret, {});
    } else {  // COO
      // Use the reverse order of aten::COOSpMM because it calculates A^T @ X.
      auto coo = COOToOldDGLCOO(sparse_mat->COOPtr());
      aten::COOSpMM(
          op.c_str(), reduce.c_str(), coo, dgl_dense_mat, dgl_sparse_val,
          dgl_ret, {});
    }
  }
  return ret;
}

torch::Tensor SDDMMNoAutoGrad(
    const c10::intrusive_ptr<SparseMatrix>& sparse_mat, torch::Tensor mat1,
    torch::Tensor mat2_tr) {
  const int64_t out_row = sparse_mat->nnz();
  const std::vector<int64_t> shape({out_row});
  auto ret = torch::zeros(shape, mat1.options());
  const std::string op = "dot";
  auto dgl_mat1 = TorchTensorToDGLArray(mat1);
  auto dgl_mat2_tr = TorchTensorToDGLArray(mat2_tr);
  auto dgl_ret = TorchTensorToDGLArray(ret);
  // The format for calculation will be chosen in the following order: CSR,
  // COO. CSR is created if the sparse matrix only has CSC format.
  if (sparse_mat->HasCSR() || !sparse_mat->HasCOO()) {
    // sparse_mat->CSRPtr() will implicitly convert CSC to CSR format if CSR
    // does not exist.
    auto csr = CSRToOldDGLCSR(sparse_mat->CSRPtr());
    aten::CSRSDDMM(
        op.c_str(), csr, dgl_mat1, dgl_mat2_tr, dgl_ret, 0 /* Lhs target: u */,
        2 /* rhs target: v */);
  } else {  // COO
    auto coo = COOToOldDGLCOO(sparse_mat->COOPtr());
    aten::COOSDDMM(
        op.c_str(), coo, dgl_mat1, dgl_mat2_tr, dgl_ret, 0 /* Lhs target: u */,
        2 /* rhs target: v */);
  }
  return ret;
}

101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
c10::intrusive_ptr<SparseMatrix> SpSpMMNoAutoGrad(
    const c10::intrusive_ptr<SparseMatrix>& lhs_mat, torch::Tensor lhs_val,
    const c10::intrusive_ptr<SparseMatrix>& rhs_mat, torch::Tensor rhs_val,
    bool lhs_transpose, bool rhs_transpose) {
  aten::CSRMatrix lhs_dgl_csr, rhs_dgl_csr;
  if (!lhs_transpose) {
    lhs_dgl_csr = CSRToOldDGLCSR(lhs_mat->CSRPtr());
  } else {
    lhs_dgl_csr = CSRToOldDGLCSR(lhs_mat->CSCPtr());
  }
  if (!rhs_transpose) {
    rhs_dgl_csr = CSRToOldDGLCSR(rhs_mat->CSRPtr());
  } else {
    rhs_dgl_csr = CSRToOldDGLCSR(rhs_mat->CSCPtr());
  }
  auto lhs_dgl_val = TorchTensorToDGLArray(lhs_val);
  auto rhs_dgl_val = TorchTensorToDGLArray(rhs_val);
  const int64_t ret_row =
      lhs_transpose ? lhs_mat->shape()[1] : lhs_mat->shape()[0];
  const int64_t ret_col =
      rhs_transpose ? rhs_mat->shape()[0] : rhs_mat->shape()[1];
  std::vector<int64_t> ret_shape({ret_row, ret_col});
  aten::CSRMatrix ret_dgl_csr;
  runtime::NDArray ret_val;
  std::tie(ret_dgl_csr, ret_val) =
      aten::CSRMM(lhs_dgl_csr, lhs_dgl_val, rhs_dgl_csr, rhs_dgl_val);
  return SparseMatrix::FromCSR(
      CSRFromOldDGLCSR(ret_dgl_csr), DGLArrayToTorchTensor(ret_val), ret_shape);
}

131
132
}  // namespace sparse
}  // namespace dgl