gk_ops.cc 3.91 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
/*!
 *  Copyright (c) 2020 by Contributors
 * \file graph/gk_ops.cc
 * \brief Graph operation implemented in GKlib
 */

#if !defined(_WIN32)
#include <GKlib.h>
#endif  // !defined(_WIN32)

#include <dgl/graph_op.h>

namespace dgl {

#if !defined(_WIN32)

/*!
 * Convert DGL CSR to GKLib CSR.
 * GKLib CSR actually stores a CSR object and a CSC object of a graph.
 * \param mat the DGL CSR matrix.
 * \param is_row the input DGL matrix is CSR or CSC.
 * \return a GKLib CSR.
 */
gk_csr_t *Convert2GKCsr(const aten::CSRMatrix mat, bool is_row) {
  // TODO(zhengda) The conversion will be zero-copy in the future.
26
27
28
29
  CHECK_EQ(mat.indptr->dtype.bits, sizeof(dgl_id_t) * CHAR_BIT);
  CHECK_EQ(mat.indices->dtype.bits, sizeof(dgl_id_t) * CHAR_BIT);
  const dgl_id_t *indptr = static_cast<dgl_id_t *>(mat.indptr->data);
  const dgl_id_t *indices = static_cast<dgl_id_t *>(mat.indices->data);
30
31
32
33
34
35
36
37
38
39

  gk_csr_t *gk_csr = gk_csr_Create();
  gk_csr->nrows = mat.num_rows;
  gk_csr->ncols = mat.num_cols;
  uint64_t nnz = mat.indices->shape[0];
  auto gk_indptr = gk_csr->rowptr;
  auto gk_indices = gk_csr->rowind;
  size_t num_ptrs;
  if (is_row) {
    num_ptrs = gk_csr->nrows + 1;
40
41
42
43
44
    gk_indptr = gk_csr->rowptr = gk_zmalloc(
        gk_csr->nrows + 1,
        const_cast<char *>("gk_csr_ExtractPartition: rowptr"));
    gk_indices = gk_csr->rowind =
        gk_imalloc(nnz, const_cast<char *>("gk_csr_ExtractPartition: rowind"));
45
46
  } else {
    num_ptrs = gk_csr->ncols + 1;
47
48
49
50
51
    gk_indptr = gk_csr->colptr = gk_zmalloc(
        gk_csr->ncols + 1,
        const_cast<char *>("gk_csr_ExtractPartition: colptr"));
    gk_indices = gk_csr->colind =
        gk_imalloc(nnz, const_cast<char *>("gk_csr_ExtractPartition: colind"));
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
  }

  for (size_t i = 0; i < num_ptrs; i++) {
    gk_indptr[i] = indptr[i];
  }
  for (size_t i = 0; i < nnz; i++) {
    gk_indices[i] = indices[i];
  }
  return gk_csr;
}

/*!
 * Convert GKLib CSR to DGL CSR.
 * GKLib CSR actually stores a CSR object and a CSC object of a graph.
 * \param gk_csr the GKLib CSR.
 * \param is_row specify whether to convert the CSR or CSC object of GKLib CSR.
 * \return a DGL CSR matrix.
 */
aten::CSRMatrix Convert2DGLCsr(gk_csr_t *gk_csr, bool is_row) {
  // TODO(zhengda) The conversion will be zero-copy in the future.
  size_t num_ptrs;
  size_t nnz;
  auto gk_indptr = gk_csr->rowptr;
  auto gk_indices = gk_csr->rowind;
  if (is_row) {
    num_ptrs = gk_csr->nrows + 1;
    nnz = gk_csr->rowptr[num_ptrs - 1];
    gk_indptr = gk_csr->rowptr;
    gk_indices = gk_csr->rowind;
  } else {
    num_ptrs = gk_csr->ncols + 1;
    nnz = gk_csr->colptr[num_ptrs - 1];
    gk_indptr = gk_csr->colptr;
    gk_indices = gk_csr->colind;
  }

  IdArray indptr_arr = aten::NewIdArray(num_ptrs);
  IdArray indices_arr = aten::NewIdArray(nnz);
  IdArray eids_arr = aten::NewIdArray(nnz);

  dgl_id_t *indptr = static_cast<dgl_id_t *>(indptr_arr->data);
  dgl_id_t *indices = static_cast<dgl_id_t *>(indices_arr->data);
  dgl_id_t *eids = static_cast<dgl_id_t *>(eids_arr->data);
  for (size_t i = 0; i < num_ptrs; i++) {
    indptr[i] = gk_indptr[i];
  }
  for (size_t i = 0; i < nnz; i++) {
    indices[i] = gk_indices[i];
    eids[i] = i;
  }

103
104
  return aten::CSRMatrix(
      gk_csr->nrows, gk_csr->ncols, indptr_arr, indices_arr, eids_arr);
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
}

#endif  // !defined(_WIN32)

GraphPtr GraphOp::ToBidirectedSimpleImmutableGraph(ImmutableGraphPtr ig) {
#if !defined(_WIN32)
  // TODO(zhengda) should we get whatever CSR exists in the graph.
  CSRPtr csr = ig->GetInCSR();
  gk_csr_t *gk_csr = Convert2GKCsr(csr->ToCSRMatrix(), true);
  gk_csr_t *sym_gk_csr = gk_csr_MakeSymmetric(gk_csr, GK_CSR_SYM_SUM);
  auto mat = Convert2DGLCsr(sym_gk_csr, true);
  gk_csr_Free(&gk_csr);
  gk_csr_Free(&sym_gk_csr);

  // This is a symmetric graph now. The in-csr and out-csr are the same.
  csr = CSRPtr(new CSR(mat.indptr, mat.indices, mat.data));
  return GraphPtr(new ImmutableGraph(csr, csr));
#else
  return GraphPtr();
#endif  // !defined(_WIN32)
}

}  // namespace dgl