sparse_matrix.cc 11.4 KB
Newer Older
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1
/**
2
3
 *  Copyright (c) 2022 by Contributors
 * @file sparse_matrix.cc
czkkkkkk's avatar
czkkkkkk committed
4
 * @brief DGL C++ sparse matrix implementations.
5
 */
czkkkkkk's avatar
czkkkkkk committed
6
7
8
9
// clang-format off
#include <sparse/dgl_headers.h>
// clang-format on

10
11
12
#include <sparse/elementwise_op.h>
#include <sparse/sparse_matrix.h>

13
14
#include "./utils.h"

15
16
17
18
19
namespace dgl {
namespace sparse {

SparseMatrix::SparseMatrix(
    const std::shared_ptr<COO>& coo, const std::shared_ptr<CSR>& csr,
20
21
22
23
24
25
26
27
    const std::shared_ptr<CSR>& csc, const std::shared_ptr<Diag>& diag,
    torch::Tensor value, const std::vector<int64_t>& shape)
    : coo_(coo),
      csr_(csr),
      csc_(csc),
      diag_(diag),
      value_(value),
      shape_(shape) {
28
  TORCH_CHECK(
29
30
31
      coo != nullptr || csr != nullptr || csc != nullptr || diag != nullptr,
      "At least one of CSR/COO/CSC/Diag is required to construct a "
      "SparseMatrix.")
32
33
34
  TORCH_CHECK(
      shape.size() == 2, "The shape of a sparse matrix should be ",
      "2-dimensional.");
35
36
37
38
  // NOTE: Currently all the tensors of a SparseMatrix should on the same
  // device. Do we allow the graph structure and values are on different
  // devices?
  if (coo != nullptr) {
39
40
41
42
    TORCH_CHECK(coo->indices.dim() == 2);
    TORCH_CHECK(coo->indices.size(0) == 2);
    TORCH_CHECK(coo->indices.size(1) == value.size(0));
    TORCH_CHECK(coo->indices.device() == value.device());
43
44
  }
  if (csr != nullptr) {
45
46
47
48
49
50
    TORCH_CHECK(csr->indptr.dim() == 1);
    TORCH_CHECK(csr->indices.dim() == 1);
    TORCH_CHECK(csr->indptr.size(0) == shape[0] + 1);
    TORCH_CHECK(csr->indices.size(0) == value.size(0));
    TORCH_CHECK(csr->indptr.device() == value.device());
    TORCH_CHECK(csr->indices.device() == value.device());
51
52
  }
  if (csc != nullptr) {
53
54
55
56
57
58
    TORCH_CHECK(csc->indptr.dim() == 1);
    TORCH_CHECK(csc->indices.dim() == 1);
    TORCH_CHECK(csc->indptr.size(0) == shape[1] + 1);
    TORCH_CHECK(csc->indices.size(0) == value.size(0));
    TORCH_CHECK(csc->indptr.device() == value.device());
    TORCH_CHECK(csc->indices.device() == value.device());
59
  }
60
61
62
  if (diag != nullptr) {
    TORCH_CHECK(value.size(0) == std::min(diag->num_rows, diag->num_cols));
  }
63
64
}

65
c10::intrusive_ptr<SparseMatrix> SparseMatrix::FromCOOPointer(
66
67
    const std::shared_ptr<COO>& coo, torch::Tensor value,
    const std::vector<int64_t>& shape) {
68
69
  return c10::make_intrusive<SparseMatrix>(
      coo, nullptr, nullptr, nullptr, value, shape);
70
71
}

72
c10::intrusive_ptr<SparseMatrix> SparseMatrix::FromCSRPointer(
73
74
    const std::shared_ptr<CSR>& csr, torch::Tensor value,
    const std::vector<int64_t>& shape) {
75
76
  return c10::make_intrusive<SparseMatrix>(
      nullptr, csr, nullptr, nullptr, value, shape);
77
78
}

79
c10::intrusive_ptr<SparseMatrix> SparseMatrix::FromCSCPointer(
80
81
    const std::shared_ptr<CSR>& csc, torch::Tensor value,
    const std::vector<int64_t>& shape) {
82
83
84
85
86
87
88
89
90
  return c10::make_intrusive<SparseMatrix>(
      nullptr, nullptr, csc, nullptr, value, shape);
}

c10::intrusive_ptr<SparseMatrix> SparseMatrix::FromDiagPointer(
    const std::shared_ptr<Diag>& diag, torch::Tensor value,
    const std::vector<int64_t>& shape) {
  return c10::make_intrusive<SparseMatrix>(
      nullptr, nullptr, nullptr, diag, value, shape);
91
92
}

93
c10::intrusive_ptr<SparseMatrix> SparseMatrix::FromCOO(
94
    torch::Tensor indices, torch::Tensor value,
95
    const std::vector<int64_t>& shape) {
96
97
  auto coo =
      std::make_shared<COO>(COO{shape[0], shape[1], indices, false, false});
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
  return SparseMatrix::FromCOOPointer(coo, value, shape);
}

c10::intrusive_ptr<SparseMatrix> SparseMatrix::FromCSR(
    torch::Tensor indptr, torch::Tensor indices, torch::Tensor value,
    const std::vector<int64_t>& shape) {
  auto csr = std::make_shared<CSR>(
      CSR{shape[0], shape[1], indptr, indices, torch::optional<torch::Tensor>(),
          false});
  return SparseMatrix::FromCSRPointer(csr, value, shape);
}

c10::intrusive_ptr<SparseMatrix> SparseMatrix::FromCSC(
    torch::Tensor indptr, torch::Tensor indices, torch::Tensor value,
    const std::vector<int64_t>& shape) {
  auto csc = std::make_shared<CSR>(
      CSR{shape[1], shape[0], indptr, indices, torch::optional<torch::Tensor>(),
          false});
  return SparseMatrix::FromCSCPointer(csc, value, shape);
}

119
120
121
122
123
124
c10::intrusive_ptr<SparseMatrix> SparseMatrix::FromDiag(
    torch::Tensor value, const std::vector<int64_t>& shape) {
  auto diag = std::make_shared<Diag>(Diag{shape[0], shape[1]});
  return SparseMatrix::FromDiagPointer(diag, value, shape);
}

125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
c10::intrusive_ptr<SparseMatrix> SparseMatrix::IndexSelect(
    int64_t dim, torch::Tensor ids) {
  auto id_array = TorchTensorToDGLArray(ids);
  bool rowwise = dim == 0;
  auto csr = rowwise ? this->CSRPtr() : this->CSCPtr();
  auto slice_csr = dgl::aten::CSRSliceRows(CSRToOldDGLCSR(csr), id_array);
  auto slice_value =
      this->value().index_select(0, DGLArrayToTorchTensor(slice_csr.data));
  // To prevent potential errors in future conversions to the COO format,
  // where this array might be used as an initialization array for
  // constructing COO representations, it is necessary to clear this array.
  slice_csr.data = dgl::aten::NullArray();
  auto ret = CSRFromOldDGLCSR(slice_csr);
  if (rowwise) {
    return SparseMatrix::FromCSRPointer(
        ret, slice_value, {ret->num_rows, ret->num_cols});
  } else {
    return SparseMatrix::FromCSCPointer(
        ret, slice_value, {ret->num_cols, ret->num_rows});
  }
}

c10::intrusive_ptr<SparseMatrix> SparseMatrix::RangeSelect(
    int64_t dim, int64_t start, int64_t end) {
  bool rowwise = dim == 0;
  auto csr = rowwise ? this->CSRPtr() : this->CSCPtr();
  auto slice_csr = dgl::aten::CSRSliceRows(CSRToOldDGLCSR(csr), start, end);
  auto slice_value =
      this->value().index_select(0, DGLArrayToTorchTensor(slice_csr.data));
  // To prevent potential errors in future conversions to the COO format,
  // where this array might be used as an initialization array for
  // constructing COO representations, it is necessary to clear this array.
  slice_csr.data = dgl::aten::NullArray();
  auto ret = CSRFromOldDGLCSR(slice_csr);
  if (rowwise) {
    return SparseMatrix::FromCSRPointer(
        ret, slice_value, {ret->num_rows, ret->num_cols});
  } else {
    return SparseMatrix::FromCSCPointer(
        ret, slice_value, {ret->num_cols, ret->num_rows});
  }
}

168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
c10::intrusive_ptr<SparseMatrix> SparseMatrix::Sample(
    int64_t dim, int64_t fanout, torch::Tensor ids, bool replace, bool bias) {
  bool rowwise = dim == 0;
  auto id_array = TorchTensorToDGLArray(ids);
  auto csr = rowwise ? this->CSRPtr() : this->CSCPtr();
  // Slicing matrix.
  auto slice_csr = dgl::aten::CSRSliceRows(CSRToOldDGLCSR(csr), id_array);
  auto slice_value =
      this->value().index_select(0, DGLArrayToTorchTensor(slice_csr.data));
  // Reset value indices.
  slice_csr.data = dgl::aten::NullArray();

  auto prob =
      bias ? TorchTensorToDGLArray(slice_value) : dgl::aten::NullArray();
  auto slice_id =
      dgl::aten::Range(0, id_array.NumElements(), 64, id_array->ctx);
  // Sampling all rows on sliced matrix.
  auto sample_coo =
      dgl::aten::CSRRowWiseSampling(slice_csr, slice_id, fanout, prob, replace);
  auto sample_value =
      slice_value.index_select(0, DGLArrayToTorchTensor(sample_coo.data));
  sample_coo.data = dgl::aten::NullArray();
  auto ret = COOFromOldDGLCOO(sample_coo);
  if (!rowwise) ret = COOTranspose(ret);
  return SparseMatrix::FromCOOPointer(
      ret, sample_value, {ret->num_rows, ret->num_cols});
}

196
197
198
199
200
201
202
203
c10::intrusive_ptr<SparseMatrix> SparseMatrix::ValLike(
    const c10::intrusive_ptr<SparseMatrix>& mat, torch::Tensor value) {
  TORCH_CHECK(
      mat->value().size(0) == value.size(0), "The first dimension of ",
      "the old values and the new values must be the same.");
  TORCH_CHECK(
      mat->value().device() == value.device(), "The device of the ",
      "old values and the new values must be the same.");
204
205
206
207
  const auto& shape = mat->shape();
  if (mat->HasDiag()) {
    return SparseMatrix::FromDiagPointer(mat->DiagPtr(), value, shape);
  }
208
209
  if (mat->HasCOO()) {
    return SparseMatrix::FromCOOPointer(mat->COOPtr(), value, shape);
210
211
  }
  if (mat->HasCSR()) {
212
213
    return SparseMatrix::FromCSRPointer(mat->CSRPtr(), value, shape);
  }
214
215
  TORCH_CHECK(mat->HasCSC(), "Invalid sparse format for ValLike.")
  return SparseMatrix::FromCSCPointer(mat->CSCPtr(), value, shape);
216
217
}

218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
std::shared_ptr<COO> SparseMatrix::COOPtr() {
  if (coo_ == nullptr) {
    _CreateCOO();
  }
  return coo_;
}

std::shared_ptr<CSR> SparseMatrix::CSRPtr() {
  if (csr_ == nullptr) {
    _CreateCSR();
  }
  return csr_;
}

std::shared_ptr<CSR> SparseMatrix::CSCPtr() {
  if (csc_ == nullptr) {
    _CreateCSC();
  }
  return csc_;
}

239
240
241
242
243
244
245
std::shared_ptr<Diag> SparseMatrix::DiagPtr() {
  TORCH_CHECK(
      diag_ != nullptr,
      "Cannot get Diag sparse format from a non-diagonal sparse matrix");
  return diag_;
}

246
std::tuple<torch::Tensor, torch::Tensor> SparseMatrix::COOTensors() {
247
  auto coo = COOPtr();
248
  return std::make_tuple(coo->indices.index({0}), coo->indices.index({1}));
249
250
}

251
252
253
254
255
torch::Tensor SparseMatrix::Indices() {
  auto coo = COOPtr();
  return coo->indices;
}

256
257
std::tuple<torch::Tensor, torch::Tensor, torch::optional<torch::Tensor>>
SparseMatrix::CSRTensors() {
258
259
  auto csr = CSRPtr();
  auto val = value();
260
  return std::make_tuple(csr->indptr, csr->indices, csr->value_indices);
261
262
}

263
264
std::tuple<torch::Tensor, torch::Tensor, torch::optional<torch::Tensor>>
SparseMatrix::CSCTensors() {
265
  auto csc = CSCPtr();
266
  return std::make_tuple(csc->indptr, csc->indices, csc->value_indices);
267
268
}

269
270
271
272
c10::intrusive_ptr<SparseMatrix> SparseMatrix::Transpose() const {
  auto shape = shape_;
  std::swap(shape[0], shape[1]);
  auto value = value_;
273
274
275
  if (HasDiag()) {
    return SparseMatrix::FromDiag(value, shape);
  } else if (HasCOO()) {
276
    auto coo = COOTranspose(coo_);
277
    return SparseMatrix::FromCOOPointer(coo, value, shape);
278
  } else if (HasCSR()) {
279
    return SparseMatrix::FromCSCPointer(csr_, value, shape);
280
  } else {
281
    return SparseMatrix::FromCSRPointer(csc_, value, shape);
282
283
284
  }
}

czkkkkkk's avatar
czkkkkkk committed
285
void SparseMatrix::_CreateCOO() {
286
  if (HasCOO()) return;
287
288
289
290
291
292
293
  if (HasDiag()) {
    auto indices_options = torch::TensorOptions()
                               .dtype(torch::kInt64)
                               .layout(torch::kStrided)
                               .device(this->device());
    coo_ = DiagToCOO(diag_, indices_options);
  } else if (HasCSR()) {
294
    coo_ = CSRToCOO(csr_);
czkkkkkk's avatar
czkkkkkk committed
295
  } else if (HasCSC()) {
296
    coo_ = CSCToCOO(csc_);
czkkkkkk's avatar
czkkkkkk committed
297
298
299
300
301
  } else {
    LOG(FATAL) << "SparseMatrix does not have any sparse format";
  }
}

302
303
void SparseMatrix::_CreateCSR() {
  if (HasCSR()) return;
304
305
306
307
308
309
310
  if (HasDiag()) {
    auto indices_options = torch::TensorOptions()
                               .dtype(torch::kInt64)
                               .layout(torch::kStrided)
                               .device(this->device());
    csr_ = DiagToCSR(diag_, indices_options);
  } else if (HasCOO()) {
311
312
313
314
315
316
317
318
319
320
    csr_ = COOToCSR(coo_);
  } else if (HasCSC()) {
    csr_ = CSCToCSR(csc_);
  } else {
    LOG(FATAL) << "SparseMatrix does not have any sparse format";
  }
}

void SparseMatrix::_CreateCSC() {
  if (HasCSC()) return;
321
322
323
324
325
326
327
  if (HasDiag()) {
    auto indices_options = torch::TensorOptions()
                               .dtype(torch::kInt64)
                               .layout(torch::kStrided)
                               .device(this->device());
    csc_ = DiagToCSC(diag_, indices_options);
  } else if (HasCOO()) {
328
329
330
331
332
333
334
    csc_ = COOToCSC(coo_);
  } else if (HasCSR()) {
    csc_ = CSRToCSC(csr_);
  } else {
    LOG(FATAL) << "SparseMatrix does not have any sparse format";
  }
}
335
336
337

}  // namespace sparse
}  // namespace dgl