sparse_matrix.cc 5.62 KB
Newer Older
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1
/**
2
3
 *  Copyright (c) 2022 by Contributors
 * @file sparse_matrix.cc
czkkkkkk's avatar
czkkkkkk committed
4
 * @brief DGL C++ sparse matrix implementations.
5
 */
czkkkkkk's avatar
czkkkkkk committed
6
7
8
9
10
// clang-format off
#include <sparse/dgl_headers.h>
// clang-format on

#include <c10/util/Logging.h>
11
12
#include <sparse/elementwise_op.h>
#include <sparse/sparse_matrix.h>
czkkkkkk's avatar
czkkkkkk committed
13
#include <torch/script.h>
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95

namespace dgl {
namespace sparse {

SparseMatrix::SparseMatrix(
    const std::shared_ptr<COO>& coo, const std::shared_ptr<CSR>& csr,
    const std::shared_ptr<CSR>& csc, torch::Tensor value,
    const std::vector<int64_t>& shape)
    : coo_(coo), csr_(csr), csc_(csc), value_(value), shape_(shape) {
  CHECK(coo != nullptr || csr != nullptr || csc != nullptr)
      << "At least one of CSR/COO/CSC is provided to construct a "
         "SparseMatrix";
  CHECK_EQ(shape.size(), 2)
      << "The shape of a sparse matrix should be 2-dimensional";
  // NOTE: Currently all the tensors of a SparseMatrix should on the same
  // device. Do we allow the graph structure and values are on different
  // devices?
  if (coo != nullptr) {
    CHECK_EQ(coo->row.dim(), 1);
    CHECK_EQ(coo->col.dim(), 1);
    CHECK_EQ(coo->row.size(0), coo->col.size(0));
    CHECK_EQ(coo->row.size(0), value.size(0));
    CHECK_EQ(coo->row.device(), value.device());
    CHECK_EQ(coo->col.device(), value.device());
  }
  if (csr != nullptr) {
    CHECK_EQ(csr->indptr.dim(), 1);
    CHECK_EQ(csr->indices.dim(), 1);
    CHECK_EQ(csr->indptr.size(0), shape[0] + 1);
    CHECK_EQ(csr->indices.size(0), value.size(0));
    CHECK_EQ(csr->indptr.device(), value.device());
    CHECK_EQ(csr->indices.device(), value.device());
  }
  if (csc != nullptr) {
    CHECK_EQ(csc->indptr.dim(), 1);
    CHECK_EQ(csc->indices.dim(), 1);
    CHECK_EQ(csc->indptr.size(0), shape[1] + 1);
    CHECK_EQ(csc->indices.size(0), value.size(0));
    CHECK_EQ(csc->indptr.device(), value.device());
    CHECK_EQ(csc->indices.device(), value.device());
  }
}

c10::intrusive_ptr<SparseMatrix> SparseMatrix::FromCOO(
    const std::shared_ptr<COO>& coo, torch::Tensor value,
    const std::vector<int64_t>& shape) {
  return c10::make_intrusive<SparseMatrix>(coo, nullptr, nullptr, value, shape);
}

c10::intrusive_ptr<SparseMatrix> SparseMatrix::FromCSR(
    const std::shared_ptr<CSR>& csr, torch::Tensor value,
    const std::vector<int64_t>& shape) {
  return c10::make_intrusive<SparseMatrix>(nullptr, csr, nullptr, value, shape);
}

c10::intrusive_ptr<SparseMatrix> SparseMatrix::FromCSC(
    const std::shared_ptr<CSR>& csc, torch::Tensor value,
    const std::vector<int64_t>& shape) {
  return c10::make_intrusive<SparseMatrix>(nullptr, nullptr, csc, value, shape);
}

std::shared_ptr<COO> SparseMatrix::COOPtr() {
  if (coo_ == nullptr) {
    _CreateCOO();
  }
  return coo_;
}

std::shared_ptr<CSR> SparseMatrix::CSRPtr() {
  if (csr_ == nullptr) {
    _CreateCSR();
  }
  return csr_;
}

std::shared_ptr<CSR> SparseMatrix::CSCPtr() {
  if (csc_ == nullptr) {
    _CreateCSC();
  }
  return csc_;
}

96
std::tuple<torch::Tensor, torch::Tensor> SparseMatrix::COOTensors() {
97
98
  auto coo = COOPtr();
  auto val = value();
99
  return {coo->row, coo->col};
100
101
}

102
103
std::tuple<torch::Tensor, torch::Tensor, torch::optional<torch::Tensor>>
SparseMatrix::CSRTensors() {
104
105
  auto csr = CSRPtr();
  auto val = value();
106
  return {csr->indptr, csr->indices, csr->value_indices};
107
108
}

109
110
std::tuple<torch::Tensor, torch::Tensor, torch::optional<torch::Tensor>>
SparseMatrix::CSCTensors() {
111
  auto csc = CSCPtr();
112
  return {csc->indptr, csc->indices, csc->value_indices};
113
114
}

115
116
void SparseMatrix::SetValue(torch::Tensor value) { value_ = value; }

117
118
119
120
121
122
123
124
125
126
127
128
129
130
c10::intrusive_ptr<SparseMatrix> SparseMatrix::Transpose() const {
  auto shape = shape_;
  std::swap(shape[0], shape[1]);
  auto value = value_;
  if (HasCOO()) {
    auto coo = COOTranspose(coo_);
    return SparseMatrix::FromCOO(coo, value, shape);
  } else if (HasCSR()) {
    return SparseMatrix::FromCSC(csr_, value, shape);
  } else {
    return SparseMatrix::FromCSR(csc_, value, shape);
  }
}

czkkkkkk's avatar
czkkkkkk committed
131
void SparseMatrix::_CreateCOO() {
132
  if (HasCOO()) return;
czkkkkkk's avatar
czkkkkkk committed
133
  if (HasCSR()) {
134
    coo_ = CSRToCOO(csr_);
czkkkkkk's avatar
czkkkkkk committed
135
  } else if (HasCSC()) {
136
    coo_ = CSCToCOO(csc_);
czkkkkkk's avatar
czkkkkkk committed
137
138
139
140
141
  } else {
    LOG(FATAL) << "SparseMatrix does not have any sparse format";
  }
}

142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
void SparseMatrix::_CreateCSR() {
  if (HasCSR()) return;
  if (HasCOO()) {
    csr_ = COOToCSR(coo_);
  } else if (HasCSC()) {
    csr_ = CSCToCSR(csc_);
  } else {
    LOG(FATAL) << "SparseMatrix does not have any sparse format";
  }
}

void SparseMatrix::_CreateCSC() {
  if (HasCSC()) return;
  if (HasCOO()) {
    csc_ = COOToCSC(coo_);
  } else if (HasCSR()) {
    csc_ = CSRToCSC(csr_);
  } else {
    LOG(FATAL) << "SparseMatrix does not have any sparse format";
  }
}
163
164
165
166

c10::intrusive_ptr<SparseMatrix> CreateFromCOO(
    torch::Tensor row, torch::Tensor col, torch::Tensor value,
    const std::vector<int64_t>& shape) {
167
168
  auto coo =
      std::make_shared<COO>(COO{shape[0], shape[1], row, col, false, false});
169
170
171
172
173
174
175
  return SparseMatrix::FromCOO(coo, value, shape);
}

c10::intrusive_ptr<SparseMatrix> CreateFromCSR(
    torch::Tensor indptr, torch::Tensor indices, torch::Tensor value,
    const std::vector<int64_t>& shape) {
  auto csr = std::make_shared<CSR>(
176
177
      CSR{shape[0], shape[1], indptr, indices, torch::optional<torch::Tensor>(),
          false});
178
179
180
181
182
183
184
  return SparseMatrix::FromCSR(csr, value, shape);
}

c10::intrusive_ptr<SparseMatrix> CreateFromCSC(
    torch::Tensor indptr, torch::Tensor indices, torch::Tensor value,
    const std::vector<int64_t>& shape) {
  auto csc = std::make_shared<CSR>(
185
186
      CSR{shape[1], shape[0], indptr, indices, torch::optional<torch::Tensor>(),
          false});
187
188
189
190
191
  return SparseMatrix::FromCSC(csc, value, shape);
}

}  // namespace sparse
}  // namespace dgl