pickle.cc 13 KB
Newer Older
1
2
3
4
5
6
7
/*!
 *  Copyright (c) 2020 by Contributors
 * \file graph/pickle.cc
 * \brief Functions for pickle and unpickle a graph
 */
#include <dgl/packed_func_ext.h>
#include <dgl/runtime/container.h>
8
9
10
#include <dgl/immutable_graph.h>
#include <dgl/graph_serializer.h>
#include <dmlc/memory_io.h>
11
12
#include "./heterograph.h"
#include "../c_api_common.h"
13
#include "unit_graph.h"
14
15
16
17
18
19
20

using namespace dgl::runtime;

namespace dgl {

HeteroPickleStates HeteroPickle(HeteroGraphPtr graph) {
  HeteroPickleStates states;
21
  states.version = 2;
22
23
24
25
  dmlc::MemoryStringStream ofs(&states.meta);
  dmlc::Stream *strm = &ofs;
  strm->Write(ImmutableGraph::ToImmutable(graph->meta_graph()));
  strm->Write(graph->NumVerticesPerType());
26
  strm->Write(graph->IsPinned());
27
  for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) {
28
    SparseFormat fmt = graph->SelectFormat(etype, ALL_CODE);
29
    switch (fmt) {
30
31
32
33
34
35
36
      case SparseFormat::kCOO: {
        strm->Write(SparseFormat::kCOO);
        const auto &coo = graph->GetCOOMatrix(etype);
        strm->Write(coo.row_sorted);
        strm->Write(coo.col_sorted);
        states.arrays.push_back(coo.row);
        states.arrays.push_back(coo.col);
37
        break;
38
      }
39
      case SparseFormat::kCSR:
40
41
42
43
44
45
46
      case SparseFormat::kCSC: {
        strm->Write(SparseFormat::kCSR);
        const auto &csr = graph->GetCSRMatrix(etype);
        strm->Write(csr.sorted);
        states.arrays.push_back(csr.indptr);
        states.arrays.push_back(csr.indices);
        states.arrays.push_back(csr.data);
47
        break;
48
      }
49
50
51
52
53
54
55
      default:
        LOG(FATAL) << "Unsupported sparse format.";
    }
  }
  return states;
}

56
57
HeteroPickleStates HeteroForkingPickle(HeteroGraphPtr graph) {
  HeteroPickleStates states;
58
  states.version = 2;
59
60
61
62
  dmlc::MemoryStringStream ofs(&states.meta);
  dmlc::Stream *strm = &ofs;
  strm->Write(ImmutableGraph::ToImmutable(graph->meta_graph()));
  strm->Write(graph->NumVerticesPerType());
63
  strm->Write(graph->IsPinned());
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
  for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) {
    auto created_formats = graph->GetCreatedFormats();
    auto allowed_formats = graph->GetAllowedFormats();
    strm->Write(created_formats);
    strm->Write(allowed_formats);
    if (created_formats & COO_CODE) {
      const auto &coo = graph->GetCOOMatrix(etype);
      strm->Write(coo.row_sorted);
      strm->Write(coo.col_sorted);
      states.arrays.push_back(coo.row);
      states.arrays.push_back(coo.col);
    }
    if (created_formats & CSR_CODE) {
      const auto &csr = graph->GetCSRMatrix(etype);
      strm->Write(csr.sorted);
      states.arrays.push_back(csr.indptr);
      states.arrays.push_back(csr.indices);
      states.arrays.push_back(csr.data);
    }
    if (created_formats & CSC_CODE) {
      const auto &csc = graph->GetCSCMatrix(etype);
      strm->Write(csc.sorted);
      states.arrays.push_back(csc.indptr);
      states.arrays.push_back(csc.indices);
      states.arrays.push_back(csc.data);
    }
  }
  return states;
}

94
HeteroGraphPtr HeteroUnpickle(const HeteroPickleStates& states) {
95
96
97
98
99
100
101
102
103
  char *buf = const_cast<char *>(states.meta.c_str());  // a readonly stream?
  dmlc::MemoryFixedSizeStream ifs(buf, states.meta.size());
  dmlc::Stream *strm = &ifs;
  auto meta_imgraph = Serializer::make_shared<ImmutableGraph>();
  CHECK(strm->Read(&meta_imgraph)) << "Invalid meta graph";
  GraphPtr metagraph = meta_imgraph;
  std::vector<HeteroGraphPtr> relgraphs(metagraph->NumEdges());
  std::vector<int64_t> num_nodes_per_type;
  CHECK(strm->Read(&num_nodes_per_type)) << "Invalid num_nodes_per_type";
104
105
106
107
  bool is_pinned = false;
  if (states.version > 1) {
    CHECK(strm->Read(&is_pinned)) << "Invalid flag 'is_pinned'";
  }
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129

  auto array_itr = states.arrays.begin();
  for (dgl_type_t etype = 0; etype < metagraph->NumEdges(); ++etype) {
    const auto& pair = metagraph->FindEdge(etype);
    const dgl_type_t srctype = pair.first;
    const dgl_type_t dsttype = pair.second;
    const int64_t num_vtypes = (srctype == dsttype)? 1 : 2;
    int64_t num_src = num_nodes_per_type[srctype];
    int64_t num_dst = num_nodes_per_type[dsttype];
    SparseFormat fmt;
    CHECK(strm->Read(&fmt)) << "Invalid SparseFormat";
    HeteroGraphPtr relgraph;
    switch (fmt) {
      case SparseFormat::kCOO: {
        CHECK_GE(states.arrays.end() - array_itr, 2);
        const auto &row = *(array_itr++);
        const auto &col = *(array_itr++);
        bool rsorted;
        bool csorted;
        CHECK(strm->Read(&rsorted)) << "Invalid flag 'rsorted'";
        CHECK(strm->Read(&csorted)) << "Invalid flag 'csorted'";
        auto coo = aten::COOMatrix(num_src, num_dst, row, col, aten::NullArray(), rsorted, csorted);
130
        // TODO(zihao) fix
131
        relgraph = CreateFromCOO(num_vtypes, coo, ALL_CODE);
132
133
134
135
136
137
138
139
140
141
        break;
      }
      case SparseFormat::kCSR: {
        CHECK_GE(states.arrays.end() - array_itr, 3);
        const auto &indptr = *(array_itr++);
        const auto &indices = *(array_itr++);
        const auto &edge_id = *(array_itr++);
        bool sorted;
        CHECK(strm->Read(&sorted)) << "Invalid flag 'sorted'";
        auto csr = aten::CSRMatrix(num_src, num_dst, indptr, indices, edge_id, sorted);
142
        // TODO(zihao) fix
143
        relgraph = CreateFromCSR(num_vtypes, csr, ALL_CODE);
144
145
146
147
148
149
150
151
        break;
      }
      case SparseFormat::kCSC:
      default:
        LOG(FATAL) << "Unsupported sparse format.";
    }
    relgraphs[etype] = relgraph;
  }
152
153
154
155
156
  auto graph = CreateHeteroGraph(metagraph, relgraphs, num_nodes_per_type);
  if (is_pinned) {
    graph->PinMemory_();
  }
  return graph;
157
158
159
160
}

// For backward compatibility
HeteroGraphPtr HeteroUnpickleOld(const HeteroPickleStates& states) {
161
  const auto metagraph = states.metagraph;
162
  const auto &num_nodes_per_type = states.num_nodes_per_type;
163
164
165
166
167
168
169
170
171
  CHECK_EQ(states.adjs.size(), metagraph->NumEdges());
  std::vector<HeteroGraphPtr> relgraphs(metagraph->NumEdges());
  for (dgl_type_t etype = 0; etype < metagraph->NumEdges(); ++etype) {
    const auto& pair = metagraph->FindEdge(etype);
    const dgl_type_t srctype = pair.first;
    const dgl_type_t dsttype = pair.second;
    const int64_t num_vtypes = (srctype == dsttype)? 1 : 2;
    const SparseFormat fmt = static_cast<SparseFormat>(states.adjs[etype]->format);
    switch (fmt) {
172
      case SparseFormat::kCOO:
173
174
175
        relgraphs[etype] = UnitGraph::CreateFromCOO(
            num_vtypes, aten::COOMatrix(*states.adjs[etype]));
        break;
176
      case SparseFormat::kCSR:
177
178
179
        relgraphs[etype] = UnitGraph::CreateFromCSR(
            num_vtypes, aten::CSRMatrix(*states.adjs[etype]));
        break;
180
      case SparseFormat::kCSC:
181
182
183
184
      default:
        LOG(FATAL) << "Unsupported sparse format.";
    }
  }
185
  return CreateHeteroGraph(metagraph, relgraphs, num_nodes_per_type);
186
187
}

188
189
190
191
192
193
194
195
196
197
HeteroGraphPtr HeteroForkingUnpickle(const HeteroPickleStates &states) {
  char *buf = const_cast<char *>(states.meta.c_str());  // a readonly stream?
  dmlc::MemoryFixedSizeStream ifs(buf, states.meta.size());
  dmlc::Stream *strm = &ifs;
  auto meta_imgraph = Serializer::make_shared<ImmutableGraph>();
  CHECK(strm->Read(&meta_imgraph)) << "Invalid meta graph";
  GraphPtr metagraph = meta_imgraph;
  std::vector<HeteroGraphPtr> relgraphs(metagraph->NumEdges());
  std::vector<int64_t> num_nodes_per_type;
  CHECK(strm->Read(&num_nodes_per_type)) << "Invalid num_nodes_per_type";
198
199
200
201
  bool is_pinned = false;
  if (states.version > 1) {
    CHECK(strm->Read(&is_pinned)) << "Invalid flag 'is_pinned'";
  }
202
203
204
205
206
207
208
209
210
211
212
213
214

  auto array_itr = states.arrays.begin();
  for (dgl_type_t etype = 0; etype < metagraph->NumEdges(); ++etype) {
    const auto& pair = metagraph->FindEdge(etype);
    const dgl_type_t srctype = pair.first;
    const dgl_type_t dsttype = pair.second;
    const int64_t num_vtypes = (srctype == dsttype) ? 1 : 2;
    int64_t num_src = num_nodes_per_type[srctype];
    int64_t num_dst = num_nodes_per_type[dsttype];

    dgl_format_code_t created_formats, allowed_formats;
    CHECK(strm->Read(&created_formats)) << "Invalid code for created formats";
    CHECK(strm->Read(&allowed_formats)) << "Invalid code for allowed formats";
215
216
217
218
219
220
    aten::COOMatrix coo;
    aten::CSRMatrix csr;
    aten::CSRMatrix csc;
    bool has_coo = (created_formats & COO_CODE);
    bool has_csr = (created_formats & CSR_CODE);
    bool has_csc = (created_formats & CSC_CODE);
221
222
223
224
225
226
227
228
229

    if (created_formats & COO_CODE) {
      CHECK_GE(states.arrays.end() - array_itr, 2);
      const auto &row = *(array_itr++);
      const auto &col = *(array_itr++);
      bool rsorted;
      bool csorted;
      CHECK(strm->Read(&rsorted)) << "Invalid flag 'rsorted'";
      CHECK(strm->Read(&csorted)) << "Invalid flag 'csorted'";
230
      coo = aten::COOMatrix(num_src, num_dst, row, col, aten::NullArray(), rsorted, csorted);
231
232
233
234
235
236
237
238
    }
    if (created_formats & CSR_CODE) {
      CHECK_GE(states.arrays.end() - array_itr, 3);
      const auto &indptr = *(array_itr++);
      const auto &indices = *(array_itr++);
      const auto &edge_id = *(array_itr++);
      bool sorted;
      CHECK(strm->Read(&sorted)) << "Invalid flag 'sorted'";
239
      csr = aten::CSRMatrix(num_src, num_dst, indptr, indices, edge_id, sorted);
240
241
242
243
244
245
246
247
    }
    if (created_formats & CSC_CODE) {
      CHECK_GE(states.arrays.end() - array_itr, 3);
      const auto &indptr = *(array_itr++);
      const auto &indices = *(array_itr++);
      const auto &edge_id = *(array_itr++);
      bool sorted;
      CHECK(strm->Read(&sorted)) << "Invalid flag 'sorted'";
248
      csc = aten::CSRMatrix(num_dst, num_src, indptr, indices, edge_id, sorted);
249
    }
250
251
    relgraphs[etype] = UnitGraph::CreateUnitGraphFrom(
        num_vtypes, csc, csr, coo, has_csc, has_csr, has_coo, allowed_formats);
252
  }
253
254
255
256
257
  auto graph = CreateHeteroGraph(metagraph, relgraphs, num_nodes_per_type);
  if (is_pinned) {
    graph->PinMemory_();
  }
  return graph;
258
259
}

260
261
262
263
264
265
266
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroPickleStatesGetVersion")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
    HeteroPickleStatesRef st = args[0];
    *rv = st->version;
  });

DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroPickleStatesGetMeta")
267
268
.set_body([] (DGLArgs args, DGLRetValue* rv) {
    HeteroPickleStatesRef st = args[0];
269
270
271
272
    DGLByteArray buf;
    buf.data = st->meta.c_str();
    buf.size = st->meta.size();
    *rv = buf;
273
274
  });

275
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroPickleStatesGetArrays")
276
277
.set_body([] (DGLArgs args, DGLRetValue* rv) {
    HeteroPickleStatesRef st = args[0];
278
    *rv = ConvertNDArrayVectorToPackedFunc(st->arrays);
279
280
  });

281
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroPickleStatesGetArraysNum")
282
283
.set_body([] (DGLArgs args, DGLRetValue* rv) {
    HeteroPickleStatesRef st = args[0];
284
    *rv = static_cast<int64_t>(st->arrays.size());
285
286
287
288
  });

DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLCreateHeteroPickleStates")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
289
290
291
    const int version = args[0];
    std::string meta = args[1];
    const List<Value> arrays = args[2];
292
    std::shared_ptr<HeteroPickleStates> st( new HeteroPickleStates );
293
    st->version = version == 0 ? 1 : version;
294
295
296
297
298
    st->meta = meta;
    st->arrays.reserve(arrays.size());
    for (const auto& ref : arrays) {
      st->arrays.push_back(ref->data);
    }
299
300
301
302
303
304
305
306
307
308
309
    *rv = HeteroPickleStatesRef(st);
  });

DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroPickle")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
    HeteroGraphRef ref = args[0];
    std::shared_ptr<HeteroPickleStates> st( new HeteroPickleStates );
    *st = HeteroPickle(ref.sptr());
    *rv = HeteroPickleStatesRef(st);
  });

310
311
312
313
314
315
316
317
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroForkingPickle")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
    HeteroGraphRef ref = args[0];
    std::shared_ptr<HeteroPickleStates> st( new HeteroPickleStates );
    *st = HeteroForkingPickle(ref.sptr());
    *rv = HeteroPickleStatesRef(st);
  });

318
319
320
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroUnpickle")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
    HeteroPickleStatesRef ref = args[0];
321
322
323
324
325
326
    HeteroGraphPtr graph;
    switch (ref->version) {
      case 0:
        graph = HeteroUnpickleOld(*ref.sptr());
        break;
      case 1:
327
      case 2:
328
329
330
        graph = HeteroUnpickle(*ref.sptr());
        break;
      default:
331
        LOG(FATAL) << "Version can only be 0 or 1 or 2.";
332
    }
333
334
335
    *rv = HeteroGraphRef(graph);
  });

336
337
338
339
340
341
342
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroForkingUnpickle")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
    HeteroPickleStatesRef ref = args[0];
    HeteroGraphPtr graph = HeteroForkingUnpickle(*ref.sptr());
    *rv = HeteroGraphRef(graph);
  });

343
344
345
346
347
348
349
350
351
352
353
354
355
356
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLCreateHeteroPickleStatesOld")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
    GraphRef metagraph = args[0];
    IdArray num_nodes_per_type = args[1];
    List<SparseMatrixRef> adjs = args[2];
    std::shared_ptr<HeteroPickleStates> st( new HeteroPickleStates );
    st->version = 0;
    st->metagraph = metagraph.sptr();
    st->num_nodes_per_type = num_nodes_per_type.ToVector<int64_t>();
    st->adjs.reserve(adjs.size());
    for (const auto& ref : adjs)
      st->adjs.push_back(ref.sptr());
    *rv = HeteroPickleStatesRef(st);
  });
357
}  // namespace dgl