"vscode:/vscode.git/clone" did not exist on "638cc035e5ecf5c05331c449745f327dbb15e4de"
sparse_matrix.h 6.83 KB
Newer Older
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1
/**
2
3
 *  Copyright (c) 2022 by Contributors
 * @file sparse/sparse_matrix.h
czkkkkkk's avatar
czkkkkkk committed
4
 * @brief DGL C++ sparse matrix header.
5
6
7
8
 */
#ifndef SPARSE_SPARSE_MATRIX_H_
#define SPARSE_SPARSE_MATRIX_H_

czkkkkkk's avatar
czkkkkkk committed
9
10
11
12
13
// clang-format off
#include <sparse/dgl_headers.h>
// clang-format on

#include <sparse/sparse_format.h>
14
15
16
17
#include <torch/custom_class.h>
#include <torch/script.h>

#include <memory>
18
19
#include <tuple>
#include <utility>
20
21
22
23
24
#include <vector>

namespace dgl {
namespace sparse {

25
/** @brief SparseMatrix bound to Python.  */
26
27
class SparseMatrix : public torch::CustomClassHolder {
 public:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
28
  /**
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
   * @brief General constructor to construct a sparse matrix for different
   * sparse formats. At least one of the sparse formats should be provided,
   * while others could be nullptrs.
   *
   * @param coo The COO format.
   * @param csr The CSR format.
   * @param csc The CSC format.
   * @param value Value of the sparse matrix.
   * @param shape Shape of the sparse matrix.
   */
  SparseMatrix(
      const std::shared_ptr<COO>& coo, const std::shared_ptr<CSR>& csr,
      const std::shared_ptr<CSR>& csc, torch::Tensor value,
      const std::vector<int64_t>& shape);

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
44
  /**
45
46
47
48
49
50
51
   * @brief Construct a SparseMatrix from a COO format.
   * @param coo The COO format
   * @param value Values of the sparse matrix
   * @param shape Shape of the sparse matrix
   *
   * @return SparseMatrix
   */
52
  static c10::intrusive_ptr<SparseMatrix> FromCOOPointer(
53
54
55
      const std::shared_ptr<COO>& coo, torch::Tensor value,
      const std::vector<int64_t>& shape);

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
56
  /**
57
58
59
60
61
62
63
   * @brief Construct a SparseMatrix from a CSR format.
   * @param csr The CSR format
   * @param value Values of the sparse matrix
   * @param shape Shape of the sparse matrix
   *
   * @return SparseMatrix
   */
64
  static c10::intrusive_ptr<SparseMatrix> FromCSRPointer(
65
66
67
      const std::shared_ptr<CSR>& csr, torch::Tensor value,
      const std::vector<int64_t>& shape);

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
68
  /**
69
70
71
72
73
74
75
   * @brief Construct a SparseMatrix from a CSC format.
   * @param csc The CSC format
   * @param value Values of the sparse matrix
   * @param shape Shape of the sparse matrix
   *
   * @return SparseMatrix
   */
76
  static c10::intrusive_ptr<SparseMatrix> FromCSCPointer(
77
78
79
      const std::shared_ptr<CSR>& csc, torch::Tensor value,
      const std::vector<int64_t>& shape);

80
81
  /**
   * @brief Create a SparseMatrix from tensors in COO format.
82
   * @param indices COO coordinates with shape (2, nnz).
83
84
85
86
87
88
   * @param value Values of the sparse matrix.
   * @param shape Shape of the sparse matrix.
   *
   * @return SparseMatrix
   */
  static c10::intrusive_ptr<SparseMatrix> FromCOO(
89
      torch::Tensor indices, torch::Tensor value,
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
      const std::vector<int64_t>& shape);

  /**
   * @brief Create a SparseMatrix from tensors in CSR format.
   * @param indptr Index pointer array of the CSR
   * @param indices Indices array of the CSR
   * @param value Values of the sparse matrix
   * @param shape Shape of the sparse matrix
   *
   * @return SparseMatrix
   */
  static c10::intrusive_ptr<SparseMatrix> FromCSR(
      torch::Tensor indptr, torch::Tensor indices, torch::Tensor value,
      const std::vector<int64_t>& shape);

  /**
   * @brief Create a SparseMatrix from tensors in CSC format.
   * @param indptr Index pointer array of the CSC
   * @param indices Indices array of the CSC
   * @param value Values of the sparse matrix
   * @param shape Shape of the sparse matrix
   *
   * @return SparseMatrix
   */
  static c10::intrusive_ptr<SparseMatrix> FromCSC(
      torch::Tensor indptr, torch::Tensor indices, torch::Tensor value,
      const std::vector<int64_t>& shape);

  /**
   * @brief Create a SparseMatrix from a SparseMatrix using new values.
   * @param mat An existing sparse matrix
   * @param value New values of the sparse matrix
   *
   * @return SparseMatrix
   */
  static c10::intrusive_ptr<SparseMatrix> ValLike(
      const c10::intrusive_ptr<SparseMatrix>& mat, torch::Tensor value);

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
128
  /** @return Value of the sparse matrix. */
129
  inline torch::Tensor value() const { return value_; }
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
130
  /** @return Shape of the sparse matrix. */
131
  inline const std::vector<int64_t>& shape() const { return shape_; }
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
132
  /** @return Number of non-zero values */
133
  inline int64_t nnz() const { return value_.size(0); }
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
134
  /** @return Non-zero value data type */
135
  inline caffe2::TypeMeta dtype() const { return value_.dtype(); }
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
136
  /** @return Device of the sparse matrix */
137
138
  inline torch::Device device() const { return value_.device(); }

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
139
  /** @return COO of the sparse matrix. The COO is created if not exists. */
140
  std::shared_ptr<COO> COOPtr();
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
141
  /** @return CSR of the sparse matrix. The CSR is created if not exists. */
142
  std::shared_ptr<CSR> CSRPtr();
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
143
  /** @return CSC of the sparse matrix. The CSC is created if not exists. */
144
145
  std::shared_ptr<CSR> CSCPtr();

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
146
  /** @brief Check whether this sparse matrix has COO format. */
147
  inline bool HasCOO() const { return coo_ != nullptr; }
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
148
  /** @brief Check whether this sparse matrix has CSR format. */
149
  inline bool HasCSR() const { return csr_ != nullptr; }
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
150
  /** @brief Check whether this sparse matrix has CSC format. */
151
152
  inline bool HasCSC() const { return csc_ != nullptr; }

153
154
  /** @return {row, col} tensors in the COO format. */
  std::tuple<torch::Tensor, torch::Tensor> COOTensors();
155
156
  /** @return Stacked row and col tensors in the COO format. */
  torch::Tensor Indices();
157
158
159
160
161
162
  /** @return {row, col, value_indices} tensors in the CSR format. */
  std::tuple<torch::Tensor, torch::Tensor, torch::optional<torch::Tensor>>
  CSRTensors();
  /** @return {row, col, value_indices} tensors in the CSC format. */
  std::tuple<torch::Tensor, torch::Tensor, torch::optional<torch::Tensor>>
  CSCTensors();
163

164
165
166
167
168
  /** @brief Return the transposition of the sparse matrix. It transposes the
   * first existing sparse format by checking COO, CSR, and CSC.
   */
  c10::intrusive_ptr<SparseMatrix> Transpose() const;

169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
  /**
   * @brief Return a new coalesced matrix.
   *
   * A coalesced sparse matrix satisfies the following properties:
   *   - the indices of the non-zero elements are unique,
   *   - the indices are sorted in lexicographical order.
   *
   * @return A coalesced sparse matrix.
   */
  c10::intrusive_ptr<SparseMatrix> Coalesce();

  /**
   * @brief Return true if this sparse matrix contains duplicate indices.
   * @return A bool flag.
   */
  bool HasDuplicate();

186
 private:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
187
  /** @brief Create the COO format for the sparse matrix internally */
188
  void _CreateCOO();
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
189
  /** @brief Create the CSR format for the sparse matrix internally */
190
  void _CreateCSR();
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
191
  /** @brief Create the CSC format for the sparse matrix internally */
192
193
194
195
196
197
198
199
200
201
202
203
204
205
  void _CreateCSC();

  // COO/CSC/CSR pointers. Nullptr indicates non-existence.
  std::shared_ptr<COO> coo_;
  std::shared_ptr<CSR> csr_, csc_;
  // Value of the SparseMatrix
  torch::Tensor value_;
  // Shape of the SparseMatrix
  const std::vector<int64_t> shape_;
};
}  // namespace sparse
}  // namespace dgl

#endif  //  SPARSE_SPARSE_MATRIX_H_