pickle.cc 13.1 KB
Newer Older
1
2
3
4
5
/*!
 *  Copyright (c) 2020 by Contributors
 * \file graph/pickle.cc
 * \brief Functions for pickle and unpickle a graph
 */
6
7
#include <dgl/graph_serializer.h>
#include <dgl/immutable_graph.h>
8
9
#include <dgl/packed_func_ext.h>
#include <dgl/runtime/container.h>
10
#include <dmlc/memory_io.h>
11

12
#include "../c_api_common.h"
13
#include "./heterograph.h"
14
#include "unit_graph.h"
15
16
17
18
19
20
21

using namespace dgl::runtime;

namespace dgl {

HeteroPickleStates HeteroPickle(HeteroGraphPtr graph) {
  HeteroPickleStates states;
22
  states.version = 2;
23
24
25
26
  dmlc::MemoryStringStream ofs(&states.meta);
  dmlc::Stream *strm = &ofs;
  strm->Write(ImmutableGraph::ToImmutable(graph->meta_graph()));
  strm->Write(graph->NumVerticesPerType());
27
  strm->Write(graph->IsPinned());
28
  for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) {
29
    SparseFormat fmt = graph->SelectFormat(etype, ALL_CODE);
30
    switch (fmt) {
31
32
33
34
35
36
37
      case SparseFormat::kCOO: {
        strm->Write(SparseFormat::kCOO);
        const auto &coo = graph->GetCOOMatrix(etype);
        strm->Write(coo.row_sorted);
        strm->Write(coo.col_sorted);
        states.arrays.push_back(coo.row);
        states.arrays.push_back(coo.col);
38
        break;
39
      }
40
      case SparseFormat::kCSR:
41
42
43
44
45
46
47
      case SparseFormat::kCSC: {
        strm->Write(SparseFormat::kCSR);
        const auto &csr = graph->GetCSRMatrix(etype);
        strm->Write(csr.sorted);
        states.arrays.push_back(csr.indptr);
        states.arrays.push_back(csr.indices);
        states.arrays.push_back(csr.data);
48
        break;
49
      }
50
51
52
53
54
55
56
      default:
        LOG(FATAL) << "Unsupported sparse format.";
    }
  }
  return states;
}

57
58
HeteroPickleStates HeteroForkingPickle(HeteroGraphPtr graph) {
  HeteroPickleStates states;
59
  states.version = 2;
60
61
62
63
  dmlc::MemoryStringStream ofs(&states.meta);
  dmlc::Stream *strm = &ofs;
  strm->Write(ImmutableGraph::ToImmutable(graph->meta_graph()));
  strm->Write(graph->NumVerticesPerType());
64
  strm->Write(graph->IsPinned());
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
  for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) {
    auto created_formats = graph->GetCreatedFormats();
    auto allowed_formats = graph->GetAllowedFormats();
    strm->Write(created_formats);
    strm->Write(allowed_formats);
    if (created_formats & COO_CODE) {
      const auto &coo = graph->GetCOOMatrix(etype);
      strm->Write(coo.row_sorted);
      strm->Write(coo.col_sorted);
      states.arrays.push_back(coo.row);
      states.arrays.push_back(coo.col);
    }
    if (created_formats & CSR_CODE) {
      const auto &csr = graph->GetCSRMatrix(etype);
      strm->Write(csr.sorted);
      states.arrays.push_back(csr.indptr);
      states.arrays.push_back(csr.indices);
      states.arrays.push_back(csr.data);
    }
    if (created_formats & CSC_CODE) {
      const auto &csc = graph->GetCSCMatrix(etype);
      strm->Write(csc.sorted);
      states.arrays.push_back(csc.indptr);
      states.arrays.push_back(csc.indices);
      states.arrays.push_back(csc.data);
    }
  }
  return states;
}

95
HeteroGraphPtr HeteroUnpickle(const HeteroPickleStates &states) {
96
97
98
99
100
101
102
103
104
  char *buf = const_cast<char *>(states.meta.c_str());  // a readonly stream?
  dmlc::MemoryFixedSizeStream ifs(buf, states.meta.size());
  dmlc::Stream *strm = &ifs;
  auto meta_imgraph = Serializer::make_shared<ImmutableGraph>();
  CHECK(strm->Read(&meta_imgraph)) << "Invalid meta graph";
  GraphPtr metagraph = meta_imgraph;
  std::vector<HeteroGraphPtr> relgraphs(metagraph->NumEdges());
  std::vector<int64_t> num_nodes_per_type;
  CHECK(strm->Read(&num_nodes_per_type)) << "Invalid num_nodes_per_type";
105
106
107
108
  bool is_pinned = false;
  if (states.version > 1) {
    CHECK(strm->Read(&is_pinned)) << "Invalid flag 'is_pinned'";
  }
109
110
111

  auto array_itr = states.arrays.begin();
  for (dgl_type_t etype = 0; etype < metagraph->NumEdges(); ++etype) {
112
    const auto &pair = metagraph->FindEdge(etype);
113
114
    const dgl_type_t srctype = pair.first;
    const dgl_type_t dsttype = pair.second;
115
    const int64_t num_vtypes = (srctype == dsttype) ? 1 : 2;
116
117
118
119
120
121
122
123
124
125
126
127
128
129
    int64_t num_src = num_nodes_per_type[srctype];
    int64_t num_dst = num_nodes_per_type[dsttype];
    SparseFormat fmt;
    CHECK(strm->Read(&fmt)) << "Invalid SparseFormat";
    HeteroGraphPtr relgraph;
    switch (fmt) {
      case SparseFormat::kCOO: {
        CHECK_GE(states.arrays.end() - array_itr, 2);
        const auto &row = *(array_itr++);
        const auto &col = *(array_itr++);
        bool rsorted;
        bool csorted;
        CHECK(strm->Read(&rsorted)) << "Invalid flag 'rsorted'";
        CHECK(strm->Read(&csorted)) << "Invalid flag 'csorted'";
130
131
        auto coo = aten::COOMatrix(
            num_src, num_dst, row, col, aten::NullArray(), rsorted, csorted);
132
        // TODO(zihao) fix
133
        relgraph = CreateFromCOO(num_vtypes, coo, ALL_CODE);
134
135
136
137
138
139
140
141
142
        break;
      }
      case SparseFormat::kCSR: {
        CHECK_GE(states.arrays.end() - array_itr, 3);
        const auto &indptr = *(array_itr++);
        const auto &indices = *(array_itr++);
        const auto &edge_id = *(array_itr++);
        bool sorted;
        CHECK(strm->Read(&sorted)) << "Invalid flag 'sorted'";
143
144
        auto csr =
            aten::CSRMatrix(num_src, num_dst, indptr, indices, edge_id, sorted);
145
        // TODO(zihao) fix
146
        relgraph = CreateFromCSR(num_vtypes, csr, ALL_CODE);
147
148
149
150
151
152
153
154
        break;
      }
      case SparseFormat::kCSC:
      default:
        LOG(FATAL) << "Unsupported sparse format.";
    }
    relgraphs[etype] = relgraph;
  }
155
156
157
158
159
  auto graph = CreateHeteroGraph(metagraph, relgraphs, num_nodes_per_type);
  if (is_pinned) {
    graph->PinMemory_();
  }
  return graph;
160
161
162
}

// For backward compatibility
163
HeteroGraphPtr HeteroUnpickleOld(const HeteroPickleStates &states) {
164
  const auto metagraph = states.metagraph;
165
  const auto &num_nodes_per_type = states.num_nodes_per_type;
166
167
168
  CHECK_EQ(states.adjs.size(), metagraph->NumEdges());
  std::vector<HeteroGraphPtr> relgraphs(metagraph->NumEdges());
  for (dgl_type_t etype = 0; etype < metagraph->NumEdges(); ++etype) {
169
    const auto &pair = metagraph->FindEdge(etype);
170
171
    const dgl_type_t srctype = pair.first;
    const dgl_type_t dsttype = pair.second;
172
173
174
    const int64_t num_vtypes = (srctype == dsttype) ? 1 : 2;
    const SparseFormat fmt =
        static_cast<SparseFormat>(states.adjs[etype]->format);
175
    switch (fmt) {
176
      case SparseFormat::kCOO:
177
178
179
        relgraphs[etype] = UnitGraph::CreateFromCOO(
            num_vtypes, aten::COOMatrix(*states.adjs[etype]));
        break;
180
      case SparseFormat::kCSR:
181
182
183
        relgraphs[etype] = UnitGraph::CreateFromCSR(
            num_vtypes, aten::CSRMatrix(*states.adjs[etype]));
        break;
184
      case SparseFormat::kCSC:
185
186
187
188
      default:
        LOG(FATAL) << "Unsupported sparse format.";
    }
  }
189
  return CreateHeteroGraph(metagraph, relgraphs, num_nodes_per_type);
190
191
}

192
193
194
195
196
197
198
199
200
201
HeteroGraphPtr HeteroForkingUnpickle(const HeteroPickleStates &states) {
  char *buf = const_cast<char *>(states.meta.c_str());  // a readonly stream?
  dmlc::MemoryFixedSizeStream ifs(buf, states.meta.size());
  dmlc::Stream *strm = &ifs;
  auto meta_imgraph = Serializer::make_shared<ImmutableGraph>();
  CHECK(strm->Read(&meta_imgraph)) << "Invalid meta graph";
  GraphPtr metagraph = meta_imgraph;
  std::vector<HeteroGraphPtr> relgraphs(metagraph->NumEdges());
  std::vector<int64_t> num_nodes_per_type;
  CHECK(strm->Read(&num_nodes_per_type)) << "Invalid num_nodes_per_type";
202
203
204
205
  bool is_pinned = false;
  if (states.version > 1) {
    CHECK(strm->Read(&is_pinned)) << "Invalid flag 'is_pinned'";
  }
206
207
208

  auto array_itr = states.arrays.begin();
  for (dgl_type_t etype = 0; etype < metagraph->NumEdges(); ++etype) {
209
    const auto &pair = metagraph->FindEdge(etype);
210
211
212
213
214
215
216
217
218
    const dgl_type_t srctype = pair.first;
    const dgl_type_t dsttype = pair.second;
    const int64_t num_vtypes = (srctype == dsttype) ? 1 : 2;
    int64_t num_src = num_nodes_per_type[srctype];
    int64_t num_dst = num_nodes_per_type[dsttype];

    dgl_format_code_t created_formats, allowed_formats;
    CHECK(strm->Read(&created_formats)) << "Invalid code for created formats";
    CHECK(strm->Read(&allowed_formats)) << "Invalid code for allowed formats";
219
220
221
222
223
224
    aten::COOMatrix coo;
    aten::CSRMatrix csr;
    aten::CSRMatrix csc;
    bool has_coo = (created_formats & COO_CODE);
    bool has_csr = (created_formats & CSR_CODE);
    bool has_csc = (created_formats & CSC_CODE);
225
226
227
228
229
230
231
232
233

    if (created_formats & COO_CODE) {
      CHECK_GE(states.arrays.end() - array_itr, 2);
      const auto &row = *(array_itr++);
      const auto &col = *(array_itr++);
      bool rsorted;
      bool csorted;
      CHECK(strm->Read(&rsorted)) << "Invalid flag 'rsorted'";
      CHECK(strm->Read(&csorted)) << "Invalid flag 'csorted'";
234
235
      coo = aten::COOMatrix(
          num_src, num_dst, row, col, aten::NullArray(), rsorted, csorted);
236
237
238
239
240
241
242
243
    }
    if (created_formats & CSR_CODE) {
      CHECK_GE(states.arrays.end() - array_itr, 3);
      const auto &indptr = *(array_itr++);
      const auto &indices = *(array_itr++);
      const auto &edge_id = *(array_itr++);
      bool sorted;
      CHECK(strm->Read(&sorted)) << "Invalid flag 'sorted'";
244
      csr = aten::CSRMatrix(num_src, num_dst, indptr, indices, edge_id, sorted);
245
246
247
248
249
250
251
252
    }
    if (created_formats & CSC_CODE) {
      CHECK_GE(states.arrays.end() - array_itr, 3);
      const auto &indptr = *(array_itr++);
      const auto &indices = *(array_itr++);
      const auto &edge_id = *(array_itr++);
      bool sorted;
      CHECK(strm->Read(&sorted)) << "Invalid flag 'sorted'";
253
      csc = aten::CSRMatrix(num_dst, num_src, indptr, indices, edge_id, sorted);
254
    }
255
256
    relgraphs[etype] = UnitGraph::CreateUnitGraphFrom(
        num_vtypes, csc, csr, coo, has_csc, has_csr, has_coo, allowed_formats);
257
  }
258
259
260
261
262
  auto graph = CreateHeteroGraph(metagraph, relgraphs, num_nodes_per_type);
  if (is_pinned) {
    graph->PinMemory_();
  }
  return graph;
263
264
}

265
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroPickleStatesGetVersion")
266
267
268
269
    .set_body([](DGLArgs args, DGLRetValue *rv) {
      HeteroPickleStatesRef st = args[0];
      *rv = st->version;
    });
270
271

DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroPickleStatesGetMeta")
272
273
274
275
276
277
278
    .set_body([](DGLArgs args, DGLRetValue *rv) {
      HeteroPickleStatesRef st = args[0];
      DGLByteArray buf;
      buf.data = st->meta.c_str();
      buf.size = st->meta.size();
      *rv = buf;
    });
279

280
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroPickleStatesGetArrays")
281
282
283
284
    .set_body([](DGLArgs args, DGLRetValue *rv) {
      HeteroPickleStatesRef st = args[0];
      *rv = ConvertNDArrayVectorToPackedFunc(st->arrays);
    });
285

286
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroPickleStatesGetArraysNum")
287
288
289
290
    .set_body([](DGLArgs args, DGLRetValue *rv) {
      HeteroPickleStatesRef st = args[0];
      *rv = static_cast<int64_t>(st->arrays.size());
    });
291
292

DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLCreateHeteroPickleStates")
293
294
295
296
297
298
299
300
301
302
303
304
305
    .set_body([](DGLArgs args, DGLRetValue *rv) {
      const int version = args[0];
      std::string meta = args[1];
      const List<Value> arrays = args[2];
      std::shared_ptr<HeteroPickleStates> st(new HeteroPickleStates);
      st->version = version == 0 ? 1 : version;
      st->meta = meta;
      st->arrays.reserve(arrays.size());
      for (const auto &ref : arrays) {
        st->arrays.push_back(ref->data);
      }
      *rv = HeteroPickleStatesRef(st);
    });
306
307

DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroPickle")
308
309
310
311
312
313
    .set_body([](DGLArgs args, DGLRetValue *rv) {
      HeteroGraphRef ref = args[0];
      std::shared_ptr<HeteroPickleStates> st(new HeteroPickleStates);
      *st = HeteroPickle(ref.sptr());
      *rv = HeteroPickleStatesRef(st);
    });
314

315
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroForkingPickle")
316
317
318
319
320
321
    .set_body([](DGLArgs args, DGLRetValue *rv) {
      HeteroGraphRef ref = args[0];
      std::shared_ptr<HeteroPickleStates> st(new HeteroPickleStates);
      *st = HeteroForkingPickle(ref.sptr());
      *rv = HeteroPickleStatesRef(st);
    });
322

323
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroUnpickle")
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
    .set_body([](DGLArgs args, DGLRetValue *rv) {
      HeteroPickleStatesRef ref = args[0];
      HeteroGraphPtr graph;
      switch (ref->version) {
        case 0:
          graph = HeteroUnpickleOld(*ref.sptr());
          break;
        case 1:
        case 2:
          graph = HeteroUnpickle(*ref.sptr());
          break;
        default:
          LOG(FATAL) << "Version can only be 0 or 1 or 2.";
      }
      *rv = HeteroGraphRef(graph);
    });
340

341
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroForkingUnpickle")
342
343
344
345
346
    .set_body([](DGLArgs args, DGLRetValue *rv) {
      HeteroPickleStatesRef ref = args[0];
      HeteroGraphPtr graph = HeteroForkingUnpickle(*ref.sptr());
      *rv = HeteroGraphRef(graph);
    });
347

348
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLCreateHeteroPickleStatesOld")
349
350
351
352
353
354
355
356
357
358
359
360
    .set_body([](DGLArgs args, DGLRetValue *rv) {
      GraphRef metagraph = args[0];
      IdArray num_nodes_per_type = args[1];
      List<SparseMatrixRef> adjs = args[2];
      std::shared_ptr<HeteroPickleStates> st(new HeteroPickleStates);
      st->version = 0;
      st->metagraph = metagraph.sptr();
      st->num_nodes_per_type = num_nodes_per_type.ToVector<int64_t>();
      st->adjs.reserve(adjs.size());
      for (const auto &ref : adjs) st->adjs.push_back(ref.sptr());
      *rv = HeteroPickleStatesRef(st);
    });
361
}  // namespace dgl