sparse_matrix.h 11.3 KB
Newer Older
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1
/**
2
3
 *  Copyright (c) 2022 by Contributors
 * @file sparse/sparse_matrix.h
czkkkkkk's avatar
czkkkkkk committed
4
 * @brief DGL C++ sparse matrix header.
5
6
7
8
 */
#ifndef SPARSE_SPARSE_MATRIX_H_
#define SPARSE_SPARSE_MATRIX_H_

czkkkkkk's avatar
czkkkkkk committed
9
10
11
12
13
// clang-format off
#include <sparse/dgl_headers.h>
// clang-format on

#include <sparse/sparse_format.h>
14
15
16
17
#include <torch/custom_class.h>
#include <torch/script.h>

#include <memory>
18
19
#include <tuple>
#include <utility>
20
21
22
23
24
#include <vector>

namespace dgl {
namespace sparse {

25
/** @brief SparseMatrix bound to Python.  */
26
27
class SparseMatrix : public torch::CustomClassHolder {
 public:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
28
  /**
29
30
31
32
33
34
35
36
37
38
39
40
   * @brief General constructor to construct a sparse matrix for different
   * sparse formats. At least one of the sparse formats should be provided,
   * while others could be nullptrs.
   *
   * @param coo The COO format.
   * @param csr The CSR format.
   * @param csc The CSC format.
   * @param value Value of the sparse matrix.
   * @param shape Shape of the sparse matrix.
   */
  SparseMatrix(
      const std::shared_ptr<COO>& coo, const std::shared_ptr<CSR>& csr,
41
42
      const std::shared_ptr<CSR>& csc, const std::shared_ptr<Diag>& diag,
      torch::Tensor value, const std::vector<int64_t>& shape);
43

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
44
  /**
45
46
47
48
49
50
51
   * @brief Construct a SparseMatrix from a COO format.
   * @param coo The COO format
   * @param value Values of the sparse matrix
   * @param shape Shape of the sparse matrix
   *
   * @return SparseMatrix
   */
52
  static c10::intrusive_ptr<SparseMatrix> FromCOOPointer(
53
54
55
      const std::shared_ptr<COO>& coo, torch::Tensor value,
      const std::vector<int64_t>& shape);

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
56
  /**
57
58
59
60
61
62
63
   * @brief Construct a SparseMatrix from a CSR format.
   * @param csr The CSR format
   * @param value Values of the sparse matrix
   * @param shape Shape of the sparse matrix
   *
   * @return SparseMatrix
   */
64
  static c10::intrusive_ptr<SparseMatrix> FromCSRPointer(
65
66
67
      const std::shared_ptr<CSR>& csr, torch::Tensor value,
      const std::vector<int64_t>& shape);

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
68
  /**
69
70
71
72
73
74
75
   * @brief Construct a SparseMatrix from a CSC format.
   * @param csc The CSC format
   * @param value Values of the sparse matrix
   * @param shape Shape of the sparse matrix
   *
   * @return SparseMatrix
   */
76
  static c10::intrusive_ptr<SparseMatrix> FromCSCPointer(
77
78
79
      const std::shared_ptr<CSR>& csc, torch::Tensor value,
      const std::vector<int64_t>& shape);

80
81
82
83
84
85
86
87
88
89
90
91
  /**
   * @brief Construct a SparseMatrix from a Diag format.
   * @param diag The Diag format
   * @param value Values of the sparse matrix
   * @param shape Shape of the sparse matrix
   *
   * @return SparseMatrix
   */
  static c10::intrusive_ptr<SparseMatrix> FromDiagPointer(
      const std::shared_ptr<Diag>& diag, torch::Tensor value,
      const std::vector<int64_t>& shape);

92
93
  /**
   * @brief Create a SparseMatrix from tensors in COO format.
94
   * @param indices COO coordinates with shape (2, nnz).
95
96
97
98
99
100
   * @param value Values of the sparse matrix.
   * @param shape Shape of the sparse matrix.
   *
   * @return SparseMatrix
   */
  static c10::intrusive_ptr<SparseMatrix> FromCOO(
101
      torch::Tensor indices, torch::Tensor value,
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
      const std::vector<int64_t>& shape);

  /**
   * @brief Create a SparseMatrix from tensors in CSR format.
   * @param indptr Index pointer array of the CSR
   * @param indices Indices array of the CSR
   * @param value Values of the sparse matrix
   * @param shape Shape of the sparse matrix
   *
   * @return SparseMatrix
   */
  static c10::intrusive_ptr<SparseMatrix> FromCSR(
      torch::Tensor indptr, torch::Tensor indices, torch::Tensor value,
      const std::vector<int64_t>& shape);

  /**
   * @brief Create a SparseMatrix from tensors in CSC format.
   * @param indptr Index pointer array of the CSC
   * @param indices Indices array of the CSC
   * @param value Values of the sparse matrix
   * @param shape Shape of the sparse matrix
   *
   * @return SparseMatrix
   */
  static c10::intrusive_ptr<SparseMatrix> FromCSC(
      torch::Tensor indptr, torch::Tensor indices, torch::Tensor value,
      const std::vector<int64_t>& shape);

130
131
132
133
134
135
136
137
138
139
  /**
   * @brief Create a SparseMatrix with Diag format.
   * @param value Values of the sparse matrix
   * @param shape Shape of the sparse matrix
   *
   * @return SparseMatrix
   */
  static c10::intrusive_ptr<SparseMatrix> FromDiag(
      torch::Tensor value, const std::vector<int64_t>& shape);

140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
  /**
   * @brief Create a SparseMatrix by selecting rows or columns based on provided
   * indices.
   *
   * This function allows you to create a new SparseMatrix by selecting specific
   * rows or columns from the original SparseMatrix based on the provided
   * indices. The selection can be performed either row-wise or column-wise,
   * determined by the 'dim' parameter.
   *
   * @param dim Select rows (dim=0) or columns (dim=1).
   * @param ids A tensor containing the indices of the selected rows or columns.
   *
   * @return A new SparseMatrix containing the selected rows or columns.
   *
   * @note The 'dim' parameter should be either 0 (for row-wise selection) or 1
   * (for column-wise selection).
   * @note The 'ids' tensor should contain valid indices within the range of the
   * original SparseMatrix's dimensions.
   */
  c10::intrusive_ptr<SparseMatrix> IndexSelect(int64_t dim, torch::Tensor ids);

  /**
   * @brief Create a SparseMatrix by selecting a range of rows or columns based
   * on provided indices.
   *
   * This function allows you to create a new SparseMatrix by selecting a range
   * of specific rows or columns from the original SparseMatrix based on the
   * provided indices. The selection can be performed either row-wise or
   * column-wise, determined by the 'dim' parameter.
   *
   * @param dim Select rows (dim=0) or columns (dim=1).
   * @param start The starting index (inclusive) of the range.
   * @param end The ending index (exclusive) of the range.
   *
   * @return A new SparseMatrix containing the selected range of rows or
   * columns.
   *
   * @note The 'dim' parameter should be either 0 (for row-wise selection) or 1
   * (for column-wise selection).
   * @note The 'start' and 'end' indices should be valid indices within
   * the valid range of the original SparseMatrix's dimensions.
   */
  c10::intrusive_ptr<SparseMatrix> RangeSelect(
      int64_t dim, int64_t start, int64_t end);

185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
  /**
   * @brief Create a SparseMatrix by sampling elements based on the specified
   * dimension and sample count.
   *
   * If `ids` is provided, this function samples elements from the specified
   * set of row or column IDs, resulting in a sparse matrix containing only
   * the sampled rows or columns.
   *
   * @param dim Select rows (dim=0) or columns (dim=1) for sampling.
   * @param fanout The number of elements to randomly sample from each row or
   * column.
   * @param ids An optional tensor containing row or column IDs from which to
   * sample elements.
   * @param replace Indicates whether repeated sampling of the same element
   * is allowed. If True, repeated sampling is allowed; otherwise, it is not
   * allowed.
   * @param bias An optional boolean flag indicating whether to enable biasing
   * during sampling. If True, the values of the sparse matrix will be used as
   * bias weights, meaning that elements with higher values will be more likely
   * to be sampled. Otherwise, all elements will be sampled uniformly,
   * regardless of their value.
   *
   * @return A new SparseMatrix with the same shape as the original matrix
   * containing the sampled elements.
   *
   * @note If 'replace = false' and there are fewer elements than 'fanout',
   * all non-zero elements will be sampled.
   * @note If 'ids' is not provided, the function will sample from
   * all rows or columns.
   */
  c10::intrusive_ptr<SparseMatrix> Sample(
      int64_t dim, int64_t fanout, torch::Tensor ids, bool replace, bool bias);

218
219
220
221
222
223
224
225
226
227
  /**
   * @brief Create a SparseMatrix from a SparseMatrix using new values.
   * @param mat An existing sparse matrix
   * @param value New values of the sparse matrix
   *
   * @return SparseMatrix
   */
  static c10::intrusive_ptr<SparseMatrix> ValLike(
      const c10::intrusive_ptr<SparseMatrix>& mat, torch::Tensor value);

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
228
  /** @return Value of the sparse matrix. */
229
  inline torch::Tensor value() const { return value_; }
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
230
  /** @return Shape of the sparse matrix. */
231
  inline const std::vector<int64_t>& shape() const { return shape_; }
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
232
  /** @return Number of non-zero values */
233
  inline int64_t nnz() const { return value_.size(0); }
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
234
  /** @return Non-zero value data type */
235
  inline caffe2::TypeMeta dtype() const { return value_.dtype(); }
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
236
  /** @return Device of the sparse matrix */
237
238
  inline torch::Device device() const { return value_.device(); }

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
239
  /** @return COO of the sparse matrix. The COO is created if not exists. */
240
  std::shared_ptr<COO> COOPtr();
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
241
  /** @return CSR of the sparse matrix. The CSR is created if not exists. */
242
  std::shared_ptr<CSR> CSRPtr();
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
243
  /** @return CSC of the sparse matrix. The CSC is created if not exists. */
244
  std::shared_ptr<CSR> CSCPtr();
245
246
247
248
249
  /**
   * @return Diagonal format of the sparse matrix. An error will be raised if
   * it does not have a diagonal format.
   */
  std::shared_ptr<Diag> DiagPtr();
250

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
251
  /** @brief Check whether this sparse matrix has COO format. */
252
  inline bool HasCOO() const { return coo_ != nullptr; }
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
253
  /** @brief Check whether this sparse matrix has CSR format. */
254
  inline bool HasCSR() const { return csr_ != nullptr; }
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
255
  /** @brief Check whether this sparse matrix has CSC format. */
256
  inline bool HasCSC() const { return csc_ != nullptr; }
257
258
  /** @brief Check whether this sparse matrix has Diag format. */
  inline bool HasDiag() const { return diag_ != nullptr; }
259

260
261
  /** @return {row, col} tensors in the COO format. */
  std::tuple<torch::Tensor, torch::Tensor> COOTensors();
262
263
  /** @return Stacked row and col tensors in the COO format. */
  torch::Tensor Indices();
264
265
266
267
268
269
  /** @return {row, col, value_indices} tensors in the CSR format. */
  std::tuple<torch::Tensor, torch::Tensor, torch::optional<torch::Tensor>>
  CSRTensors();
  /** @return {row, col, value_indices} tensors in the CSC format. */
  std::tuple<torch::Tensor, torch::Tensor, torch::optional<torch::Tensor>>
  CSCTensors();
270

271
272
273
274
275
  /** @brief Return the transposition of the sparse matrix. It transposes the
   * first existing sparse format by checking COO, CSR, and CSC.
   */
  c10::intrusive_ptr<SparseMatrix> Transpose() const;

276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
  /**
   * @brief Return a new coalesced matrix.
   *
   * A coalesced sparse matrix satisfies the following properties:
   *   - the indices of the non-zero elements are unique,
   *   - the indices are sorted in lexicographical order.
   *
   * @return A coalesced sparse matrix.
   */
  c10::intrusive_ptr<SparseMatrix> Coalesce();

  /**
   * @brief Return true if this sparse matrix contains duplicate indices.
   * @return A bool flag.
   */
  bool HasDuplicate();

293
 private:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
294
  /** @brief Create the COO format for the sparse matrix internally */
295
  void _CreateCOO();
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
296
  /** @brief Create the CSR format for the sparse matrix internally */
297
  void _CreateCSR();
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
298
  /** @brief Create the CSC format for the sparse matrix internally */
299
300
  void _CreateCSC();

301
  // COO/CSC/CSR/Diag pointers. Nullptr indicates non-existence.
302
303
  std::shared_ptr<COO> coo_;
  std::shared_ptr<CSR> csr_, csc_;
304
  std::shared_ptr<Diag> diag_;
305
306
307
308
309
310
311
312
313
  // Value of the SparseMatrix
  torch::Tensor value_;
  // Shape of the SparseMatrix
  const std::vector<int64_t> shape_;
};
}  // namespace sparse
}  // namespace dgl

#endif  //  SPARSE_SPARSE_MATRIX_H_