pickle.cc 4.04 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
/*!
 *  Copyright (c) 2020 by Contributors
 * \file graph/pickle.cc
 * \brief Functions for pickle and unpickle a graph
 */
#include <dgl/packed_func_ext.h>
#include <dgl/runtime/container.h>
#include "./heterograph.h"
#include "../c_api_common.h"

using namespace dgl::runtime;

namespace dgl {

HeteroPickleStates HeteroPickle(HeteroGraphPtr graph) {
  HeteroPickleStates states;
  states.metagraph = graph->meta_graph();
18
  states.num_nodes_per_type = graph->NumVerticesPerType();
19
20
  states.adjs.resize(graph->NumEdgeTypes());
  for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) {
21
    SparseFormat fmt = graph->SelectFormat(etype, SparseFormat::kAny);
22
23
    states.adjs[etype] = std::make_shared<SparseMatrix>();
    switch (fmt) {
24
      case SparseFormat::kCOO:
25
26
        *states.adjs[etype] = graph->GetCOOMatrix(etype).ToSparseMatrix();
        break;
27
28
      case SparseFormat::kCSR:
      case SparseFormat::kCSC:
29
30
31
32
33
34
35
36
37
38
39
        *states.adjs[etype] = graph->GetCSRMatrix(etype).ToSparseMatrix();
        break;
      default:
        LOG(FATAL) << "Unsupported sparse format.";
    }
  }
  return states;
}

HeteroGraphPtr HeteroUnpickle(const HeteroPickleStates& states) {
  const auto metagraph = states.metagraph;
40
  const auto &num_nodes_per_type = states.num_nodes_per_type;
41
42
43
44
45
46
47
48
49
  CHECK_EQ(states.adjs.size(), metagraph->NumEdges());
  std::vector<HeteroGraphPtr> relgraphs(metagraph->NumEdges());
  for (dgl_type_t etype = 0; etype < metagraph->NumEdges(); ++etype) {
    const auto& pair = metagraph->FindEdge(etype);
    const dgl_type_t srctype = pair.first;
    const dgl_type_t dsttype = pair.second;
    const int64_t num_vtypes = (srctype == dsttype)? 1 : 2;
    const SparseFormat fmt = static_cast<SparseFormat>(states.adjs[etype]->format);
    switch (fmt) {
50
      case SparseFormat::kCOO:
51
52
53
        relgraphs[etype] = UnitGraph::CreateFromCOO(
            num_vtypes, aten::COOMatrix(*states.adjs[etype]));
        break;
54
      case SparseFormat::kCSR:
55
56
57
        relgraphs[etype] = UnitGraph::CreateFromCSR(
            num_vtypes, aten::CSRMatrix(*states.adjs[etype]));
        break;
58
      case SparseFormat::kCSC:
59
60
61
62
      default:
        LOG(FATAL) << "Unsupported sparse format.";
    }
  }
63
  return CreateHeteroGraph(metagraph, relgraphs, num_nodes_per_type);
64
65
66
67
68
69
70
71
}

DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroPickleStatesGetMetagraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
    HeteroPickleStatesRef st = args[0];
    *rv = GraphRef(st->metagraph);
  });

72
73
74
75
76
77
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroPickleStatesGetNumVertices")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
    HeteroPickleStatesRef st = args[0];
    *rv = NDArray::FromVector(st->num_nodes_per_type);
  });

78
79
80
81
82
83
84
85
86
87
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroPickleStatesGetAdjs")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
    HeteroPickleStatesRef st = args[0];
    std::vector<SparseMatrixRef> refs(st->adjs.begin(), st->adjs.end());
    *rv = List<SparseMatrixRef>(refs);
  });

DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLCreateHeteroPickleStates")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
    GraphRef metagraph = args[0];
88
89
    IdArray num_nodes_per_type = args[1];
    List<SparseMatrixRef> adjs = args[2];
90
91
    std::shared_ptr<HeteroPickleStates> st( new HeteroPickleStates );
    st->metagraph = metagraph.sptr();
92
    st->num_nodes_per_type = num_nodes_per_type.ToVector<int64_t>();
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
    st->adjs.reserve(adjs.size());
    for (const auto& ref : adjs)
      st->adjs.push_back(ref.sptr());
    *rv = HeteroPickleStatesRef(st);
  });

DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroPickle")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
    HeteroGraphRef ref = args[0];
    std::shared_ptr<HeteroPickleStates> st( new HeteroPickleStates );
    *st = HeteroPickle(ref.sptr());
    *rv = HeteroPickleStatesRef(st);
  });

DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroUnpickle")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
    HeteroPickleStatesRef ref = args[0];
    HeteroGraphPtr graph = HeteroUnpickle(*ref.sptr());
    *rv = HeteroGraphRef(graph);
  });

}  // namespace dgl