sparse_format.cc 4.86 KB
Newer Older
czkkkkkk's avatar
czkkkkkk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
/**
 *  Copyright (c) 2022 by Contributors
 * @file sparse_format.cc
 * @brief DGL C++ sparse format implementations.
 */
// clang-format off
#include <sparse/dgl_headers.h>
// clang-format on

#include <sparse/sparse_format.h>

#include "./utils.h"

namespace dgl {
namespace sparse {

17
std::shared_ptr<COO> COOFromOldDGLCOO(const aten::COOMatrix& dgl_coo) {
czkkkkkk's avatar
czkkkkkk committed
18
19
  auto row = DGLArrayToTorchTensor(dgl_coo.row);
  auto col = DGLArrayToTorchTensor(dgl_coo.col);
20
  TORCH_CHECK(aten::IsNullArray(dgl_coo.data));
21
  auto indices = torch::stack({row, col});
22
  return std::make_shared<COO>(
23
      COO{dgl_coo.num_rows, dgl_coo.num_cols, indices, dgl_coo.row_sorted,
24
25
26
27
          dgl_coo.col_sorted});
}

aten::COOMatrix COOToOldDGLCOO(const std::shared_ptr<COO>& coo) {
28
29
  auto row = TorchTensorToDGLArray(coo->indices.index({0}));
  auto col = TorchTensorToDGLArray(coo->indices.index({1}));
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
  return aten::COOMatrix(
      coo->num_rows, coo->num_cols, row, col, aten::NullArray(),
      coo->row_sorted, coo->col_sorted);
}

std::shared_ptr<CSR> CSRFromOldDGLCSR(const aten::CSRMatrix& dgl_csr) {
  auto indptr = DGLArrayToTorchTensor(dgl_csr.indptr);
  auto indices = DGLArrayToTorchTensor(dgl_csr.indices);
  auto value_indices = DGLArrayToOptionalTorchTensor(dgl_csr.data);
  return std::make_shared<CSR>(
      CSR{dgl_csr.num_rows, dgl_csr.num_cols, indptr, indices, value_indices,
          dgl_csr.sorted});
}

aten::CSRMatrix CSRToOldDGLCSR(const std::shared_ptr<CSR>& csr) {
  auto indptr = TorchTensorToDGLArray(csr->indptr);
  auto indices = TorchTensorToDGLArray(csr->indices);
  auto data = OptionalTorchTensorToDGLArray(csr->value_indices);
  return aten::CSRMatrix(
      csr->num_rows, csr->num_cols, indptr, indices, data, csr->sorted);
}

52
53
torch::Tensor COOToTorchCOO(
    const std::shared_ptr<COO>& coo, torch::Tensor value) {
54
  torch::Tensor indices = coo->indices;
55
56
  if (value.ndimension() == 2) {
    return torch::sparse_coo_tensor(
57
        indices, value, {coo->num_rows, coo->num_cols, value.size(1)});
58
59
  } else {
    return torch::sparse_coo_tensor(
60
        indices, value, {coo->num_rows, coo->num_cols});
61
62
63
  }
}

64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
std::shared_ptr<COO> CSRToCOO(const std::shared_ptr<CSR>& csr) {
  auto dgl_csr = CSRToOldDGLCSR(csr);
  auto dgl_coo = aten::CSRToCOO(dgl_csr, csr->value_indices.has_value());
  return COOFromOldDGLCOO(dgl_coo);
}

std::shared_ptr<COO> CSCToCOO(const std::shared_ptr<CSR>& csc) {
  auto dgl_csc = CSRToOldDGLCSR(csc);
  auto dgl_coo = aten::CSRToCOO(dgl_csc, csc->value_indices.has_value());
  dgl_coo = aten::COOTranspose(dgl_coo);
  return COOFromOldDGLCOO(dgl_coo);
}

std::shared_ptr<CSR> COOToCSR(const std::shared_ptr<COO>& coo) {
  auto dgl_coo = COOToOldDGLCOO(coo);
  auto dgl_csr = aten::COOToCSR(dgl_coo);
  return CSRFromOldDGLCSR(dgl_csr);
}

std::shared_ptr<CSR> CSCToCSR(const std::shared_ptr<CSR>& csc) {
  auto dgl_csc = CSRToOldDGLCSR(csc);
  auto dgl_csr = aten::CSRTranspose(dgl_csc);
  return CSRFromOldDGLCSR(dgl_csr);
}

std::shared_ptr<CSR> COOToCSC(const std::shared_ptr<COO>& coo) {
  auto dgl_coo = COOToOldDGLCOO(coo);
  auto dgl_coo_transpose = aten::COOTranspose(dgl_coo);
  auto dgl_csc = aten::COOToCSR(dgl_coo_transpose);
  return CSRFromOldDGLCSR(dgl_csc);
}

std::shared_ptr<CSR> CSRToCSC(const std::shared_ptr<CSR>& csr) {
  auto dgl_csr = CSRToOldDGLCSR(csr);
  auto dgl_csc = aten::CSRTranspose(dgl_csr);
  return CSRFromOldDGLCSR(dgl_csc);
czkkkkkk's avatar
czkkkkkk committed
100
101
}

102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
std::shared_ptr<COO> DiagToCOO(
    const std::shared_ptr<Diag>& diag,
    const c10::TensorOptions& indices_options) {
  int64_t nnz = std::min(diag->num_rows, diag->num_cols);
  auto indices = torch::arange(nnz, indices_options).repeat({2, 1});
  return std::make_shared<COO>(
      COO{diag->num_rows, diag->num_cols, indices, true, true});
}

std::shared_ptr<CSR> DiagToCSR(
    const std::shared_ptr<Diag>& diag,
    const c10::TensorOptions& indices_options) {
  int64_t nnz = std::min(diag->num_rows, diag->num_cols);
  auto indptr = torch::full(diag->num_rows + 1, nnz, indices_options);
  torch::arange_out(indptr, nnz + 1);
  auto indices = torch::arange(nnz, indices_options);
  return std::make_shared<CSR>(
      CSR{diag->num_rows, diag->num_cols, indptr, indices,
          torch::optional<torch::Tensor>(), true});
}

std::shared_ptr<CSR> DiagToCSC(
    const std::shared_ptr<Diag>& diag,
    const c10::TensorOptions& indices_options) {
  int64_t nnz = std::min(diag->num_rows, diag->num_cols);
  auto indptr = torch::full(diag->num_cols + 1, nnz, indices_options);
  torch::arange_out(indptr, nnz + 1);
  auto indices = torch::arange(nnz, indices_options);
  return std::make_shared<CSR>(
      CSR{diag->num_cols, diag->num_rows, indptr, indices,
          torch::optional<torch::Tensor>(), true});
}

135
136
137
138
139
140
std::shared_ptr<COO> COOTranspose(const std::shared_ptr<COO>& coo) {
  auto dgl_coo = COOToOldDGLCOO(coo);
  auto dgl_coo_tr = aten::COOTranspose(dgl_coo);
  return COOFromOldDGLCOO(dgl_coo_tr);
}

czkkkkkk's avatar
czkkkkkk committed
141
142
}  // namespace sparse
}  // namespace dgl