sparse_matrix.h 11.3 KB
Newer Older
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1
/**
2
3
 *  Copyright (c) 2022 by Contributors
 * @file sparse/sparse_matrix.h
czkkkkkk's avatar
czkkkkkk committed
4
 * @brief DGL C++ sparse matrix header.
5
6
7
8
 */
#ifndef SPARSE_SPARSE_MATRIX_H_
#define SPARSE_SPARSE_MATRIX_H_

czkkkkkk's avatar
czkkkkkk committed
9
10
// clang-format off
#include <sparse/dgl_headers.h>
11
#include <sparse/torch_headers.h>
czkkkkkk's avatar
czkkkkkk committed
12
13
14
// clang-format on

#include <sparse/sparse_format.h>
15
16

#include <memory>
17
18
#include <tuple>
#include <utility>
19
20
21
22
23
#include <vector>

namespace dgl {
namespace sparse {

24
/** @brief SparseMatrix bound to Python.  */
25
26
class SparseMatrix : public torch::CustomClassHolder {
 public:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
27
  /**
28
29
30
31
32
33
34
35
36
37
38
39
   * @brief General constructor to construct a sparse matrix for different
   * sparse formats. At least one of the sparse formats should be provided,
   * while others could be nullptrs.
   *
   * @param coo The COO format.
   * @param csr The CSR format.
   * @param csc The CSC format.
   * @param value Value of the sparse matrix.
   * @param shape Shape of the sparse matrix.
   */
  SparseMatrix(
      const std::shared_ptr<COO>& coo, const std::shared_ptr<CSR>& csr,
40
41
      const std::shared_ptr<CSR>& csc, const std::shared_ptr<Diag>& diag,
      torch::Tensor value, const std::vector<int64_t>& shape);
42

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
43
  /**
44
45
46
47
48
49
50
   * @brief Construct a SparseMatrix from a COO format.
   * @param coo The COO format
   * @param value Values of the sparse matrix
   * @param shape Shape of the sparse matrix
   *
   * @return SparseMatrix
   */
51
  static c10::intrusive_ptr<SparseMatrix> FromCOOPointer(
52
53
54
      const std::shared_ptr<COO>& coo, torch::Tensor value,
      const std::vector<int64_t>& shape);

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
55
  /**
56
57
58
59
60
61
62
   * @brief Construct a SparseMatrix from a CSR format.
   * @param csr The CSR format
   * @param value Values of the sparse matrix
   * @param shape Shape of the sparse matrix
   *
   * @return SparseMatrix
   */
63
  static c10::intrusive_ptr<SparseMatrix> FromCSRPointer(
64
65
66
      const std::shared_ptr<CSR>& csr, torch::Tensor value,
      const std::vector<int64_t>& shape);

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
67
  /**
68
69
70
71
72
73
74
   * @brief Construct a SparseMatrix from a CSC format.
   * @param csc The CSC format
   * @param value Values of the sparse matrix
   * @param shape Shape of the sparse matrix
   *
   * @return SparseMatrix
   */
75
  static c10::intrusive_ptr<SparseMatrix> FromCSCPointer(
76
77
78
      const std::shared_ptr<CSR>& csc, torch::Tensor value,
      const std::vector<int64_t>& shape);

79
80
81
82
83
84
85
86
87
88
89
90
  /**
   * @brief Construct a SparseMatrix from a Diag format.
   * @param diag The Diag format
   * @param value Values of the sparse matrix
   * @param shape Shape of the sparse matrix
   *
   * @return SparseMatrix
   */
  static c10::intrusive_ptr<SparseMatrix> FromDiagPointer(
      const std::shared_ptr<Diag>& diag, torch::Tensor value,
      const std::vector<int64_t>& shape);

91
92
  /**
   * @brief Create a SparseMatrix from tensors in COO format.
93
   * @param indices COO coordinates with shape (2, nnz).
94
95
96
97
98
99
   * @param value Values of the sparse matrix.
   * @param shape Shape of the sparse matrix.
   *
   * @return SparseMatrix
   */
  static c10::intrusive_ptr<SparseMatrix> FromCOO(
100
      torch::Tensor indices, torch::Tensor value,
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
      const std::vector<int64_t>& shape);

  /**
   * @brief Create a SparseMatrix from tensors in CSR format.
   * @param indptr Index pointer array of the CSR
   * @param indices Indices array of the CSR
   * @param value Values of the sparse matrix
   * @param shape Shape of the sparse matrix
   *
   * @return SparseMatrix
   */
  static c10::intrusive_ptr<SparseMatrix> FromCSR(
      torch::Tensor indptr, torch::Tensor indices, torch::Tensor value,
      const std::vector<int64_t>& shape);

  /**
   * @brief Create a SparseMatrix from tensors in CSC format.
   * @param indptr Index pointer array of the CSC
   * @param indices Indices array of the CSC
   * @param value Values of the sparse matrix
   * @param shape Shape of the sparse matrix
   *
   * @return SparseMatrix
   */
  static c10::intrusive_ptr<SparseMatrix> FromCSC(
      torch::Tensor indptr, torch::Tensor indices, torch::Tensor value,
      const std::vector<int64_t>& shape);

129
130
131
132
133
134
135
136
137
138
  /**
   * @brief Create a SparseMatrix with Diag format.
   * @param value Values of the sparse matrix
   * @param shape Shape of the sparse matrix
   *
   * @return SparseMatrix
   */
  static c10::intrusive_ptr<SparseMatrix> FromDiag(
      torch::Tensor value, const std::vector<int64_t>& shape);

139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
  /**
   * @brief Create a SparseMatrix by selecting rows or columns based on provided
   * indices.
   *
   * This function allows you to create a new SparseMatrix by selecting specific
   * rows or columns from the original SparseMatrix based on the provided
   * indices. The selection can be performed either row-wise or column-wise,
   * determined by the 'dim' parameter.
   *
   * @param dim Select rows (dim=0) or columns (dim=1).
   * @param ids A tensor containing the indices of the selected rows or columns.
   *
   * @return A new SparseMatrix containing the selected rows or columns.
   *
   * @note The 'dim' parameter should be either 0 (for row-wise selection) or 1
   * (for column-wise selection).
   * @note The 'ids' tensor should contain valid indices within the range of the
   * original SparseMatrix's dimensions.
   */
  c10::intrusive_ptr<SparseMatrix> IndexSelect(int64_t dim, torch::Tensor ids);

  /**
   * @brief Create a SparseMatrix by selecting a range of rows or columns based
   * on provided indices.
   *
   * This function allows you to create a new SparseMatrix by selecting a range
   * of specific rows or columns from the original SparseMatrix based on the
   * provided indices. The selection can be performed either row-wise or
   * column-wise, determined by the 'dim' parameter.
   *
   * @param dim Select rows (dim=0) or columns (dim=1).
   * @param start The starting index (inclusive) of the range.
   * @param end The ending index (exclusive) of the range.
   *
   * @return A new SparseMatrix containing the selected range of rows or
   * columns.
   *
   * @note The 'dim' parameter should be either 0 (for row-wise selection) or 1
   * (for column-wise selection).
   * @note The 'start' and 'end' indices should be valid indices within
   * the valid range of the original SparseMatrix's dimensions.
   */
  c10::intrusive_ptr<SparseMatrix> RangeSelect(
      int64_t dim, int64_t start, int64_t end);

184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
  /**
   * @brief Create a SparseMatrix by sampling elements based on the specified
   * dimension and sample count.
   *
   * If `ids` is provided, this function samples elements from the specified
   * set of row or column IDs, resulting in a sparse matrix containing only
   * the sampled rows or columns.
   *
   * @param dim Select rows (dim=0) or columns (dim=1) for sampling.
   * @param fanout The number of elements to randomly sample from each row or
   * column.
   * @param ids An optional tensor containing row or column IDs from which to
   * sample elements.
   * @param replace Indicates whether repeated sampling of the same element
   * is allowed. If True, repeated sampling is allowed; otherwise, it is not
   * allowed.
   * @param bias An optional boolean flag indicating whether to enable biasing
   * during sampling. If True, the values of the sparse matrix will be used as
   * bias weights, meaning that elements with higher values will be more likely
   * to be sampled. Otherwise, all elements will be sampled uniformly,
   * regardless of their value.
   *
   * @return A new SparseMatrix with the same shape as the original matrix
   * containing the sampled elements.
   *
   * @note If 'replace = false' and there are fewer elements than 'fanout',
   * all non-zero elements will be sampled.
   * @note If 'ids' is not provided, the function will sample from
   * all rows or columns.
   */
  c10::intrusive_ptr<SparseMatrix> Sample(
      int64_t dim, int64_t fanout, torch::Tensor ids, bool replace, bool bias);

217
218
219
220
221
222
223
224
225
226
  /**
   * @brief Create a SparseMatrix from a SparseMatrix using new values.
   * @param mat An existing sparse matrix
   * @param value New values of the sparse matrix
   *
   * @return SparseMatrix
   */
  static c10::intrusive_ptr<SparseMatrix> ValLike(
      const c10::intrusive_ptr<SparseMatrix>& mat, torch::Tensor value);

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
227
  /** @return Value of the sparse matrix. */
228
  inline torch::Tensor value() const { return value_; }
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
229
  /** @return Shape of the sparse matrix. */
230
  inline const std::vector<int64_t>& shape() const { return shape_; }
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
231
  /** @return Number of non-zero values */
232
  inline int64_t nnz() const { return value_.size(0); }
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
233
  /** @return Non-zero value data type */
234
  inline caffe2::TypeMeta dtype() const { return value_.dtype(); }
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
235
  /** @return Device of the sparse matrix */
236
237
  inline torch::Device device() const { return value_.device(); }

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
238
  /** @return COO of the sparse matrix. The COO is created if not exists. */
239
  std::shared_ptr<COO> COOPtr();
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
240
  /** @return CSR of the sparse matrix. The CSR is created if not exists. */
241
  std::shared_ptr<CSR> CSRPtr();
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
242
  /** @return CSC of the sparse matrix. The CSC is created if not exists. */
243
  std::shared_ptr<CSR> CSCPtr();
244
245
246
247
248
  /**
   * @return Diagonal format of the sparse matrix. An error will be raised if
   * it does not have a diagonal format.
   */
  std::shared_ptr<Diag> DiagPtr();
249

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
250
  /** @brief Check whether this sparse matrix has COO format. */
251
  inline bool HasCOO() const { return coo_ != nullptr; }
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
252
  /** @brief Check whether this sparse matrix has CSR format. */
253
  inline bool HasCSR() const { return csr_ != nullptr; }
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
254
  /** @brief Check whether this sparse matrix has CSC format. */
255
  inline bool HasCSC() const { return csc_ != nullptr; }
256
257
  /** @brief Check whether this sparse matrix has Diag format. */
  inline bool HasDiag() const { return diag_ != nullptr; }
258

259
260
  /** @return {row, col} tensors in the COO format. */
  std::tuple<torch::Tensor, torch::Tensor> COOTensors();
261
262
  /** @return Stacked row and col tensors in the COO format. */
  torch::Tensor Indices();
263
264
265
266
267
268
  /** @return {row, col, value_indices} tensors in the CSR format. */
  std::tuple<torch::Tensor, torch::Tensor, torch::optional<torch::Tensor>>
  CSRTensors();
  /** @return {row, col, value_indices} tensors in the CSC format. */
  std::tuple<torch::Tensor, torch::Tensor, torch::optional<torch::Tensor>>
  CSCTensors();
269

270
271
272
273
274
  /** @brief Return the transposition of the sparse matrix. It transposes the
   * first existing sparse format by checking COO, CSR, and CSC.
   */
  c10::intrusive_ptr<SparseMatrix> Transpose() const;

275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
  /**
   * @brief Return a new coalesced matrix.
   *
   * A coalesced sparse matrix satisfies the following properties:
   *   - the indices of the non-zero elements are unique,
   *   - the indices are sorted in lexicographical order.
   *
   * @return A coalesced sparse matrix.
   */
  c10::intrusive_ptr<SparseMatrix> Coalesce();

  /**
   * @brief Return true if this sparse matrix contains duplicate indices.
   * @return A bool flag.
   */
  bool HasDuplicate();

292
 private:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
293
  /** @brief Create the COO format for the sparse matrix internally */
294
  void _CreateCOO();
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
295
  /** @brief Create the CSR format for the sparse matrix internally */
296
  void _CreateCSR();
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
297
  /** @brief Create the CSC format for the sparse matrix internally */
298
299
  void _CreateCSC();

300
  // COO/CSC/CSR/Diag pointers. Nullptr indicates non-existence.
301
302
  std::shared_ptr<COO> coo_;
  std::shared_ptr<CSR> csr_, csc_;
303
  std::shared_ptr<Diag> diag_;
304
305
306
307
308
309
310
311
312
  // Value of the SparseMatrix
  torch::Tensor value_;
  // Shape of the SparseMatrix
  const std::vector<int64_t> shape_;
};
}  // namespace sparse
}  // namespace dgl

#endif  //  SPARSE_SPARSE_MATRIX_H_