sparse_matrix.h 9.77 KB
Newer Older
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1
/**
2
3
 *  Copyright (c) 2022 by Contributors
 * @file sparse/sparse_matrix.h
czkkkkkk's avatar
czkkkkkk committed
4
 * @brief DGL C++ sparse matrix header.
5
6
7
8
 */
#ifndef SPARSE_SPARSE_MATRIX_H_
#define SPARSE_SPARSE_MATRIX_H_

czkkkkkk's avatar
czkkkkkk committed
9
10
11
12
13
// clang-format off
#include <sparse/dgl_headers.h>
// clang-format on

#include <sparse/sparse_format.h>
14
15
16
17
#include <torch/custom_class.h>
#include <torch/script.h>

#include <memory>
18
19
#include <tuple>
#include <utility>
20
21
22
23
24
#include <vector>

namespace dgl {
namespace sparse {

25
/** @brief SparseMatrix bound to Python.  */
26
27
class SparseMatrix : public torch::CustomClassHolder {
 public:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
28
  /**
29
30
31
32
33
34
35
36
37
38
39
40
   * @brief General constructor to construct a sparse matrix for different
   * sparse formats. At least one of the sparse formats should be provided,
   * while others could be nullptrs.
   *
   * @param coo The COO format.
   * @param csr The CSR format.
   * @param csc The CSC format.
   * @param value Value of the sparse matrix.
   * @param shape Shape of the sparse matrix.
   */
  SparseMatrix(
      const std::shared_ptr<COO>& coo, const std::shared_ptr<CSR>& csr,
41
42
      const std::shared_ptr<CSR>& csc, const std::shared_ptr<Diag>& diag,
      torch::Tensor value, const std::vector<int64_t>& shape);
43

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
44
  /**
45
46
47
48
49
50
51
   * @brief Construct a SparseMatrix from a COO format.
   * @param coo The COO format
   * @param value Values of the sparse matrix
   * @param shape Shape of the sparse matrix
   *
   * @return SparseMatrix
   */
52
  static c10::intrusive_ptr<SparseMatrix> FromCOOPointer(
53
54
55
      const std::shared_ptr<COO>& coo, torch::Tensor value,
      const std::vector<int64_t>& shape);

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
56
  /**
57
58
59
60
61
62
63
   * @brief Construct a SparseMatrix from a CSR format.
   * @param csr The CSR format
   * @param value Values of the sparse matrix
   * @param shape Shape of the sparse matrix
   *
   * @return SparseMatrix
   */
64
  static c10::intrusive_ptr<SparseMatrix> FromCSRPointer(
65
66
67
      const std::shared_ptr<CSR>& csr, torch::Tensor value,
      const std::vector<int64_t>& shape);

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
68
  /**
69
70
71
72
73
74
75
   * @brief Construct a SparseMatrix from a CSC format.
   * @param csc The CSC format
   * @param value Values of the sparse matrix
   * @param shape Shape of the sparse matrix
   *
   * @return SparseMatrix
   */
76
  static c10::intrusive_ptr<SparseMatrix> FromCSCPointer(
77
78
79
      const std::shared_ptr<CSR>& csc, torch::Tensor value,
      const std::vector<int64_t>& shape);

80
81
82
83
84
85
86
87
88
89
90
91
  /**
   * @brief Construct a SparseMatrix from a Diag format.
   * @param diag The Diag format
   * @param value Values of the sparse matrix
   * @param shape Shape of the sparse matrix
   *
   * @return SparseMatrix
   */
  static c10::intrusive_ptr<SparseMatrix> FromDiagPointer(
      const std::shared_ptr<Diag>& diag, torch::Tensor value,
      const std::vector<int64_t>& shape);

92
93
  /**
   * @brief Create a SparseMatrix from tensors in COO format.
94
   * @param indices COO coordinates with shape (2, nnz).
95
96
97
98
99
100
   * @param value Values of the sparse matrix.
   * @param shape Shape of the sparse matrix.
   *
   * @return SparseMatrix
   */
  static c10::intrusive_ptr<SparseMatrix> FromCOO(
101
      torch::Tensor indices, torch::Tensor value,
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
      const std::vector<int64_t>& shape);

  /**
   * @brief Create a SparseMatrix from tensors in CSR format.
   * @param indptr Index pointer array of the CSR
   * @param indices Indices array of the CSR
   * @param value Values of the sparse matrix
   * @param shape Shape of the sparse matrix
   *
   * @return SparseMatrix
   */
  static c10::intrusive_ptr<SparseMatrix> FromCSR(
      torch::Tensor indptr, torch::Tensor indices, torch::Tensor value,
      const std::vector<int64_t>& shape);

  /**
   * @brief Create a SparseMatrix from tensors in CSC format.
   * @param indptr Index pointer array of the CSC
   * @param indices Indices array of the CSC
   * @param value Values of the sparse matrix
   * @param shape Shape of the sparse matrix
   *
   * @return SparseMatrix
   */
  static c10::intrusive_ptr<SparseMatrix> FromCSC(
      torch::Tensor indptr, torch::Tensor indices, torch::Tensor value,
      const std::vector<int64_t>& shape);

130
131
132
133
134
135
136
137
138
139
  /**
   * @brief Create a SparseMatrix with Diag format.
   * @param value Values of the sparse matrix
   * @param shape Shape of the sparse matrix
   *
   * @return SparseMatrix
   */
  static c10::intrusive_ptr<SparseMatrix> FromDiag(
      torch::Tensor value, const std::vector<int64_t>& shape);

140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
  /**
   * @brief Create a SparseMatrix by selecting rows or columns based on provided
   * indices.
   *
   * This function allows you to create a new SparseMatrix by selecting specific
   * rows or columns from the original SparseMatrix based on the provided
   * indices. The selection can be performed either row-wise or column-wise,
   * determined by the 'dim' parameter.
   *
   * @param dim Select rows (dim=0) or columns (dim=1).
   * @param ids A tensor containing the indices of the selected rows or columns.
   *
   * @return A new SparseMatrix containing the selected rows or columns.
   *
   * @note The 'dim' parameter should be either 0 (for row-wise selection) or 1
   * (for column-wise selection).
   * @note The 'ids' tensor should contain valid indices within the range of the
   * original SparseMatrix's dimensions.
   */
  c10::intrusive_ptr<SparseMatrix> IndexSelect(int64_t dim, torch::Tensor ids);

  /**
   * @brief Create a SparseMatrix by selecting a range of rows or columns based
   * on provided indices.
   *
   * This function allows you to create a new SparseMatrix by selecting a range
   * of specific rows or columns from the original SparseMatrix based on the
   * provided indices. The selection can be performed either row-wise or
   * column-wise, determined by the 'dim' parameter.
   *
   * @param dim Select rows (dim=0) or columns (dim=1).
   * @param start The starting index (inclusive) of the range.
   * @param end The ending index (exclusive) of the range.
   *
   * @return A new SparseMatrix containing the selected range of rows or
   * columns.
   *
   * @note The 'dim' parameter should be either 0 (for row-wise selection) or 1
   * (for column-wise selection).
   * @note The 'start' and 'end' indices should be valid indices within
   * the valid range of the original SparseMatrix's dimensions.
   */
  c10::intrusive_ptr<SparseMatrix> RangeSelect(
      int64_t dim, int64_t start, int64_t end);

185
186
187
188
189
190
191
192
193
194
  /**
   * @brief Create a SparseMatrix from a SparseMatrix using new values.
   * @param mat An existing sparse matrix
   * @param value New values of the sparse matrix
   *
   * @return SparseMatrix
   */
  static c10::intrusive_ptr<SparseMatrix> ValLike(
      const c10::intrusive_ptr<SparseMatrix>& mat, torch::Tensor value);

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
195
  /** @return Value of the sparse matrix. */
196
  inline torch::Tensor value() const { return value_; }
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
197
  /** @return Shape of the sparse matrix. */
198
  inline const std::vector<int64_t>& shape() const { return shape_; }
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
199
  /** @return Number of non-zero values */
200
  inline int64_t nnz() const { return value_.size(0); }
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
201
  /** @return Non-zero value data type */
202
  inline caffe2::TypeMeta dtype() const { return value_.dtype(); }
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
203
  /** @return Device of the sparse matrix */
204
205
  inline torch::Device device() const { return value_.device(); }

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
206
  /** @return COO of the sparse matrix. The COO is created if not exists. */
207
  std::shared_ptr<COO> COOPtr();
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
208
  /** @return CSR of the sparse matrix. The CSR is created if not exists. */
209
  std::shared_ptr<CSR> CSRPtr();
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
210
  /** @return CSC of the sparse matrix. The CSC is created if not exists. */
211
  std::shared_ptr<CSR> CSCPtr();
212
213
214
215
216
  /**
   * @return Diagonal format of the sparse matrix. An error will be raised if
   * it does not have a diagonal format.
   */
  std::shared_ptr<Diag> DiagPtr();
217

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
218
  /** @brief Check whether this sparse matrix has COO format. */
219
  inline bool HasCOO() const { return coo_ != nullptr; }
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
220
  /** @brief Check whether this sparse matrix has CSR format. */
221
  inline bool HasCSR() const { return csr_ != nullptr; }
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
222
  /** @brief Check whether this sparse matrix has CSC format. */
223
  inline bool HasCSC() const { return csc_ != nullptr; }
224
225
  /** @brief Check whether this sparse matrix has Diag format. */
  inline bool HasDiag() const { return diag_ != nullptr; }
226

227
228
  /** @return {row, col} tensors in the COO format. */
  std::tuple<torch::Tensor, torch::Tensor> COOTensors();
229
230
  /** @return Stacked row and col tensors in the COO format. */
  torch::Tensor Indices();
231
232
233
234
235
236
  /** @return {row, col, value_indices} tensors in the CSR format. */
  std::tuple<torch::Tensor, torch::Tensor, torch::optional<torch::Tensor>>
  CSRTensors();
  /** @return {row, col, value_indices} tensors in the CSC format. */
  std::tuple<torch::Tensor, torch::Tensor, torch::optional<torch::Tensor>>
  CSCTensors();
237

238
239
240
241
242
  /** @brief Return the transposition of the sparse matrix. It transposes the
   * first existing sparse format by checking COO, CSR, and CSC.
   */
  c10::intrusive_ptr<SparseMatrix> Transpose() const;

243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
  /**
   * @brief Return a new coalesced matrix.
   *
   * A coalesced sparse matrix satisfies the following properties:
   *   - the indices of the non-zero elements are unique,
   *   - the indices are sorted in lexicographical order.
   *
   * @return A coalesced sparse matrix.
   */
  c10::intrusive_ptr<SparseMatrix> Coalesce();

  /**
   * @brief Return true if this sparse matrix contains duplicate indices.
   * @return A bool flag.
   */
  bool HasDuplicate();

260
 private:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
261
  /** @brief Create the COO format for the sparse matrix internally */
262
  void _CreateCOO();
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
263
  /** @brief Create the CSR format for the sparse matrix internally */
264
  void _CreateCSR();
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
265
  /** @brief Create the CSC format for the sparse matrix internally */
266
267
  void _CreateCSC();

268
  // COO/CSC/CSR/Diag pointers. Nullptr indicates non-existence.
269
270
  std::shared_ptr<COO> coo_;
  std::shared_ptr<CSR> csr_, csc_;
271
  std::shared_ptr<Diag> diag_;
272
273
274
275
276
277
278
279
280
  // Value of the SparseMatrix
  torch::Tensor value_;
  // Shape of the SparseMatrix
  const std::vector<int64_t> shape_;
};
}  // namespace sparse
}  // namespace dgl

#endif  //  SPARSE_SPARSE_MATRIX_H_