pickle.cc 7.77 KB
Newer Older
1
2
3
4
5
6
7
/*!
 *  Copyright (c) 2020 by Contributors
 * \file graph/pickle.cc
 * \brief Functions for pickle and unpickle a graph
 */
#include <dgl/packed_func_ext.h>
#include <dgl/runtime/container.h>
8
9
10
#include <dgl/immutable_graph.h>
#include <dgl/graph_serializer.h>
#include <dmlc/memory_io.h>
11
12
#include "./heterograph.h"
#include "../c_api_common.h"
13
#include "unit_graph.h"
14
15
16
17
18
19
20

using namespace dgl::runtime;

namespace dgl {

HeteroPickleStates HeteroPickle(HeteroGraphPtr graph) {
  HeteroPickleStates states;
21
22
23
24
  dmlc::MemoryStringStream ofs(&states.meta);
  dmlc::Stream *strm = &ofs;
  strm->Write(ImmutableGraph::ToImmutable(graph->meta_graph()));
  strm->Write(graph->NumVerticesPerType());
25
  for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) {
26
    SparseFormat fmt = graph->SelectFormat(etype, ALL_CODE);
27
    switch (fmt) {
28
29
30
31
32
33
34
      case SparseFormat::kCOO: {
        strm->Write(SparseFormat::kCOO);
        const auto &coo = graph->GetCOOMatrix(etype);
        strm->Write(coo.row_sorted);
        strm->Write(coo.col_sorted);
        states.arrays.push_back(coo.row);
        states.arrays.push_back(coo.col);
35
        break;
36
      }
37
      case SparseFormat::kCSR:
38
39
40
41
42
43
44
      case SparseFormat::kCSC: {
        strm->Write(SparseFormat::kCSR);
        const auto &csr = graph->GetCSRMatrix(etype);
        strm->Write(csr.sorted);
        states.arrays.push_back(csr.indptr);
        states.arrays.push_back(csr.indices);
        states.arrays.push_back(csr.data);
45
        break;
46
      }
47
48
49
50
51
52
53
54
      default:
        LOG(FATAL) << "Unsupported sparse format.";
    }
  }
  return states;
}

HeteroGraphPtr HeteroUnpickle(const HeteroPickleStates& states) {
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
  char *buf = const_cast<char *>(states.meta.c_str());  // a readonly stream?
  dmlc::MemoryFixedSizeStream ifs(buf, states.meta.size());
  dmlc::Stream *strm = &ifs;
  auto meta_imgraph = Serializer::make_shared<ImmutableGraph>();
  CHECK(strm->Read(&meta_imgraph)) << "Invalid meta graph";
  GraphPtr metagraph = meta_imgraph;
  std::vector<HeteroGraphPtr> relgraphs(metagraph->NumEdges());
  std::vector<int64_t> num_nodes_per_type;
  CHECK(strm->Read(&num_nodes_per_type)) << "Invalid num_nodes_per_type";

  auto array_itr = states.arrays.begin();
  for (dgl_type_t etype = 0; etype < metagraph->NumEdges(); ++etype) {
    const auto& pair = metagraph->FindEdge(etype);
    const dgl_type_t srctype = pair.first;
    const dgl_type_t dsttype = pair.second;
    const int64_t num_vtypes = (srctype == dsttype)? 1 : 2;
    int64_t num_src = num_nodes_per_type[srctype];
    int64_t num_dst = num_nodes_per_type[dsttype];
    SparseFormat fmt;
    CHECK(strm->Read(&fmt)) << "Invalid SparseFormat";
    HeteroGraphPtr relgraph;
    switch (fmt) {
      case SparseFormat::kCOO: {
        CHECK_GE(states.arrays.end() - array_itr, 2);
        const auto &row = *(array_itr++);
        const auto &col = *(array_itr++);
        bool rsorted;
        bool csorted;
        CHECK(strm->Read(&rsorted)) << "Invalid flag 'rsorted'";
        CHECK(strm->Read(&csorted)) << "Invalid flag 'csorted'";
        auto coo = aten::COOMatrix(num_src, num_dst, row, col, aten::NullArray(), rsorted, csorted);
86
        // TODO(zihao) fix
87
        relgraph = CreateFromCOO(num_vtypes, coo, ALL_CODE);
88
89
90
91
92
93
94
95
96
97
        break;
      }
      case SparseFormat::kCSR: {
        CHECK_GE(states.arrays.end() - array_itr, 3);
        const auto &indptr = *(array_itr++);
        const auto &indices = *(array_itr++);
        const auto &edge_id = *(array_itr++);
        bool sorted;
        CHECK(strm->Read(&sorted)) << "Invalid flag 'sorted'";
        auto csr = aten::CSRMatrix(num_src, num_dst, indptr, indices, edge_id, sorted);
98
        // TODO(zihao) fix
99
        relgraph = CreateFromCSR(num_vtypes, csr, ALL_CODE);
100
101
102
103
104
105
106
107
108
109
110
111
112
        break;
      }
      case SparseFormat::kCSC:
      default:
        LOG(FATAL) << "Unsupported sparse format.";
    }
    relgraphs[etype] = relgraph;
  }
  return CreateHeteroGraph(metagraph, relgraphs, num_nodes_per_type);
}

// For backward compatibility
HeteroGraphPtr HeteroUnpickleOld(const HeteroPickleStates& states) {
113
  const auto metagraph = states.metagraph;
114
  const auto &num_nodes_per_type = states.num_nodes_per_type;
115
116
117
118
119
120
121
122
123
  CHECK_EQ(states.adjs.size(), metagraph->NumEdges());
  std::vector<HeteroGraphPtr> relgraphs(metagraph->NumEdges());
  for (dgl_type_t etype = 0; etype < metagraph->NumEdges(); ++etype) {
    const auto& pair = metagraph->FindEdge(etype);
    const dgl_type_t srctype = pair.first;
    const dgl_type_t dsttype = pair.second;
    const int64_t num_vtypes = (srctype == dsttype)? 1 : 2;
    const SparseFormat fmt = static_cast<SparseFormat>(states.adjs[etype]->format);
    switch (fmt) {
124
      case SparseFormat::kCOO:
125
126
127
        relgraphs[etype] = UnitGraph::CreateFromCOO(
            num_vtypes, aten::COOMatrix(*states.adjs[etype]));
        break;
128
      case SparseFormat::kCSR:
129
130
131
        relgraphs[etype] = UnitGraph::CreateFromCSR(
            num_vtypes, aten::CSRMatrix(*states.adjs[etype]));
        break;
132
      case SparseFormat::kCSC:
133
134
135
136
      default:
        LOG(FATAL) << "Unsupported sparse format.";
    }
  }
137
  return CreateHeteroGraph(metagraph, relgraphs, num_nodes_per_type);
138
139
}

140
141
142
143
144
145
146
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroPickleStatesGetVersion")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
    HeteroPickleStatesRef st = args[0];
    *rv = st->version;
  });

DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroPickleStatesGetMeta")
147
148
.set_body([] (DGLArgs args, DGLRetValue* rv) {
    HeteroPickleStatesRef st = args[0];
149
150
151
152
    DGLByteArray buf;
    buf.data = st->meta.c_str();
    buf.size = st->meta.size();
    *rv = buf;
153
154
  });

155
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroPickleStatesGetArrays")
156
157
.set_body([] (DGLArgs args, DGLRetValue* rv) {
    HeteroPickleStatesRef st = args[0];
158
    *rv = ConvertNDArrayVectorToPackedFunc(st->arrays);
159
160
  });

161
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroPickleStatesGetArraysNum")
162
163
.set_body([] (DGLArgs args, DGLRetValue* rv) {
    HeteroPickleStatesRef st = args[0];
164
    *rv = static_cast<int64_t>(st->arrays.size());
165
166
167
168
  });

DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLCreateHeteroPickleStates")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
169
170
    std::string meta = args[0];
    const List<Value> arrays = args[1];
171
    std::shared_ptr<HeteroPickleStates> st( new HeteroPickleStates );
172
173
174
175
176
177
    st->version = 1;
    st->meta = meta;
    st->arrays.reserve(arrays.size());
    for (const auto& ref : arrays) {
      st->arrays.push_back(ref->data);
    }
178
179
180
181
182
183
184
185
186
187
188
189
190
191
    *rv = HeteroPickleStatesRef(st);
  });

DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroPickle")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
    HeteroGraphRef ref = args[0];
    std::shared_ptr<HeteroPickleStates> st( new HeteroPickleStates );
    *st = HeteroPickle(ref.sptr());
    *rv = HeteroPickleStatesRef(st);
  });

DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroUnpickle")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
    HeteroPickleStatesRef ref = args[0];
192
193
194
195
196
197
198
199
200
201
202
    HeteroGraphPtr graph;
    switch (ref->version) {
      case 0:
        graph = HeteroUnpickleOld(*ref.sptr());
        break;
      case 1:
        graph = HeteroUnpickle(*ref.sptr());
        break;
      default:
        LOG(FATAL) << "Version can only be 0 or 1.";
    }
203
204
205
    *rv = HeteroGraphRef(graph);
  });

206
207
208
209
210
211
212
213
214
215
216
217
218
219
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLCreateHeteroPickleStatesOld")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
    GraphRef metagraph = args[0];
    IdArray num_nodes_per_type = args[1];
    List<SparseMatrixRef> adjs = args[2];
    std::shared_ptr<HeteroPickleStates> st( new HeteroPickleStates );
    st->version = 0;
    st->metagraph = metagraph.sptr();
    st->num_nodes_per_type = num_nodes_per_type.ToVector<int64_t>();
    st->adjs.reserve(adjs.size());
    for (const auto& ref : adjs)
      st->adjs.push_back(ref.sptr());
    *rv = HeteroPickleStatesRef(st);
  });
220
}  // namespace dgl