sparse_matrix.h 7.86 KB
Newer Older
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1
/**
2
3
 *  Copyright (c) 2022 by Contributors
 * @file sparse/sparse_matrix.h
czkkkkkk's avatar
czkkkkkk committed
4
 * @brief DGL C++ sparse matrix header.
5
6
7
8
 */
#ifndef SPARSE_SPARSE_MATRIX_H_
#define SPARSE_SPARSE_MATRIX_H_

czkkkkkk's avatar
czkkkkkk committed
9
10
11
12
13
// clang-format off
#include <sparse/dgl_headers.h>
// clang-format on

#include <sparse/sparse_format.h>
14
15
16
17
#include <torch/custom_class.h>
#include <torch/script.h>

#include <memory>
18
19
#include <tuple>
#include <utility>
20
21
22
23
24
#include <vector>

namespace dgl {
namespace sparse {

25
/** @brief SparseMatrix bound to Python.  */
26
27
class SparseMatrix : public torch::CustomClassHolder {
 public:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
28
  /**
29
30
31
32
33
34
35
36
37
38
39
40
   * @brief General constructor to construct a sparse matrix for different
   * sparse formats. At least one of the sparse formats should be provided,
   * while others could be nullptrs.
   *
   * @param coo The COO format.
   * @param csr The CSR format.
   * @param csc The CSC format.
   * @param value Value of the sparse matrix.
   * @param shape Shape of the sparse matrix.
   */
  SparseMatrix(
      const std::shared_ptr<COO>& coo, const std::shared_ptr<CSR>& csr,
41
42
      const std::shared_ptr<CSR>& csc, const std::shared_ptr<Diag>& diag,
      torch::Tensor value, const std::vector<int64_t>& shape);
43

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
44
  /**
45
46
47
48
49
50
51
   * @brief Construct a SparseMatrix from a COO format.
   * @param coo The COO format
   * @param value Values of the sparse matrix
   * @param shape Shape of the sparse matrix
   *
   * @return SparseMatrix
   */
52
  static c10::intrusive_ptr<SparseMatrix> FromCOOPointer(
53
54
55
      const std::shared_ptr<COO>& coo, torch::Tensor value,
      const std::vector<int64_t>& shape);

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
56
  /**
57
58
59
60
61
62
63
   * @brief Construct a SparseMatrix from a CSR format.
   * @param csr The CSR format
   * @param value Values of the sparse matrix
   * @param shape Shape of the sparse matrix
   *
   * @return SparseMatrix
   */
64
  static c10::intrusive_ptr<SparseMatrix> FromCSRPointer(
65
66
67
      const std::shared_ptr<CSR>& csr, torch::Tensor value,
      const std::vector<int64_t>& shape);

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
68
  /**
69
70
71
72
73
74
75
   * @brief Construct a SparseMatrix from a CSC format.
   * @param csc The CSC format
   * @param value Values of the sparse matrix
   * @param shape Shape of the sparse matrix
   *
   * @return SparseMatrix
   */
76
  static c10::intrusive_ptr<SparseMatrix> FromCSCPointer(
77
78
79
      const std::shared_ptr<CSR>& csc, torch::Tensor value,
      const std::vector<int64_t>& shape);

80
81
82
83
84
85
86
87
88
89
90
91
  /**
   * @brief Construct a SparseMatrix from a Diag format.
   * @param diag The Diag format
   * @param value Values of the sparse matrix
   * @param shape Shape of the sparse matrix
   *
   * @return SparseMatrix
   */
  static c10::intrusive_ptr<SparseMatrix> FromDiagPointer(
      const std::shared_ptr<Diag>& diag, torch::Tensor value,
      const std::vector<int64_t>& shape);

92
93
  /**
   * @brief Create a SparseMatrix from tensors in COO format.
94
   * @param indices COO coordinates with shape (2, nnz).
95
96
97
98
99
100
   * @param value Values of the sparse matrix.
   * @param shape Shape of the sparse matrix.
   *
   * @return SparseMatrix
   */
  static c10::intrusive_ptr<SparseMatrix> FromCOO(
101
      torch::Tensor indices, torch::Tensor value,
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
      const std::vector<int64_t>& shape);

  /**
   * @brief Create a SparseMatrix from tensors in CSR format.
   * @param indptr Index pointer array of the CSR
   * @param indices Indices array of the CSR
   * @param value Values of the sparse matrix
   * @param shape Shape of the sparse matrix
   *
   * @return SparseMatrix
   */
  static c10::intrusive_ptr<SparseMatrix> FromCSR(
      torch::Tensor indptr, torch::Tensor indices, torch::Tensor value,
      const std::vector<int64_t>& shape);

  /**
   * @brief Create a SparseMatrix from tensors in CSC format.
   * @param indptr Index pointer array of the CSC
   * @param indices Indices array of the CSC
   * @param value Values of the sparse matrix
   * @param shape Shape of the sparse matrix
   *
   * @return SparseMatrix
   */
  static c10::intrusive_ptr<SparseMatrix> FromCSC(
      torch::Tensor indptr, torch::Tensor indices, torch::Tensor value,
      const std::vector<int64_t>& shape);

130
131
132
133
134
135
136
137
138
139
  /**
   * @brief Create a SparseMatrix with Diag format.
   * @param value Values of the sparse matrix
   * @param shape Shape of the sparse matrix
   *
   * @return SparseMatrix
   */
  static c10::intrusive_ptr<SparseMatrix> FromDiag(
      torch::Tensor value, const std::vector<int64_t>& shape);

140
141
142
143
144
145
146
147
148
149
  /**
   * @brief Create a SparseMatrix from a SparseMatrix using new values.
   * @param mat An existing sparse matrix
   * @param value New values of the sparse matrix
   *
   * @return SparseMatrix
   */
  static c10::intrusive_ptr<SparseMatrix> ValLike(
      const c10::intrusive_ptr<SparseMatrix>& mat, torch::Tensor value);

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
150
  /** @return Value of the sparse matrix. */
151
  inline torch::Tensor value() const { return value_; }
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
152
  /** @return Shape of the sparse matrix. */
153
  inline const std::vector<int64_t>& shape() const { return shape_; }
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
154
  /** @return Number of non-zero values */
155
  inline int64_t nnz() const { return value_.size(0); }
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
156
  /** @return Non-zero value data type */
157
  inline caffe2::TypeMeta dtype() const { return value_.dtype(); }
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
158
  /** @return Device of the sparse matrix */
159
160
  inline torch::Device device() const { return value_.device(); }

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
161
  /** @return COO of the sparse matrix. The COO is created if not exists. */
162
  std::shared_ptr<COO> COOPtr();
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
163
  /** @return CSR of the sparse matrix. The CSR is created if not exists. */
164
  std::shared_ptr<CSR> CSRPtr();
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
165
  /** @return CSC of the sparse matrix. The CSC is created if not exists. */
166
  std::shared_ptr<CSR> CSCPtr();
167
168
169
170
171
  /**
   * @return Diagonal format of the sparse matrix. An error will be raised if
   * it does not have a diagonal format.
   */
  std::shared_ptr<Diag> DiagPtr();
172

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
173
  /** @brief Check whether this sparse matrix has COO format. */
174
  inline bool HasCOO() const { return coo_ != nullptr; }
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
175
  /** @brief Check whether this sparse matrix has CSR format. */
176
  inline bool HasCSR() const { return csr_ != nullptr; }
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
177
  /** @brief Check whether this sparse matrix has CSC format. */
178
  inline bool HasCSC() const { return csc_ != nullptr; }
179
180
  /** @brief Check whether this sparse matrix has Diag format. */
  inline bool HasDiag() const { return diag_ != nullptr; }
181

182
183
  /** @return {row, col} tensors in the COO format. */
  std::tuple<torch::Tensor, torch::Tensor> COOTensors();
184
185
  /** @return Stacked row and col tensors in the COO format. */
  torch::Tensor Indices();
186
187
188
189
190
191
  /** @return {row, col, value_indices} tensors in the CSR format. */
  std::tuple<torch::Tensor, torch::Tensor, torch::optional<torch::Tensor>>
  CSRTensors();
  /** @return {row, col, value_indices} tensors in the CSC format. */
  std::tuple<torch::Tensor, torch::Tensor, torch::optional<torch::Tensor>>
  CSCTensors();
192

193
194
195
196
197
  /** @brief Return the transposition of the sparse matrix. It transposes the
   * first existing sparse format by checking COO, CSR, and CSC.
   */
  c10::intrusive_ptr<SparseMatrix> Transpose() const;

198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
  /**
   * @brief Return a new coalesced matrix.
   *
   * A coalesced sparse matrix satisfies the following properties:
   *   - the indices of the non-zero elements are unique,
   *   - the indices are sorted in lexicographical order.
   *
   * @return A coalesced sparse matrix.
   */
  c10::intrusive_ptr<SparseMatrix> Coalesce();

  /**
   * @brief Return true if this sparse matrix contains duplicate indices.
   * @return A bool flag.
   */
  bool HasDuplicate();

215
 private:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
216
  /** @brief Create the COO format for the sparse matrix internally */
217
  void _CreateCOO();
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
218
  /** @brief Create the CSR format for the sparse matrix internally */
219
  void _CreateCSR();
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
220
  /** @brief Create the CSC format for the sparse matrix internally */
221
222
  void _CreateCSC();

223
  // COO/CSC/CSR/Diag pointers. Nullptr indicates non-existence.
224
225
  std::shared_ptr<COO> coo_;
  std::shared_ptr<CSR> csr_, csc_;
226
  std::shared_ptr<Diag> diag_;
227
228
229
230
231
232
233
234
235
  // Value of the SparseMatrix
  torch::Tensor value_;
  // Shape of the SparseMatrix
  const std::vector<int64_t> shape_;
};
}  // namespace sparse
}  // namespace dgl

#endif  //  SPARSE_SPARSE_MATRIX_H_