"vscode:/vscode.git/clone" did not exist on "52a4480d70592dde520240b1694184612108ca6f"
sparse_matrix.cc 11.5 KB
Newer Older
sangwzh's avatar
sangwzh committed
1
// !!! This is a file automatically generated by hipify!!!
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
2
/**
3
4
 *  Copyright (c) 2022 by Contributors
 * @file sparse_matrix.cc
czkkkkkk's avatar
czkkkkkk committed
5
 * @brief DGL C++ sparse matrix implementations.
6
 */
czkkkkkk's avatar
czkkkkkk committed
7
8
9
10
// clang-format off
#include <sparse/dgl_headers.h>
// clang-format on

11
#include <c10/util/Logging.h>
12
13
#include <sparse/elementwise_op.h>
#include <sparse/sparse_matrix.h>
14
#include <torch/script.h>
15

sangwzh's avatar
sangwzh committed
16
#include "utils.h"
17

18
19
20
21
22
namespace dgl {
namespace sparse {

SparseMatrix::SparseMatrix(
    const std::shared_ptr<COO>& coo, const std::shared_ptr<CSR>& csr,
23
24
25
26
27
28
29
30
    const std::shared_ptr<CSR>& csc, const std::shared_ptr<Diag>& diag,
    torch::Tensor value, const std::vector<int64_t>& shape)
    : coo_(coo),
      csr_(csr),
      csc_(csc),
      diag_(diag),
      value_(value),
      shape_(shape) {
31
  TORCH_CHECK(
32
33
34
      coo != nullptr || csr != nullptr || csc != nullptr || diag != nullptr,
      "At least one of CSR/COO/CSC/Diag is required to construct a "
      "SparseMatrix.")
35
36
37
  TORCH_CHECK(
      shape.size() == 2, "The shape of a sparse matrix should be ",
      "2-dimensional.");
38
39
40
41
  // NOTE: Currently all the tensors of a SparseMatrix should on the same
  // device. Do we allow the graph structure and values are on different
  // devices?
  if (coo != nullptr) {
42
43
44
45
    TORCH_CHECK(coo->indices.dim() == 2);
    TORCH_CHECK(coo->indices.size(0) == 2);
    TORCH_CHECK(coo->indices.size(1) == value.size(0));
    TORCH_CHECK(coo->indices.device() == value.device());
46
47
  }
  if (csr != nullptr) {
48
49
50
51
52
53
    TORCH_CHECK(csr->indptr.dim() == 1);
    TORCH_CHECK(csr->indices.dim() == 1);
    TORCH_CHECK(csr->indptr.size(0) == shape[0] + 1);
    TORCH_CHECK(csr->indices.size(0) == value.size(0));
    TORCH_CHECK(csr->indptr.device() == value.device());
    TORCH_CHECK(csr->indices.device() == value.device());
54
55
  }
  if (csc != nullptr) {
56
57
58
59
60
61
    TORCH_CHECK(csc->indptr.dim() == 1);
    TORCH_CHECK(csc->indices.dim() == 1);
    TORCH_CHECK(csc->indptr.size(0) == shape[1] + 1);
    TORCH_CHECK(csc->indices.size(0) == value.size(0));
    TORCH_CHECK(csc->indptr.device() == value.device());
    TORCH_CHECK(csc->indices.device() == value.device());
62
  }
63
64
65
  if (diag != nullptr) {
    TORCH_CHECK(value.size(0) == std::min(diag->num_rows, diag->num_cols));
  }
66
67
}

68
c10::intrusive_ptr<SparseMatrix> SparseMatrix::FromCOOPointer(
69
70
    const std::shared_ptr<COO>& coo, torch::Tensor value,
    const std::vector<int64_t>& shape) {
71
72
  return c10::make_intrusive<SparseMatrix>(
      coo, nullptr, nullptr, nullptr, value, shape);
73
74
}

75
c10::intrusive_ptr<SparseMatrix> SparseMatrix::FromCSRPointer(
76
77
    const std::shared_ptr<CSR>& csr, torch::Tensor value,
    const std::vector<int64_t>& shape) {
78
79
  return c10::make_intrusive<SparseMatrix>(
      nullptr, csr, nullptr, nullptr, value, shape);
80
81
}

82
c10::intrusive_ptr<SparseMatrix> SparseMatrix::FromCSCPointer(
83
84
    const std::shared_ptr<CSR>& csc, torch::Tensor value,
    const std::vector<int64_t>& shape) {
85
86
87
88
89
90
91
92
93
  return c10::make_intrusive<SparseMatrix>(
      nullptr, nullptr, csc, nullptr, value, shape);
}

c10::intrusive_ptr<SparseMatrix> SparseMatrix::FromDiagPointer(
    const std::shared_ptr<Diag>& diag, torch::Tensor value,
    const std::vector<int64_t>& shape) {
  return c10::make_intrusive<SparseMatrix>(
      nullptr, nullptr, nullptr, diag, value, shape);
94
95
}

96
c10::intrusive_ptr<SparseMatrix> SparseMatrix::FromCOO(
97
    torch::Tensor indices, torch::Tensor value,
98
    const std::vector<int64_t>& shape) {
99
100
  auto coo =
      std::make_shared<COO>(COO{shape[0], shape[1], indices, false, false});
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
  return SparseMatrix::FromCOOPointer(coo, value, shape);
}

c10::intrusive_ptr<SparseMatrix> SparseMatrix::FromCSR(
    torch::Tensor indptr, torch::Tensor indices, torch::Tensor value,
    const std::vector<int64_t>& shape) {
  auto csr = std::make_shared<CSR>(
      CSR{shape[0], shape[1], indptr, indices, torch::optional<torch::Tensor>(),
          false});
  return SparseMatrix::FromCSRPointer(csr, value, shape);
}

c10::intrusive_ptr<SparseMatrix> SparseMatrix::FromCSC(
    torch::Tensor indptr, torch::Tensor indices, torch::Tensor value,
    const std::vector<int64_t>& shape) {
  auto csc = std::make_shared<CSR>(
      CSR{shape[1], shape[0], indptr, indices, torch::optional<torch::Tensor>(),
          false});
  return SparseMatrix::FromCSCPointer(csc, value, shape);
}

122
123
124
125
126
127
c10::intrusive_ptr<SparseMatrix> SparseMatrix::FromDiag(
    torch::Tensor value, const std::vector<int64_t>& shape) {
  auto diag = std::make_shared<Diag>(Diag{shape[0], shape[1]});
  return SparseMatrix::FromDiagPointer(diag, value, shape);
}

128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
c10::intrusive_ptr<SparseMatrix> SparseMatrix::IndexSelect(
    int64_t dim, torch::Tensor ids) {
  auto id_array = TorchTensorToDGLArray(ids);
  bool rowwise = dim == 0;
  auto csr = rowwise ? this->CSRPtr() : this->CSCPtr();
  auto slice_csr = dgl::aten::CSRSliceRows(CSRToOldDGLCSR(csr), id_array);
  auto slice_value =
      this->value().index_select(0, DGLArrayToTorchTensor(slice_csr.data));
  // To prevent potential errors in future conversions to the COO format,
  // where this array might be used as an initialization array for
  // constructing COO representations, it is necessary to clear this array.
  slice_csr.data = dgl::aten::NullArray();
  auto ret = CSRFromOldDGLCSR(slice_csr);
  if (rowwise) {
    return SparseMatrix::FromCSRPointer(
        ret, slice_value, {ret->num_rows, ret->num_cols});
  } else {
    return SparseMatrix::FromCSCPointer(
        ret, slice_value, {ret->num_cols, ret->num_rows});
  }
}

c10::intrusive_ptr<SparseMatrix> SparseMatrix::RangeSelect(
    int64_t dim, int64_t start, int64_t end) {
  bool rowwise = dim == 0;
  auto csr = rowwise ? this->CSRPtr() : this->CSCPtr();
  auto slice_csr = dgl::aten::CSRSliceRows(CSRToOldDGLCSR(csr), start, end);
  auto slice_value =
      this->value().index_select(0, DGLArrayToTorchTensor(slice_csr.data));
  // To prevent potential errors in future conversions to the COO format,
  // where this array might be used as an initialization array for
  // constructing COO representations, it is necessary to clear this array.
  slice_csr.data = dgl::aten::NullArray();
  auto ret = CSRFromOldDGLCSR(slice_csr);
  if (rowwise) {
    return SparseMatrix::FromCSRPointer(
        ret, slice_value, {ret->num_rows, ret->num_cols});
  } else {
    return SparseMatrix::FromCSCPointer(
        ret, slice_value, {ret->num_cols, ret->num_rows});
  }
}

171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
c10::intrusive_ptr<SparseMatrix> SparseMatrix::Sample(
    int64_t dim, int64_t fanout, torch::Tensor ids, bool replace, bool bias) {
  bool rowwise = dim == 0;
  auto id_array = TorchTensorToDGLArray(ids);
  auto csr = rowwise ? this->CSRPtr() : this->CSCPtr();
  // Slicing matrix.
  auto slice_csr = dgl::aten::CSRSliceRows(CSRToOldDGLCSR(csr), id_array);
  auto slice_value =
      this->value().index_select(0, DGLArrayToTorchTensor(slice_csr.data));
  // Reset value indices.
  slice_csr.data = dgl::aten::NullArray();

  auto prob =
      bias ? TorchTensorToDGLArray(slice_value) : dgl::aten::NullArray();
  auto slice_id =
      dgl::aten::Range(0, id_array.NumElements(), 64, id_array->ctx);
  // Sampling all rows on sliced matrix.
  auto sample_coo =
      dgl::aten::CSRRowWiseSampling(slice_csr, slice_id, fanout, prob, replace);
  auto sample_value =
      slice_value.index_select(0, DGLArrayToTorchTensor(sample_coo.data));
  sample_coo.data = dgl::aten::NullArray();
  auto ret = COOFromOldDGLCOO(sample_coo);
  if (!rowwise) ret = COOTranspose(ret);
  return SparseMatrix::FromCOOPointer(
      ret, sample_value, {ret->num_rows, ret->num_cols});
}

199
200
201
202
203
204
205
206
c10::intrusive_ptr<SparseMatrix> SparseMatrix::ValLike(
    const c10::intrusive_ptr<SparseMatrix>& mat, torch::Tensor value) {
  TORCH_CHECK(
      mat->value().size(0) == value.size(0), "The first dimension of ",
      "the old values and the new values must be the same.");
  TORCH_CHECK(
      mat->value().device() == value.device(), "The device of the ",
      "old values and the new values must be the same.");
207
208
209
210
  const auto& shape = mat->shape();
  if (mat->HasDiag()) {
    return SparseMatrix::FromDiagPointer(mat->DiagPtr(), value, shape);
  }
211
212
  if (mat->HasCOO()) {
    return SparseMatrix::FromCOOPointer(mat->COOPtr(), value, shape);
213
214
  }
  if (mat->HasCSR()) {
215
216
    return SparseMatrix::FromCSRPointer(mat->CSRPtr(), value, shape);
  }
217
218
  TORCH_CHECK(mat->HasCSC(), "Invalid sparse format for ValLike.")
  return SparseMatrix::FromCSCPointer(mat->CSCPtr(), value, shape);
219
220
}

221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
std::shared_ptr<COO> SparseMatrix::COOPtr() {
  if (coo_ == nullptr) {
    _CreateCOO();
  }
  return coo_;
}

std::shared_ptr<CSR> SparseMatrix::CSRPtr() {
  if (csr_ == nullptr) {
    _CreateCSR();
  }
  return csr_;
}

std::shared_ptr<CSR> SparseMatrix::CSCPtr() {
  if (csc_ == nullptr) {
    _CreateCSC();
  }
  return csc_;
}

242
243
244
245
246
247
248
std::shared_ptr<Diag> SparseMatrix::DiagPtr() {
  TORCH_CHECK(
      diag_ != nullptr,
      "Cannot get Diag sparse format from a non-diagonal sparse matrix");
  return diag_;
}

249
std::tuple<torch::Tensor, torch::Tensor> SparseMatrix::COOTensors() {
250
  auto coo = COOPtr();
251
  return std::make_tuple(coo->indices.index({0}), coo->indices.index({1}));
252
253
}

254
255
256
257
258
torch::Tensor SparseMatrix::Indices() {
  auto coo = COOPtr();
  return coo->indices;
}

259
260
std::tuple<torch::Tensor, torch::Tensor, torch::optional<torch::Tensor>>
SparseMatrix::CSRTensors() {
261
262
  auto csr = CSRPtr();
  auto val = value();
263
  return std::make_tuple(csr->indptr, csr->indices, csr->value_indices);
264
265
}

266
267
std::tuple<torch::Tensor, torch::Tensor, torch::optional<torch::Tensor>>
SparseMatrix::CSCTensors() {
268
  auto csc = CSCPtr();
269
  return std::make_tuple(csc->indptr, csc->indices, csc->value_indices);
270
271
}

272
273
274
275
c10::intrusive_ptr<SparseMatrix> SparseMatrix::Transpose() const {
  auto shape = shape_;
  std::swap(shape[0], shape[1]);
  auto value = value_;
276
277
278
  if (HasDiag()) {
    return SparseMatrix::FromDiag(value, shape);
  } else if (HasCOO()) {
279
    auto coo = COOTranspose(coo_);
280
    return SparseMatrix::FromCOOPointer(coo, value, shape);
281
  } else if (HasCSR()) {
282
    return SparseMatrix::FromCSCPointer(csr_, value, shape);
283
  } else {
284
    return SparseMatrix::FromCSRPointer(csc_, value, shape);
285
286
287
  }
}

czkkkkkk's avatar
czkkkkkk committed
288
void SparseMatrix::_CreateCOO() {
289
  if (HasCOO()) return;
290
291
292
293
294
295
296
  if (HasDiag()) {
    auto indices_options = torch::TensorOptions()
                               .dtype(torch::kInt64)
                               .layout(torch::kStrided)
                               .device(this->device());
    coo_ = DiagToCOO(diag_, indices_options);
  } else if (HasCSR()) {
297
    coo_ = CSRToCOO(csr_);
czkkkkkk's avatar
czkkkkkk committed
298
  } else if (HasCSC()) {
299
    coo_ = CSCToCOO(csc_);
czkkkkkk's avatar
czkkkkkk committed
300
301
302
303
304
  } else {
    LOG(FATAL) << "SparseMatrix does not have any sparse format";
  }
}

305
306
void SparseMatrix::_CreateCSR() {
  if (HasCSR()) return;
307
308
309
310
311
312
313
  if (HasDiag()) {
    auto indices_options = torch::TensorOptions()
                               .dtype(torch::kInt64)
                               .layout(torch::kStrided)
                               .device(this->device());
    csr_ = DiagToCSR(diag_, indices_options);
  } else if (HasCOO()) {
314
315
316
317
318
319
320
321
322
323
    csr_ = COOToCSR(coo_);
  } else if (HasCSC()) {
    csr_ = CSCToCSR(csc_);
  } else {
    LOG(FATAL) << "SparseMatrix does not have any sparse format";
  }
}

void SparseMatrix::_CreateCSC() {
  if (HasCSC()) return;
324
325
326
327
328
329
330
  if (HasDiag()) {
    auto indices_options = torch::TensorOptions()
                               .dtype(torch::kInt64)
                               .layout(torch::kStrided)
                               .device(this->device());
    csc_ = DiagToCSC(diag_, indices_options);
  } else if (HasCOO()) {
331
332
333
334
335
336
337
    csc_ = COOToCSC(coo_);
  } else if (HasCSR()) {
    csc_ = CSRToCSC(csr_);
  } else {
    LOG(FATAL) << "SparseMatrix does not have any sparse format";
  }
}
338
339
340

}  // namespace sparse
}  // namespace dgl