gk_ops.cc 3.91 KB
Newer Older
1
/**
2
 *  Copyright (c) 2020 by Contributors
3
4
 * @file graph/gk_ops.cc
 * @brief Graph operation implemented in GKlib
5
6
7
8
9
10
11
12
13
14
15
16
 */

#if !defined(_WIN32)
#include <GKlib.h>
#endif  // !defined(_WIN32)

#include <dgl/graph_op.h>

namespace dgl {

#if !defined(_WIN32)

17
/**
18
19
 * Convert DGL CSR to GKLib CSR.
 * GKLib CSR actually stores a CSR object and a CSC object of a graph.
20
21
22
 * @param mat the DGL CSR matrix.
 * @param is_row the input DGL matrix is CSR or CSC.
 * @return a GKLib CSR.
23
24
25
 */
gk_csr_t *Convert2GKCsr(const aten::CSRMatrix mat, bool is_row) {
  // TODO(zhengda) The conversion will be zero-copy in the future.
26
27
28
29
  CHECK_EQ(mat.indptr->dtype.bits, sizeof(dgl_id_t) * CHAR_BIT);
  CHECK_EQ(mat.indices->dtype.bits, sizeof(dgl_id_t) * CHAR_BIT);
  const dgl_id_t *indptr = static_cast<dgl_id_t *>(mat.indptr->data);
  const dgl_id_t *indices = static_cast<dgl_id_t *>(mat.indices->data);
30
31
32
33
34
35
36
37
38
39

  gk_csr_t *gk_csr = gk_csr_Create();
  gk_csr->nrows = mat.num_rows;
  gk_csr->ncols = mat.num_cols;
  uint64_t nnz = mat.indices->shape[0];
  auto gk_indptr = gk_csr->rowptr;
  auto gk_indices = gk_csr->rowind;
  size_t num_ptrs;
  if (is_row) {
    num_ptrs = gk_csr->nrows + 1;
40
41
42
43
44
    gk_indptr = gk_csr->rowptr = gk_zmalloc(
        gk_csr->nrows + 1,
        const_cast<char *>("gk_csr_ExtractPartition: rowptr"));
    gk_indices = gk_csr->rowind =
        gk_imalloc(nnz, const_cast<char *>("gk_csr_ExtractPartition: rowind"));
45
46
  } else {
    num_ptrs = gk_csr->ncols + 1;
47
48
49
50
51
    gk_indptr = gk_csr->colptr = gk_zmalloc(
        gk_csr->ncols + 1,
        const_cast<char *>("gk_csr_ExtractPartition: colptr"));
    gk_indices = gk_csr->colind =
        gk_imalloc(nnz, const_cast<char *>("gk_csr_ExtractPartition: colind"));
52
53
54
55
56
57
58
59
60
61
62
  }

  for (size_t i = 0; i < num_ptrs; i++) {
    gk_indptr[i] = indptr[i];
  }
  for (size_t i = 0; i < nnz; i++) {
    gk_indices[i] = indices[i];
  }
  return gk_csr;
}

63
/**
64
65
 * Convert GKLib CSR to DGL CSR.
 * GKLib CSR actually stores a CSR object and a CSC object of a graph.
66
67
68
 * @param gk_csr the GKLib CSR.
 * @param is_row specify whether to convert the CSR or CSC object of GKLib CSR.
 * @return a DGL CSR matrix.
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
 */
aten::CSRMatrix Convert2DGLCsr(gk_csr_t *gk_csr, bool is_row) {
  // TODO(zhengda) The conversion will be zero-copy in the future.
  size_t num_ptrs;
  size_t nnz;
  auto gk_indptr = gk_csr->rowptr;
  auto gk_indices = gk_csr->rowind;
  if (is_row) {
    num_ptrs = gk_csr->nrows + 1;
    nnz = gk_csr->rowptr[num_ptrs - 1];
    gk_indptr = gk_csr->rowptr;
    gk_indices = gk_csr->rowind;
  } else {
    num_ptrs = gk_csr->ncols + 1;
    nnz = gk_csr->colptr[num_ptrs - 1];
    gk_indptr = gk_csr->colptr;
    gk_indices = gk_csr->colind;
  }

  IdArray indptr_arr = aten::NewIdArray(num_ptrs);
  IdArray indices_arr = aten::NewIdArray(nnz);
  IdArray eids_arr = aten::NewIdArray(nnz);

  dgl_id_t *indptr = static_cast<dgl_id_t *>(indptr_arr->data);
  dgl_id_t *indices = static_cast<dgl_id_t *>(indices_arr->data);
  dgl_id_t *eids = static_cast<dgl_id_t *>(eids_arr->data);
  for (size_t i = 0; i < num_ptrs; i++) {
    indptr[i] = gk_indptr[i];
  }
  for (size_t i = 0; i < nnz; i++) {
    indices[i] = gk_indices[i];
    eids[i] = i;
  }

103
104
  return aten::CSRMatrix(
      gk_csr->nrows, gk_csr->ncols, indptr_arr, indices_arr, eids_arr);
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
}

#endif  // !defined(_WIN32)

GraphPtr GraphOp::ToBidirectedSimpleImmutableGraph(ImmutableGraphPtr ig) {
#if !defined(_WIN32)
  // TODO(zhengda) should we get whatever CSR exists in the graph.
  CSRPtr csr = ig->GetInCSR();
  gk_csr_t *gk_csr = Convert2GKCsr(csr->ToCSRMatrix(), true);
  gk_csr_t *sym_gk_csr = gk_csr_MakeSymmetric(gk_csr, GK_CSR_SYM_SUM);
  auto mat = Convert2DGLCsr(sym_gk_csr, true);
  gk_csr_Free(&gk_csr);
  gk_csr_Free(&sym_gk_csr);

  // This is a symmetric graph now. The in-csr and out-csr are the same.
  csr = CSRPtr(new CSR(mat.indptr, mat.indices, mat.data));
  return GraphPtr(new ImmutableGraph(csr, csr));
#else
  return GraphPtr();
#endif  // !defined(_WIN32)
}

}  // namespace dgl