sparse_format.h 3.94 KB
Newer Older
czkkkkkk's avatar
czkkkkkk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
/**
 *  Copyright (c) 2022 by Contributors
 * @file sparse/sparse_format.h
 * @brief DGL C++ sparse format header.
 */
#ifndef SPARSE_SPARSE_FORMAT_H_
#define SPARSE_SPARSE_FORMAT_H_

// clang-format off
#include <sparse/dgl_headers.h>
// clang-format on

#include <torch/custom_class.h>
#include <torch/script.h>

#include <memory>
czkkkkkk's avatar
czkkkkkk committed
17
#include <utility>
czkkkkkk's avatar
czkkkkkk committed
18
19
20
21

namespace dgl {
namespace sparse {

22
/** @brief SparseFormat enumeration. */
23
enum SparseFormat { kCOO, kCSR, kCSC, kDiag };
czkkkkkk's avatar
czkkkkkk committed
24

25
26
27
28
/** @brief COO sparse structure. */
struct COO {
  /** @brief The shape of the matrix. */
  int64_t num_rows = 0, num_cols = 0;
29
30
31
32
  /**
   * @brief COO tensor of shape (2, nnz), stacking the row and column indices.
   */
  torch::Tensor indices;
33
34
35
36
37
38
39
  /** @brief Whether the row indices are sorted. */
  bool row_sorted = false;
  /** @brief Whether the column indices per row are sorted. */
  bool col_sorted = false;
};

/** @brief CSR sparse structure. */
czkkkkkk's avatar
czkkkkkk committed
40
struct CSR {
41
42
43
  /** @brief The dense shape of the matrix. */
  int64_t num_rows = 0, num_cols = 0;
  /** @brief CSR format index pointer array of the matrix. */
czkkkkkk's avatar
czkkkkkk committed
44
  torch::Tensor indptr;
45
  /** @brief CSR format index array of the matrix. */
czkkkkkk's avatar
czkkkkkk committed
46
  torch::Tensor indices;
47
48
  /** @brief Data index tensor. When it is null, assume it is from 0 to NNZ - 1.
   */
czkkkkkk's avatar
czkkkkkk committed
49
  torch::optional<torch::Tensor> value_indices;
50
51
  /** @brief Whether the column indices per row are sorted. */
  bool sorted = false;
czkkkkkk's avatar
czkkkkkk committed
52
53
};

54
55
56
57
58
struct Diag {
  /** @brief The dense shape of the matrix. */
  int64_t num_rows = 0, num_cols = 0;
};

59
60
/** @brief Convert an old DGL COO format to a COO in the sparse library. */
std::shared_ptr<COO> COOFromOldDGLCOO(const aten::COOMatrix& dgl_coo);
czkkkkkk's avatar
czkkkkkk committed
61

62
63
64
65
66
67
68
69
70
/** @brief Convert a COO in the sparse library to an old DGL COO matrix. */
aten::COOMatrix COOToOldDGLCOO(const std::shared_ptr<COO>& coo);

/** @brief Convert an old DGL CSR format to a CSR in the sparse library. */
std::shared_ptr<CSR> CSRFromOldDGLCSR(const aten::CSRMatrix& dgl_csr);

/** @brief Convert a CSR in the sparse library to an old DGL CSR matrix. */
aten::CSRMatrix CSRToOldDGLCSR(const std::shared_ptr<CSR>& csr);

71
72
73
74
75
76
77
78
79
80
/**
 *  @brief Convert a COO and its nonzero values to a Torch COO matrix.
 *  @param coo The COO format in the sparse library
 *  @param value Values of the sparse matrix
 *
 *  @return Torch Sparse Tensor in COO format
 */
torch::Tensor COOToTorchCOO(
    const std::shared_ptr<COO>& coo, torch::Tensor value);

81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
/** @brief Convert a CSR format to COO format. */
std::shared_ptr<COO> CSRToCOO(const std::shared_ptr<CSR>& csr);

/** @brief Convert a CSC format to COO format. */
std::shared_ptr<COO> CSCToCOO(const std::shared_ptr<CSR>& csc);

/** @brief Convert a COO format to CSR format. */
std::shared_ptr<CSR> COOToCSR(const std::shared_ptr<COO>& coo);

/** @brief Convert a CSC format to CSR format. */
std::shared_ptr<CSR> CSCToCSR(const std::shared_ptr<CSR>& csc);

/** @brief Convert a COO format to CSC format. */
std::shared_ptr<CSR> COOToCSC(const std::shared_ptr<COO>& coo);

/** @brief Convert a CSR format to CSC format. */
std::shared_ptr<CSR> CSRToCSC(const std::shared_ptr<CSR>& csr);
czkkkkkk's avatar
czkkkkkk committed
98

99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
/** @brief Convert a Diag format to COO format. */
std::shared_ptr<COO> DiagToCOO(
    const std::shared_ptr<Diag>& diag,
    const c10::TensorOptions& indices_options);

/** @brief Convert a Diag format to CSR format. */
std::shared_ptr<CSR> DiagToCSR(
    const std::shared_ptr<Diag>& diag,
    const c10::TensorOptions& indices_options);

/** @brief Convert a Diag format to CSC format. */
std::shared_ptr<CSR> DiagToCSC(
    const std::shared_ptr<Diag>& diag,
    const c10::TensorOptions& indices_options);

114
115
116
/** @brief COO transposition. */
std::shared_ptr<COO> COOTranspose(const std::shared_ptr<COO>& coo);

czkkkkkk's avatar
czkkkkkk committed
117
118
119
120
121
122
123
/**
 * @brief Sort the COO matrix by row and column indices.
 * @return A pair of the sorted COO matrix and the permutation indices.
 */
std::pair<std::shared_ptr<COO>, torch::Tensor> COOSort(
    const std::shared_ptr<COO>& coo);

czkkkkkk's avatar
czkkkkkk committed
124
125
126
127
}  // namespace sparse
}  // namespace dgl

#endif  // SPARSE_SPARSE_FORMAT_H_