sparse_matrix.cc 8.48 KB
Newer Older
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1
/**
2
3
 *  Copyright (c) 2022 by Contributors
 * @file sparse_matrix.cc
czkkkkkk's avatar
czkkkkkk committed
4
 * @brief DGL C++ sparse matrix implementations.
5
 */
czkkkkkk's avatar
czkkkkkk committed
6
7
8
9
10
// clang-format off
#include <sparse/dgl_headers.h>
// clang-format on

#include <c10/util/Logging.h>
11
12
#include <sparse/elementwise_op.h>
#include <sparse/sparse_matrix.h>
czkkkkkk's avatar
czkkkkkk committed
13
#include <torch/script.h>
14
15
16
17
18
19

namespace dgl {
namespace sparse {

SparseMatrix::SparseMatrix(
    const std::shared_ptr<COO>& coo, const std::shared_ptr<CSR>& csr,
20
21
22
23
24
25
26
27
    const std::shared_ptr<CSR>& csc, const std::shared_ptr<Diag>& diag,
    torch::Tensor value, const std::vector<int64_t>& shape)
    : coo_(coo),
      csr_(csr),
      csc_(csc),
      diag_(diag),
      value_(value),
      shape_(shape) {
28
  TORCH_CHECK(
29
30
31
      coo != nullptr || csr != nullptr || csc != nullptr || diag != nullptr,
      "At least one of CSR/COO/CSC/Diag is required to construct a "
      "SparseMatrix.")
32
33
34
  TORCH_CHECK(
      shape.size() == 2, "The shape of a sparse matrix should be ",
      "2-dimensional.");
35
36
37
38
  // NOTE: Currently all the tensors of a SparseMatrix should on the same
  // device. Do we allow the graph structure and values are on different
  // devices?
  if (coo != nullptr) {
39
40
41
42
    TORCH_CHECK(coo->indices.dim() == 2);
    TORCH_CHECK(coo->indices.size(0) == 2);
    TORCH_CHECK(coo->indices.size(1) == value.size(0));
    TORCH_CHECK(coo->indices.device() == value.device());
43
44
  }
  if (csr != nullptr) {
45
46
47
48
49
50
    TORCH_CHECK(csr->indptr.dim() == 1);
    TORCH_CHECK(csr->indices.dim() == 1);
    TORCH_CHECK(csr->indptr.size(0) == shape[0] + 1);
    TORCH_CHECK(csr->indices.size(0) == value.size(0));
    TORCH_CHECK(csr->indptr.device() == value.device());
    TORCH_CHECK(csr->indices.device() == value.device());
51
52
  }
  if (csc != nullptr) {
53
54
55
56
57
58
    TORCH_CHECK(csc->indptr.dim() == 1);
    TORCH_CHECK(csc->indices.dim() == 1);
    TORCH_CHECK(csc->indptr.size(0) == shape[1] + 1);
    TORCH_CHECK(csc->indices.size(0) == value.size(0));
    TORCH_CHECK(csc->indptr.device() == value.device());
    TORCH_CHECK(csc->indices.device() == value.device());
59
  }
60
61
62
  if (diag != nullptr) {
    TORCH_CHECK(value.size(0) == std::min(diag->num_rows, diag->num_cols));
  }
63
64
}

65
c10::intrusive_ptr<SparseMatrix> SparseMatrix::FromCOOPointer(
66
67
    const std::shared_ptr<COO>& coo, torch::Tensor value,
    const std::vector<int64_t>& shape) {
68
69
  return c10::make_intrusive<SparseMatrix>(
      coo, nullptr, nullptr, nullptr, value, shape);
70
71
}

72
c10::intrusive_ptr<SparseMatrix> SparseMatrix::FromCSRPointer(
73
74
    const std::shared_ptr<CSR>& csr, torch::Tensor value,
    const std::vector<int64_t>& shape) {
75
76
  return c10::make_intrusive<SparseMatrix>(
      nullptr, csr, nullptr, nullptr, value, shape);
77
78
}

79
c10::intrusive_ptr<SparseMatrix> SparseMatrix::FromCSCPointer(
80
81
    const std::shared_ptr<CSR>& csc, torch::Tensor value,
    const std::vector<int64_t>& shape) {
82
83
84
85
86
87
88
89
90
  return c10::make_intrusive<SparseMatrix>(
      nullptr, nullptr, csc, nullptr, value, shape);
}

c10::intrusive_ptr<SparseMatrix> SparseMatrix::FromDiagPointer(
    const std::shared_ptr<Diag>& diag, torch::Tensor value,
    const std::vector<int64_t>& shape) {
  return c10::make_intrusive<SparseMatrix>(
      nullptr, nullptr, nullptr, diag, value, shape);
91
92
}

93
c10::intrusive_ptr<SparseMatrix> SparseMatrix::FromCOO(
94
    torch::Tensor indices, torch::Tensor value,
95
    const std::vector<int64_t>& shape) {
96
97
  auto coo =
      std::make_shared<COO>(COO{shape[0], shape[1], indices, false, false});
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
  return SparseMatrix::FromCOOPointer(coo, value, shape);
}

c10::intrusive_ptr<SparseMatrix> SparseMatrix::FromCSR(
    torch::Tensor indptr, torch::Tensor indices, torch::Tensor value,
    const std::vector<int64_t>& shape) {
  auto csr = std::make_shared<CSR>(
      CSR{shape[0], shape[1], indptr, indices, torch::optional<torch::Tensor>(),
          false});
  return SparseMatrix::FromCSRPointer(csr, value, shape);
}

c10::intrusive_ptr<SparseMatrix> SparseMatrix::FromCSC(
    torch::Tensor indptr, torch::Tensor indices, torch::Tensor value,
    const std::vector<int64_t>& shape) {
  auto csc = std::make_shared<CSR>(
      CSR{shape[1], shape[0], indptr, indices, torch::optional<torch::Tensor>(),
          false});
  return SparseMatrix::FromCSCPointer(csc, value, shape);
}

119
120
121
122
123
124
c10::intrusive_ptr<SparseMatrix> SparseMatrix::FromDiag(
    torch::Tensor value, const std::vector<int64_t>& shape) {
  auto diag = std::make_shared<Diag>(Diag{shape[0], shape[1]});
  return SparseMatrix::FromDiagPointer(diag, value, shape);
}

125
126
127
128
129
130
131
132
c10::intrusive_ptr<SparseMatrix> SparseMatrix::ValLike(
    const c10::intrusive_ptr<SparseMatrix>& mat, torch::Tensor value) {
  TORCH_CHECK(
      mat->value().size(0) == value.size(0), "The first dimension of ",
      "the old values and the new values must be the same.");
  TORCH_CHECK(
      mat->value().device() == value.device(), "The device of the ",
      "old values and the new values must be the same.");
133
134
135
136
  const auto& shape = mat->shape();
  if (mat->HasDiag()) {
    return SparseMatrix::FromDiagPointer(mat->DiagPtr(), value, shape);
  }
137
138
  if (mat->HasCOO()) {
    return SparseMatrix::FromCOOPointer(mat->COOPtr(), value, shape);
139
140
  }
  if (mat->HasCSR()) {
141
142
    return SparseMatrix::FromCSRPointer(mat->CSRPtr(), value, shape);
  }
143
144
  TORCH_CHECK(mat->HasCSC(), "Invalid sparse format for ValLike.")
  return SparseMatrix::FromCSCPointer(mat->CSCPtr(), value, shape);
145
146
}

147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
std::shared_ptr<COO> SparseMatrix::COOPtr() {
  if (coo_ == nullptr) {
    _CreateCOO();
  }
  return coo_;
}

std::shared_ptr<CSR> SparseMatrix::CSRPtr() {
  if (csr_ == nullptr) {
    _CreateCSR();
  }
  return csr_;
}

std::shared_ptr<CSR> SparseMatrix::CSCPtr() {
  if (csc_ == nullptr) {
    _CreateCSC();
  }
  return csc_;
}

168
169
170
171
172
173
174
std::shared_ptr<Diag> SparseMatrix::DiagPtr() {
  TORCH_CHECK(
      diag_ != nullptr,
      "Cannot get Diag sparse format from a non-diagonal sparse matrix");
  return diag_;
}

175
std::tuple<torch::Tensor, torch::Tensor> SparseMatrix::COOTensors() {
176
  auto coo = COOPtr();
177
  return std::make_tuple(coo->indices.index({0}), coo->indices.index({1}));
178
179
}

180
181
182
183
184
torch::Tensor SparseMatrix::Indices() {
  auto coo = COOPtr();
  return coo->indices;
}

185
186
std::tuple<torch::Tensor, torch::Tensor, torch::optional<torch::Tensor>>
SparseMatrix::CSRTensors() {
187
188
  auto csr = CSRPtr();
  auto val = value();
189
  return std::make_tuple(csr->indptr, csr->indices, csr->value_indices);
190
191
}

192
193
std::tuple<torch::Tensor, torch::Tensor, torch::optional<torch::Tensor>>
SparseMatrix::CSCTensors() {
194
  auto csc = CSCPtr();
195
  return std::make_tuple(csc->indptr, csc->indices, csc->value_indices);
196
197
}

198
199
200
201
c10::intrusive_ptr<SparseMatrix> SparseMatrix::Transpose() const {
  auto shape = shape_;
  std::swap(shape[0], shape[1]);
  auto value = value_;
202
203
204
  if (HasDiag()) {
    return SparseMatrix::FromDiag(value, shape);
  } else if (HasCOO()) {
205
    auto coo = COOTranspose(coo_);
206
    return SparseMatrix::FromCOOPointer(coo, value, shape);
207
  } else if (HasCSR()) {
208
    return SparseMatrix::FromCSCPointer(csr_, value, shape);
209
  } else {
210
    return SparseMatrix::FromCSRPointer(csc_, value, shape);
211
212
213
  }
}

czkkkkkk's avatar
czkkkkkk committed
214
void SparseMatrix::_CreateCOO() {
215
  if (HasCOO()) return;
216
217
218
219
220
221
222
  if (HasDiag()) {
    auto indices_options = torch::TensorOptions()
                               .dtype(torch::kInt64)
                               .layout(torch::kStrided)
                               .device(this->device());
    coo_ = DiagToCOO(diag_, indices_options);
  } else if (HasCSR()) {
223
    coo_ = CSRToCOO(csr_);
czkkkkkk's avatar
czkkkkkk committed
224
  } else if (HasCSC()) {
225
    coo_ = CSCToCOO(csc_);
czkkkkkk's avatar
czkkkkkk committed
226
227
228
229
230
  } else {
    LOG(FATAL) << "SparseMatrix does not have any sparse format";
  }
}

231
232
void SparseMatrix::_CreateCSR() {
  if (HasCSR()) return;
233
234
235
236
237
238
239
  if (HasDiag()) {
    auto indices_options = torch::TensorOptions()
                               .dtype(torch::kInt64)
                               .layout(torch::kStrided)
                               .device(this->device());
    csr_ = DiagToCSR(diag_, indices_options);
  } else if (HasCOO()) {
240
241
242
243
244
245
246
247
248
249
    csr_ = COOToCSR(coo_);
  } else if (HasCSC()) {
    csr_ = CSCToCSR(csc_);
  } else {
    LOG(FATAL) << "SparseMatrix does not have any sparse format";
  }
}

void SparseMatrix::_CreateCSC() {
  if (HasCSC()) return;
250
251
252
253
254
255
256
  if (HasDiag()) {
    auto indices_options = torch::TensorOptions()
                               .dtype(torch::kInt64)
                               .layout(torch::kStrided)
                               .device(this->device());
    csc_ = DiagToCSC(diag_, indices_options);
  } else if (HasCOO()) {
257
258
259
260
261
262
263
    csc_ = COOToCSC(coo_);
  } else if (HasCSR()) {
    csc_ = CSRToCSC(csr_);
  } else {
    LOG(FATAL) << "SparseMatrix does not have any sparse format";
  }
}
264
265
266

}  // namespace sparse
}  // namespace dgl