sparse_matrix.cc 4.45 KB
Newer Older
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1
/**
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
 *  Copyright (c) 2022 by Contributors
 * @file sparse_matrix.cc
 * @brief DGL C++ sparse matrix implementations
 */
#include <dmlc/logging.h>
#include <sparse/elementwise_op.h>
#include <sparse/sparse_matrix.h>

namespace dgl {
namespace sparse {

SparseMatrix::SparseMatrix(
    const std::shared_ptr<COO>& coo, const std::shared_ptr<CSR>& csr,
    const std::shared_ptr<CSR>& csc, torch::Tensor value,
    const std::vector<int64_t>& shape)
    : coo_(coo), csr_(csr), csc_(csc), value_(value), shape_(shape) {
  CHECK(coo != nullptr || csr != nullptr || csc != nullptr)
      << "At least one of CSR/COO/CSC is provided to construct a "
         "SparseMatrix";
  CHECK_EQ(shape.size(), 2)
      << "The shape of a sparse matrix should be 2-dimensional";
  // NOTE: Currently all the tensors of a SparseMatrix should on the same
  // device. Do we allow the graph structure and values are on different
  // devices?
  if (coo != nullptr) {
    CHECK_EQ(coo->row.dim(), 1);
    CHECK_EQ(coo->col.dim(), 1);
    CHECK_EQ(coo->row.size(0), coo->col.size(0));
    CHECK_EQ(coo->row.size(0), value.size(0));
    CHECK_EQ(coo->row.device(), value.device());
    CHECK_EQ(coo->col.device(), value.device());
  }
  if (csr != nullptr) {
    CHECK_EQ(csr->indptr.dim(), 1);
    CHECK_EQ(csr->indices.dim(), 1);
    CHECK_EQ(csr->indptr.size(0), shape[0] + 1);
    CHECK_EQ(csr->indices.size(0), value.size(0));
    CHECK_EQ(csr->indptr.device(), value.device());
    CHECK_EQ(csr->indices.device(), value.device());
  }
  if (csc != nullptr) {
    CHECK_EQ(csc->indptr.dim(), 1);
    CHECK_EQ(csc->indices.dim(), 1);
    CHECK_EQ(csc->indptr.size(0), shape[1] + 1);
    CHECK_EQ(csc->indices.size(0), value.size(0));
    CHECK_EQ(csc->indptr.device(), value.device());
    CHECK_EQ(csc->indices.device(), value.device());
  }
}

c10::intrusive_ptr<SparseMatrix> SparseMatrix::FromCOO(
    const std::shared_ptr<COO>& coo, torch::Tensor value,
    const std::vector<int64_t>& shape) {
  return c10::make_intrusive<SparseMatrix>(coo, nullptr, nullptr, value, shape);
}

c10::intrusive_ptr<SparseMatrix> SparseMatrix::FromCSR(
    const std::shared_ptr<CSR>& csr, torch::Tensor value,
    const std::vector<int64_t>& shape) {
  return c10::make_intrusive<SparseMatrix>(nullptr, csr, nullptr, value, shape);
}

c10::intrusive_ptr<SparseMatrix> SparseMatrix::FromCSC(
    const std::shared_ptr<CSR>& csc, torch::Tensor value,
    const std::vector<int64_t>& shape) {
  return c10::make_intrusive<SparseMatrix>(nullptr, nullptr, csc, value, shape);
}

std::shared_ptr<COO> SparseMatrix::COOPtr() {
  if (coo_ == nullptr) {
    _CreateCOO();
  }
  return coo_;
}

std::shared_ptr<CSR> SparseMatrix::CSRPtr() {
  if (csr_ == nullptr) {
    _CreateCSR();
  }
  return csr_;
}

std::shared_ptr<CSR> SparseMatrix::CSCPtr() {
  if (csc_ == nullptr) {
    _CreateCSC();
  }
  return csc_;
}

std::vector<torch::Tensor> SparseMatrix::COOTensors() {
  auto coo = COOPtr();
  auto val = value();
  return {coo->row, coo->col, val};
}

std::vector<torch::Tensor> SparseMatrix::CSRTensors() {
  auto csr = CSRPtr();
  auto val = value();
  if (csr->value_indices.has_value()) {
    val = val[csr->value_indices.value()];
  }
  return {csr->indptr, csr->indices, val};
}

std::vector<torch::Tensor> SparseMatrix::CSCTensors() {
  auto csc = CSCPtr();
  auto val = value();
  if (csc->value_indices.has_value()) {
    val = val[csc->value_indices.value()];
  }
  return {csc->indptr, csc->indices, val};
}

// TODO(zhenkun): format conversion
void SparseMatrix::_CreateCOO() {}
void SparseMatrix::_CreateCSR() {}
void SparseMatrix::_CreateCSC() {}

c10::intrusive_ptr<SparseMatrix> CreateFromCOO(
    torch::Tensor row, torch::Tensor col, torch::Tensor value,
    const std::vector<int64_t>& shape) {
  auto coo = std::make_shared<COO>(COO{row, col});
  return SparseMatrix::FromCOO(coo, value, shape);
}

c10::intrusive_ptr<SparseMatrix> CreateFromCSR(
    torch::Tensor indptr, torch::Tensor indices, torch::Tensor value,
    const std::vector<int64_t>& shape) {
  auto csr = std::make_shared<CSR>(
      CSR{indptr, indices, torch::optional<torch::Tensor>()});
  return SparseMatrix::FromCSR(csr, value, shape);
}

c10::intrusive_ptr<SparseMatrix> CreateFromCSC(
    torch::Tensor indptr, torch::Tensor indices, torch::Tensor value,
    const std::vector<int64_t>& shape) {
  auto csc = std::make_shared<CSR>(
      CSR{indptr, indices, torch::optional<torch::Tensor>()});
  return SparseMatrix::FromCSC(csc, value, shape);
}

}  // namespace sparse
}  // namespace dgl