sparse_matrix.cc 10.3 KB
Newer Older
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1
/**
2
3
 *  Copyright (c) 2022 by Contributors
 * @file sparse_matrix.cc
czkkkkkk's avatar
czkkkkkk committed
4
 * @brief DGL C++ sparse matrix implementations.
5
 */
czkkkkkk's avatar
czkkkkkk committed
6
7
8
9
10
// clang-format off
#include <sparse/dgl_headers.h>
// clang-format on

#include <c10/util/Logging.h>
11
12
#include <sparse/elementwise_op.h>
#include <sparse/sparse_matrix.h>
czkkkkkk's avatar
czkkkkkk committed
13
#include <torch/script.h>
14

15
16
#include "./utils.h"

17
18
19
20
21
namespace dgl {
namespace sparse {

SparseMatrix::SparseMatrix(
    const std::shared_ptr<COO>& coo, const std::shared_ptr<CSR>& csr,
22
23
24
25
26
27
28
29
    const std::shared_ptr<CSR>& csc, const std::shared_ptr<Diag>& diag,
    torch::Tensor value, const std::vector<int64_t>& shape)
    : coo_(coo),
      csr_(csr),
      csc_(csc),
      diag_(diag),
      value_(value),
      shape_(shape) {
30
  TORCH_CHECK(
31
32
33
      coo != nullptr || csr != nullptr || csc != nullptr || diag != nullptr,
      "At least one of CSR/COO/CSC/Diag is required to construct a "
      "SparseMatrix.")
34
35
36
  TORCH_CHECK(
      shape.size() == 2, "The shape of a sparse matrix should be ",
      "2-dimensional.");
37
38
39
40
  // NOTE: Currently all the tensors of a SparseMatrix should on the same
  // device. Do we allow the graph structure and values are on different
  // devices?
  if (coo != nullptr) {
41
42
43
44
    TORCH_CHECK(coo->indices.dim() == 2);
    TORCH_CHECK(coo->indices.size(0) == 2);
    TORCH_CHECK(coo->indices.size(1) == value.size(0));
    TORCH_CHECK(coo->indices.device() == value.device());
45
46
  }
  if (csr != nullptr) {
47
48
49
50
51
52
    TORCH_CHECK(csr->indptr.dim() == 1);
    TORCH_CHECK(csr->indices.dim() == 1);
    TORCH_CHECK(csr->indptr.size(0) == shape[0] + 1);
    TORCH_CHECK(csr->indices.size(0) == value.size(0));
    TORCH_CHECK(csr->indptr.device() == value.device());
    TORCH_CHECK(csr->indices.device() == value.device());
53
54
  }
  if (csc != nullptr) {
55
56
57
58
59
60
    TORCH_CHECK(csc->indptr.dim() == 1);
    TORCH_CHECK(csc->indices.dim() == 1);
    TORCH_CHECK(csc->indptr.size(0) == shape[1] + 1);
    TORCH_CHECK(csc->indices.size(0) == value.size(0));
    TORCH_CHECK(csc->indptr.device() == value.device());
    TORCH_CHECK(csc->indices.device() == value.device());
61
  }
62
63
64
  if (diag != nullptr) {
    TORCH_CHECK(value.size(0) == std::min(diag->num_rows, diag->num_cols));
  }
65
66
}

67
c10::intrusive_ptr<SparseMatrix> SparseMatrix::FromCOOPointer(
68
69
    const std::shared_ptr<COO>& coo, torch::Tensor value,
    const std::vector<int64_t>& shape) {
70
71
  return c10::make_intrusive<SparseMatrix>(
      coo, nullptr, nullptr, nullptr, value, shape);
72
73
}

74
c10::intrusive_ptr<SparseMatrix> SparseMatrix::FromCSRPointer(
75
76
    const std::shared_ptr<CSR>& csr, torch::Tensor value,
    const std::vector<int64_t>& shape) {
77
78
  return c10::make_intrusive<SparseMatrix>(
      nullptr, csr, nullptr, nullptr, value, shape);
79
80
}

81
c10::intrusive_ptr<SparseMatrix> SparseMatrix::FromCSCPointer(
82
83
    const std::shared_ptr<CSR>& csc, torch::Tensor value,
    const std::vector<int64_t>& shape) {
84
85
86
87
88
89
90
91
92
  return c10::make_intrusive<SparseMatrix>(
      nullptr, nullptr, csc, nullptr, value, shape);
}

c10::intrusive_ptr<SparseMatrix> SparseMatrix::FromDiagPointer(
    const std::shared_ptr<Diag>& diag, torch::Tensor value,
    const std::vector<int64_t>& shape) {
  return c10::make_intrusive<SparseMatrix>(
      nullptr, nullptr, nullptr, diag, value, shape);
93
94
}

95
c10::intrusive_ptr<SparseMatrix> SparseMatrix::FromCOO(
96
    torch::Tensor indices, torch::Tensor value,
97
    const std::vector<int64_t>& shape) {
98
99
  auto coo =
      std::make_shared<COO>(COO{shape[0], shape[1], indices, false, false});
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
  return SparseMatrix::FromCOOPointer(coo, value, shape);
}

c10::intrusive_ptr<SparseMatrix> SparseMatrix::FromCSR(
    torch::Tensor indptr, torch::Tensor indices, torch::Tensor value,
    const std::vector<int64_t>& shape) {
  auto csr = std::make_shared<CSR>(
      CSR{shape[0], shape[1], indptr, indices, torch::optional<torch::Tensor>(),
          false});
  return SparseMatrix::FromCSRPointer(csr, value, shape);
}

c10::intrusive_ptr<SparseMatrix> SparseMatrix::FromCSC(
    torch::Tensor indptr, torch::Tensor indices, torch::Tensor value,
    const std::vector<int64_t>& shape) {
  auto csc = std::make_shared<CSR>(
      CSR{shape[1], shape[0], indptr, indices, torch::optional<torch::Tensor>(),
          false});
  return SparseMatrix::FromCSCPointer(csc, value, shape);
}

121
122
123
124
125
126
c10::intrusive_ptr<SparseMatrix> SparseMatrix::FromDiag(
    torch::Tensor value, const std::vector<int64_t>& shape) {
  auto diag = std::make_shared<Diag>(Diag{shape[0], shape[1]});
  return SparseMatrix::FromDiagPointer(diag, value, shape);
}

127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
c10::intrusive_ptr<SparseMatrix> SparseMatrix::IndexSelect(
    int64_t dim, torch::Tensor ids) {
  auto id_array = TorchTensorToDGLArray(ids);
  bool rowwise = dim == 0;
  auto csr = rowwise ? this->CSRPtr() : this->CSCPtr();
  auto slice_csr = dgl::aten::CSRSliceRows(CSRToOldDGLCSR(csr), id_array);
  auto slice_value =
      this->value().index_select(0, DGLArrayToTorchTensor(slice_csr.data));
  // To prevent potential errors in future conversions to the COO format,
  // where this array might be used as an initialization array for
  // constructing COO representations, it is necessary to clear this array.
  slice_csr.data = dgl::aten::NullArray();
  auto ret = CSRFromOldDGLCSR(slice_csr);
  if (rowwise) {
    return SparseMatrix::FromCSRPointer(
        ret, slice_value, {ret->num_rows, ret->num_cols});
  } else {
    return SparseMatrix::FromCSCPointer(
        ret, slice_value, {ret->num_cols, ret->num_rows});
  }
}

c10::intrusive_ptr<SparseMatrix> SparseMatrix::RangeSelect(
    int64_t dim, int64_t start, int64_t end) {
  bool rowwise = dim == 0;
  auto csr = rowwise ? this->CSRPtr() : this->CSCPtr();
  auto slice_csr = dgl::aten::CSRSliceRows(CSRToOldDGLCSR(csr), start, end);
  auto slice_value =
      this->value().index_select(0, DGLArrayToTorchTensor(slice_csr.data));
  // To prevent potential errors in future conversions to the COO format,
  // where this array might be used as an initialization array for
  // constructing COO representations, it is necessary to clear this array.
  slice_csr.data = dgl::aten::NullArray();
  auto ret = CSRFromOldDGLCSR(slice_csr);
  if (rowwise) {
    return SparseMatrix::FromCSRPointer(
        ret, slice_value, {ret->num_rows, ret->num_cols});
  } else {
    return SparseMatrix::FromCSCPointer(
        ret, slice_value, {ret->num_cols, ret->num_rows});
  }
}

170
171
172
173
174
175
176
177
c10::intrusive_ptr<SparseMatrix> SparseMatrix::ValLike(
    const c10::intrusive_ptr<SparseMatrix>& mat, torch::Tensor value) {
  TORCH_CHECK(
      mat->value().size(0) == value.size(0), "The first dimension of ",
      "the old values and the new values must be the same.");
  TORCH_CHECK(
      mat->value().device() == value.device(), "The device of the ",
      "old values and the new values must be the same.");
178
179
180
181
  const auto& shape = mat->shape();
  if (mat->HasDiag()) {
    return SparseMatrix::FromDiagPointer(mat->DiagPtr(), value, shape);
  }
182
183
  if (mat->HasCOO()) {
    return SparseMatrix::FromCOOPointer(mat->COOPtr(), value, shape);
184
185
  }
  if (mat->HasCSR()) {
186
187
    return SparseMatrix::FromCSRPointer(mat->CSRPtr(), value, shape);
  }
188
189
  TORCH_CHECK(mat->HasCSC(), "Invalid sparse format for ValLike.")
  return SparseMatrix::FromCSCPointer(mat->CSCPtr(), value, shape);
190
191
}

192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
std::shared_ptr<COO> SparseMatrix::COOPtr() {
  if (coo_ == nullptr) {
    _CreateCOO();
  }
  return coo_;
}

std::shared_ptr<CSR> SparseMatrix::CSRPtr() {
  if (csr_ == nullptr) {
    _CreateCSR();
  }
  return csr_;
}

std::shared_ptr<CSR> SparseMatrix::CSCPtr() {
  if (csc_ == nullptr) {
    _CreateCSC();
  }
  return csc_;
}

213
214
215
216
217
218
219
std::shared_ptr<Diag> SparseMatrix::DiagPtr() {
  TORCH_CHECK(
      diag_ != nullptr,
      "Cannot get Diag sparse format from a non-diagonal sparse matrix");
  return diag_;
}

220
std::tuple<torch::Tensor, torch::Tensor> SparseMatrix::COOTensors() {
221
  auto coo = COOPtr();
222
  return std::make_tuple(coo->indices.index({0}), coo->indices.index({1}));
223
224
}

225
226
227
228
229
torch::Tensor SparseMatrix::Indices() {
  auto coo = COOPtr();
  return coo->indices;
}

230
231
std::tuple<torch::Tensor, torch::Tensor, torch::optional<torch::Tensor>>
SparseMatrix::CSRTensors() {
232
233
  auto csr = CSRPtr();
  auto val = value();
234
  return std::make_tuple(csr->indptr, csr->indices, csr->value_indices);
235
236
}

237
238
std::tuple<torch::Tensor, torch::Tensor, torch::optional<torch::Tensor>>
SparseMatrix::CSCTensors() {
239
  auto csc = CSCPtr();
240
  return std::make_tuple(csc->indptr, csc->indices, csc->value_indices);
241
242
}

243
244
245
246
c10::intrusive_ptr<SparseMatrix> SparseMatrix::Transpose() const {
  auto shape = shape_;
  std::swap(shape[0], shape[1]);
  auto value = value_;
247
248
249
  if (HasDiag()) {
    return SparseMatrix::FromDiag(value, shape);
  } else if (HasCOO()) {
250
    auto coo = COOTranspose(coo_);
251
    return SparseMatrix::FromCOOPointer(coo, value, shape);
252
  } else if (HasCSR()) {
253
    return SparseMatrix::FromCSCPointer(csr_, value, shape);
254
  } else {
255
    return SparseMatrix::FromCSRPointer(csc_, value, shape);
256
257
258
  }
}

czkkkkkk's avatar
czkkkkkk committed
259
void SparseMatrix::_CreateCOO() {
260
  if (HasCOO()) return;
261
262
263
264
265
266
267
  if (HasDiag()) {
    auto indices_options = torch::TensorOptions()
                               .dtype(torch::kInt64)
                               .layout(torch::kStrided)
                               .device(this->device());
    coo_ = DiagToCOO(diag_, indices_options);
  } else if (HasCSR()) {
268
    coo_ = CSRToCOO(csr_);
czkkkkkk's avatar
czkkkkkk committed
269
  } else if (HasCSC()) {
270
    coo_ = CSCToCOO(csc_);
czkkkkkk's avatar
czkkkkkk committed
271
272
273
274
275
  } else {
    LOG(FATAL) << "SparseMatrix does not have any sparse format";
  }
}

276
277
void SparseMatrix::_CreateCSR() {
  if (HasCSR()) return;
278
279
280
281
282
283
284
  if (HasDiag()) {
    auto indices_options = torch::TensorOptions()
                               .dtype(torch::kInt64)
                               .layout(torch::kStrided)
                               .device(this->device());
    csr_ = DiagToCSR(diag_, indices_options);
  } else if (HasCOO()) {
285
286
287
288
289
290
291
292
293
294
    csr_ = COOToCSR(coo_);
  } else if (HasCSC()) {
    csr_ = CSCToCSR(csc_);
  } else {
    LOG(FATAL) << "SparseMatrix does not have any sparse format";
  }
}

void SparseMatrix::_CreateCSC() {
  if (HasCSC()) return;
295
296
297
298
299
300
301
  if (HasDiag()) {
    auto indices_options = torch::TensorOptions()
                               .dtype(torch::kInt64)
                               .layout(torch::kStrided)
                               .device(this->device());
    csc_ = DiagToCSC(diag_, indices_options);
  } else if (HasCOO()) {
302
303
304
305
306
307
308
    csc_ = COOToCSC(coo_);
  } else if (HasCSR()) {
    csc_ = CSRToCSC(csr_);
  } else {
    LOG(FATAL) << "SparseMatrix does not have any sparse format";
  }
}
309
310
311

}  // namespace sparse
}  // namespace dgl