sparse_matrix.h 6.46 KB
Newer Older
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1
/**
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
 *  Copyright (c) 2022 by Contributors
 * @file sparse/sparse_matrix.h
 * @brief DGL C++ sparse matrix header
 */
#ifndef SPARSE_SPARSE_MATRIX_H_
#define SPARSE_SPARSE_MATRIX_H_

#include <torch/custom_class.h>
#include <torch/script.h>

#include <memory>
#include <vector>

namespace dgl {
namespace sparse {

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
18
/** @brief SparseFormat enumeration */
19
20
enum SparseFormat { kCOO, kCSR, kCSC };

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
21
/** @brief CSR sparse structure */
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
struct CSR {
  // CSR format index pointer array of the matrix
  torch::Tensor indptr;
  // CSR format index array of the matrix
  torch::Tensor indices;
  // The element order of the sparse format. In the SparseMatrix, we have data
  // (value_) for each non-zero value. The order of non-zero values in (value_)
  // may differ from the order of non-zero entries in CSR. So we store
  // `value_indices` in CSR to indicate its relative non-zero value order to the
  // SparseMatrix. With `value_indices`, we can retrieve the correct value for
  // CSR, i.e., `value_[value_indices]`. If `value_indices` is not defined, this
  // CSR follows the same non-zero value order as the SparseMatrix.
  torch::optional<torch::Tensor> value_indices;
};

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
37
/** @brief COO sparse structure */
38
39
40
41
42
43
44
struct COO {
  // COO format row array of the matrix
  torch::Tensor row;
  // COO format column array of the matrix
  torch::Tensor col;
};

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
45
/** @brief SparseMatrix bound to Python  */
46
47
class SparseMatrix : public torch::CustomClassHolder {
 public:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
48
  /**
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
   * @brief General constructor to construct a sparse matrix for different
   * sparse formats. At least one of the sparse formats should be provided,
   * while others could be nullptrs.
   *
   * @param coo The COO format.
   * @param csr The CSR format.
   * @param csc The CSC format.
   * @param value Value of the sparse matrix.
   * @param shape Shape of the sparse matrix.
   */
  SparseMatrix(
      const std::shared_ptr<COO>& coo, const std::shared_ptr<CSR>& csr,
      const std::shared_ptr<CSR>& csc, torch::Tensor value,
      const std::vector<int64_t>& shape);

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
64
  /**
65
66
67
68
69
70
71
72
73
74
75
   * @brief Construct a SparseMatrix from a COO format.
   * @param coo The COO format
   * @param value Values of the sparse matrix
   * @param shape Shape of the sparse matrix
   *
   * @return SparseMatrix
   */
  static c10::intrusive_ptr<SparseMatrix> FromCOO(
      const std::shared_ptr<COO>& coo, torch::Tensor value,
      const std::vector<int64_t>& shape);

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
76
  /**
77
78
79
80
81
82
83
84
85
86
87
   * @brief Construct a SparseMatrix from a CSR format.
   * @param csr The CSR format
   * @param value Values of the sparse matrix
   * @param shape Shape of the sparse matrix
   *
   * @return SparseMatrix
   */
  static c10::intrusive_ptr<SparseMatrix> FromCSR(
      const std::shared_ptr<CSR>& csr, torch::Tensor value,
      const std::vector<int64_t>& shape);

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
88
  /**
89
90
91
92
93
94
95
96
97
98
99
   * @brief Construct a SparseMatrix from a CSC format.
   * @param csc The CSC format
   * @param value Values of the sparse matrix
   * @param shape Shape of the sparse matrix
   *
   * @return SparseMatrix
   */
  static c10::intrusive_ptr<SparseMatrix> FromCSC(
      const std::shared_ptr<CSR>& csc, torch::Tensor value,
      const std::vector<int64_t>& shape);

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
100
  /** @return Value of the sparse matrix. */
101
  inline torch::Tensor value() const { return value_; }
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
102
  /** @return Shape of the sparse matrix. */
103
  inline const std::vector<int64_t>& shape() const { return shape_; }
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
104
  /** @return Number of non-zero values */
105
  inline int64_t nnz() const { return value_.size(0); }
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
106
  /** @return Non-zero value data type */
107
  inline caffe2::TypeMeta dtype() const { return value_.dtype(); }
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
108
  /** @return Device of the sparse matrix */
109
110
  inline torch::Device device() const { return value_.device(); }

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
111
  /** @return COO of the sparse matrix. The COO is created if not exists. */
112
  std::shared_ptr<COO> COOPtr();
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
113
  /** @return CSR of the sparse matrix. The CSR is created if not exists. */
114
  std::shared_ptr<CSR> CSRPtr();
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
115
  /** @return CSC of the sparse matrix. The CSC is created if not exists. */
116
117
  std::shared_ptr<CSR> CSCPtr();

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
118
  /** @brief Check whether this sparse matrix has COO format. */
119
  inline bool HasCOO() const { return coo_ != nullptr; }
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
120
  /** @brief Check whether this sparse matrix has CSR format. */
121
  inline bool HasCSR() const { return csr_ != nullptr; }
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
122
  /** @brief Check whether this sparse matrix has CSC format. */
123
124
  inline bool HasCSC() const { return csc_ != nullptr; }

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
125
  /** @return {row, col, value} tensors in the COO format. */
126
  std::vector<torch::Tensor> COOTensors();
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
127
  /** @return {row, col, value} tensors in the CSR format. */
128
  std::vector<torch::Tensor> CSRTensors();
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
129
  /** @return {row, col, value} tensors in the CSC format. */
130
131
132
  std::vector<torch::Tensor> CSCTensors();

 private:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
133
  /** @brief Create the COO format for the sparse matrix internally */
134
  void _CreateCOO();
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
135
  /** @brief Create the CSR format for the sparse matrix internally */
136
  void _CreateCSR();
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
137
  /** @brief Create the CSC format for the sparse matrix internally */
138
139
140
141
142
143
144
145
146
147
148
  void _CreateCSC();

  // COO/CSC/CSR pointers. Nullptr indicates non-existence.
  std::shared_ptr<COO> coo_;
  std::shared_ptr<CSR> csr_, csc_;
  // Value of the SparseMatrix
  torch::Tensor value_;
  // Shape of the SparseMatrix
  const std::vector<int64_t> shape_;
};

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
149
/**
150
151
152
153
154
155
156
157
158
159
160
161
 * @brief Create a SparseMatrix from tensors in COO format.
 * @param row Row indices of the COO.
 * @param col Column indices of the COO.
 * @param value Values of the sparse matrix.
 * @param shape Shape of the sparse matrix.
 *
 * @return SparseMatrix
 */
c10::intrusive_ptr<SparseMatrix> CreateFromCOO(
    torch::Tensor row, torch::Tensor col, torch::Tensor value,
    const std::vector<int64_t>& shape);

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
162
/**
163
164
165
166
167
168
169
170
171
172
173
174
 * @brief Create a SparseMatrix from tensors in CSR format.
 * @param indptr Index pointer array of the CSR
 * @param indices Indices array of the CSR
 * @param value Values of the sparse matrix
 * @param shape Shape of the sparse matrix
 *
 * @return SparseMatrix
 */
c10::intrusive_ptr<SparseMatrix> CreateFromCSR(
    torch::Tensor indptr, torch::Tensor indices, torch::Tensor value,
    const std::vector<int64_t>& shape);

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
175
/**
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
 * @brief Create a SparseMatrix from tensors in CSC format.
 * @param indptr Index pointer array of the CSC
 * @param indices Indices array of the CSC
 * @param value Values of the sparse matrix
 * @param shape Shape of the sparse matrix
 *
 * @return SparseMatrix
 */
c10::intrusive_ptr<SparseMatrix> CreateFromCSC(
    torch::Tensor indptr, torch::Tensor indices, torch::Tensor value,
    const std::vector<int64_t>& shape);

}  // namespace sparse
}  // namespace dgl

#endif  //  SPARSE_SPARSE_MATRIX_H_