pickle.cc 13.2 KB
Newer Older
sangwzh's avatar
sangwzh committed
1
// !!! This is a file automatically generated by hipify!!!
2
/**
3
 *  Copyright (c) 2020 by Contributors
4
5
 * @file graph/pickle.cc
 * @brief Functions for pickle and unpickle a graph
6
 */
7
8
#include <dgl/graph_serializer.h>
#include <dgl/immutable_graph.h>
9
10
#include <dgl/packed_func_ext.h>
#include <dgl/runtime/container.h>
11
#include <dmlc/memory_io.h>
12

13
#include "../c_api_common.h"
sangwzh's avatar
sangwzh committed
14
#include "heterograph.h"
15
#include "unit_graph.h"
16
17
18
19
20
21
22

using namespace dgl::runtime;

namespace dgl {

HeteroPickleStates HeteroPickle(HeteroGraphPtr graph) {
  HeteroPickleStates states;
23
  states.version = 2;
24
25
26
27
  dmlc::MemoryStringStream ofs(&states.meta);
  dmlc::Stream *strm = &ofs;
  strm->Write(ImmutableGraph::ToImmutable(graph->meta_graph()));
  strm->Write(graph->NumVerticesPerType());
28
  strm->Write(graph->IsPinned());
29
  for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) {
30
    SparseFormat fmt = graph->SelectFormat(etype, ALL_CODE);
31
    switch (fmt) {
32
33
34
35
36
37
38
      case SparseFormat::kCOO: {
        strm->Write(SparseFormat::kCOO);
        const auto &coo = graph->GetCOOMatrix(etype);
        strm->Write(coo.row_sorted);
        strm->Write(coo.col_sorted);
        states.arrays.push_back(coo.row);
        states.arrays.push_back(coo.col);
39
        break;
40
      }
41
      case SparseFormat::kCSR:
42
43
44
45
46
47
48
      case SparseFormat::kCSC: {
        strm->Write(SparseFormat::kCSR);
        const auto &csr = graph->GetCSRMatrix(etype);
        strm->Write(csr.sorted);
        states.arrays.push_back(csr.indptr);
        states.arrays.push_back(csr.indices);
        states.arrays.push_back(csr.data);
49
        break;
50
      }
51
52
53
54
55
56
57
      default:
        LOG(FATAL) << "Unsupported sparse format.";
    }
  }
  return states;
}

58
59
HeteroPickleStates HeteroForkingPickle(HeteroGraphPtr graph) {
  HeteroPickleStates states;
60
  states.version = 2;
61
62
63
64
  dmlc::MemoryStringStream ofs(&states.meta);
  dmlc::Stream *strm = &ofs;
  strm->Write(ImmutableGraph::ToImmutable(graph->meta_graph()));
  strm->Write(graph->NumVerticesPerType());
65
  strm->Write(graph->IsPinned());
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
  for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) {
    auto created_formats = graph->GetCreatedFormats();
    auto allowed_formats = graph->GetAllowedFormats();
    strm->Write(created_formats);
    strm->Write(allowed_formats);
    if (created_formats & COO_CODE) {
      const auto &coo = graph->GetCOOMatrix(etype);
      strm->Write(coo.row_sorted);
      strm->Write(coo.col_sorted);
      states.arrays.push_back(coo.row);
      states.arrays.push_back(coo.col);
    }
    if (created_formats & CSR_CODE) {
      const auto &csr = graph->GetCSRMatrix(etype);
      strm->Write(csr.sorted);
      states.arrays.push_back(csr.indptr);
      states.arrays.push_back(csr.indices);
      states.arrays.push_back(csr.data);
    }
    if (created_formats & CSC_CODE) {
      const auto &csc = graph->GetCSCMatrix(etype);
      strm->Write(csc.sorted);
      states.arrays.push_back(csc.indptr);
      states.arrays.push_back(csc.indices);
      states.arrays.push_back(csc.data);
    }
  }
  return states;
}

96
HeteroGraphPtr HeteroUnpickle(const HeteroPickleStates &states) {
97
98
99
100
101
102
103
104
105
  char *buf = const_cast<char *>(states.meta.c_str());  // a readonly stream?
  dmlc::MemoryFixedSizeStream ifs(buf, states.meta.size());
  dmlc::Stream *strm = &ifs;
  auto meta_imgraph = Serializer::make_shared<ImmutableGraph>();
  CHECK(strm->Read(&meta_imgraph)) << "Invalid meta graph";
  GraphPtr metagraph = meta_imgraph;
  std::vector<HeteroGraphPtr> relgraphs(metagraph->NumEdges());
  std::vector<int64_t> num_nodes_per_type;
  CHECK(strm->Read(&num_nodes_per_type)) << "Invalid num_nodes_per_type";
106
107
108
109
  bool is_pinned = false;
  if (states.version > 1) {
    CHECK(strm->Read(&is_pinned)) << "Invalid flag 'is_pinned'";
  }
110
111
112

  auto array_itr = states.arrays.begin();
  for (dgl_type_t etype = 0; etype < metagraph->NumEdges(); ++etype) {
113
    const auto &pair = metagraph->FindEdge(etype);
114
115
    const dgl_type_t srctype = pair.first;
    const dgl_type_t dsttype = pair.second;
116
    const int64_t num_vtypes = (srctype == dsttype) ? 1 : 2;
117
118
119
120
121
122
123
124
125
126
127
128
129
130
    int64_t num_src = num_nodes_per_type[srctype];
    int64_t num_dst = num_nodes_per_type[dsttype];
    SparseFormat fmt;
    CHECK(strm->Read(&fmt)) << "Invalid SparseFormat";
    HeteroGraphPtr relgraph;
    switch (fmt) {
      case SparseFormat::kCOO: {
        CHECK_GE(states.arrays.end() - array_itr, 2);
        const auto &row = *(array_itr++);
        const auto &col = *(array_itr++);
        bool rsorted;
        bool csorted;
        CHECK(strm->Read(&rsorted)) << "Invalid flag 'rsorted'";
        CHECK(strm->Read(&csorted)) << "Invalid flag 'csorted'";
131
132
        auto coo = aten::COOMatrix(
            num_src, num_dst, row, col, aten::NullArray(), rsorted, csorted);
133
        // TODO(zihao) fix
134
        relgraph = CreateFromCOO(num_vtypes, coo, ALL_CODE);
135
136
137
138
139
140
141
142
143
        break;
      }
      case SparseFormat::kCSR: {
        CHECK_GE(states.arrays.end() - array_itr, 3);
        const auto &indptr = *(array_itr++);
        const auto &indices = *(array_itr++);
        const auto &edge_id = *(array_itr++);
        bool sorted;
        CHECK(strm->Read(&sorted)) << "Invalid flag 'sorted'";
144
145
        auto csr =
            aten::CSRMatrix(num_src, num_dst, indptr, indices, edge_id, sorted);
146
        // TODO(zihao) fix
147
        relgraph = CreateFromCSR(num_vtypes, csr, ALL_CODE);
148
149
150
151
152
153
154
155
        break;
      }
      case SparseFormat::kCSC:
      default:
        LOG(FATAL) << "Unsupported sparse format.";
    }
    relgraphs[etype] = relgraph;
  }
156
157
158
159
160
  auto graph = CreateHeteroGraph(metagraph, relgraphs, num_nodes_per_type);
  if (is_pinned) {
    graph->PinMemory_();
  }
  return graph;
161
162
163
}

// For backward compatibility
164
HeteroGraphPtr HeteroUnpickleOld(const HeteroPickleStates &states) {
165
  const auto metagraph = states.metagraph;
166
  const auto &num_nodes_per_type = states.num_nodes_per_type;
167
168
169
  CHECK_EQ(states.adjs.size(), metagraph->NumEdges());
  std::vector<HeteroGraphPtr> relgraphs(metagraph->NumEdges());
  for (dgl_type_t etype = 0; etype < metagraph->NumEdges(); ++etype) {
170
    const auto &pair = metagraph->FindEdge(etype);
171
172
    const dgl_type_t srctype = pair.first;
    const dgl_type_t dsttype = pair.second;
173
174
175
    const int64_t num_vtypes = (srctype == dsttype) ? 1 : 2;
    const SparseFormat fmt =
        static_cast<SparseFormat>(states.adjs[etype]->format);
176
    switch (fmt) {
177
      case SparseFormat::kCOO:
178
179
180
        relgraphs[etype] = UnitGraph::CreateFromCOO(
            num_vtypes, aten::COOMatrix(*states.adjs[etype]));
        break;
181
      case SparseFormat::kCSR:
182
183
184
        relgraphs[etype] = UnitGraph::CreateFromCSR(
            num_vtypes, aten::CSRMatrix(*states.adjs[etype]));
        break;
185
      case SparseFormat::kCSC:
186
187
188
189
      default:
        LOG(FATAL) << "Unsupported sparse format.";
    }
  }
190
  return CreateHeteroGraph(metagraph, relgraphs, num_nodes_per_type);
191
192
}

193
194
195
196
197
198
199
200
201
202
HeteroGraphPtr HeteroForkingUnpickle(const HeteroPickleStates &states) {
  char *buf = const_cast<char *>(states.meta.c_str());  // a readonly stream?
  dmlc::MemoryFixedSizeStream ifs(buf, states.meta.size());
  dmlc::Stream *strm = &ifs;
  auto meta_imgraph = Serializer::make_shared<ImmutableGraph>();
  CHECK(strm->Read(&meta_imgraph)) << "Invalid meta graph";
  GraphPtr metagraph = meta_imgraph;
  std::vector<HeteroGraphPtr> relgraphs(metagraph->NumEdges());
  std::vector<int64_t> num_nodes_per_type;
  CHECK(strm->Read(&num_nodes_per_type)) << "Invalid num_nodes_per_type";
203
204
205
206
  bool is_pinned = false;
  if (states.version > 1) {
    CHECK(strm->Read(&is_pinned)) << "Invalid flag 'is_pinned'";
  }
207
208
209

  auto array_itr = states.arrays.begin();
  for (dgl_type_t etype = 0; etype < metagraph->NumEdges(); ++etype) {
210
    const auto &pair = metagraph->FindEdge(etype);
211
212
213
214
215
216
217
218
219
    const dgl_type_t srctype = pair.first;
    const dgl_type_t dsttype = pair.second;
    const int64_t num_vtypes = (srctype == dsttype) ? 1 : 2;
    int64_t num_src = num_nodes_per_type[srctype];
    int64_t num_dst = num_nodes_per_type[dsttype];

    dgl_format_code_t created_formats, allowed_formats;
    CHECK(strm->Read(&created_formats)) << "Invalid code for created formats";
    CHECK(strm->Read(&allowed_formats)) << "Invalid code for allowed formats";
220
221
222
223
224
225
    aten::COOMatrix coo;
    aten::CSRMatrix csr;
    aten::CSRMatrix csc;
    bool has_coo = (created_formats & COO_CODE);
    bool has_csr = (created_formats & CSR_CODE);
    bool has_csc = (created_formats & CSC_CODE);
226
227
228
229
230
231
232
233
234

    if (created_formats & COO_CODE) {
      CHECK_GE(states.arrays.end() - array_itr, 2);
      const auto &row = *(array_itr++);
      const auto &col = *(array_itr++);
      bool rsorted;
      bool csorted;
      CHECK(strm->Read(&rsorted)) << "Invalid flag 'rsorted'";
      CHECK(strm->Read(&csorted)) << "Invalid flag 'csorted'";
235
236
      coo = aten::COOMatrix(
          num_src, num_dst, row, col, aten::NullArray(), rsorted, csorted);
237
238
239
240
241
242
243
244
    }
    if (created_formats & CSR_CODE) {
      CHECK_GE(states.arrays.end() - array_itr, 3);
      const auto &indptr = *(array_itr++);
      const auto &indices = *(array_itr++);
      const auto &edge_id = *(array_itr++);
      bool sorted;
      CHECK(strm->Read(&sorted)) << "Invalid flag 'sorted'";
245
      csr = aten::CSRMatrix(num_src, num_dst, indptr, indices, edge_id, sorted);
246
247
248
249
250
251
252
253
    }
    if (created_formats & CSC_CODE) {
      CHECK_GE(states.arrays.end() - array_itr, 3);
      const auto &indptr = *(array_itr++);
      const auto &indices = *(array_itr++);
      const auto &edge_id = *(array_itr++);
      bool sorted;
      CHECK(strm->Read(&sorted)) << "Invalid flag 'sorted'";
254
      csc = aten::CSRMatrix(num_dst, num_src, indptr, indices, edge_id, sorted);
255
    }
256
257
    relgraphs[etype] = UnitGraph::CreateUnitGraphFrom(
        num_vtypes, csc, csr, coo, has_csc, has_csr, has_coo, allowed_formats);
258
  }
259
260
261
262
263
  auto graph = CreateHeteroGraph(metagraph, relgraphs, num_nodes_per_type);
  if (is_pinned) {
    graph->PinMemory_();
  }
  return graph;
264
265
}

266
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroPickleStatesGetVersion")
267
268
269
270
    .set_body([](DGLArgs args, DGLRetValue *rv) {
      HeteroPickleStatesRef st = args[0];
      *rv = st->version;
    });
271
272

DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroPickleStatesGetMeta")
273
274
275
276
277
278
279
    .set_body([](DGLArgs args, DGLRetValue *rv) {
      HeteroPickleStatesRef st = args[0];
      DGLByteArray buf;
      buf.data = st->meta.c_str();
      buf.size = st->meta.size();
      *rv = buf;
    });
280

281
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroPickleStatesGetArrays")
282
283
284
285
    .set_body([](DGLArgs args, DGLRetValue *rv) {
      HeteroPickleStatesRef st = args[0];
      *rv = ConvertNDArrayVectorToPackedFunc(st->arrays);
    });
286

287
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroPickleStatesGetArraysNum")
288
289
290
291
    .set_body([](DGLArgs args, DGLRetValue *rv) {
      HeteroPickleStatesRef st = args[0];
      *rv = static_cast<int64_t>(st->arrays.size());
    });
292
293

DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLCreateHeteroPickleStates")
294
295
296
297
298
299
300
301
302
303
304
305
306
    .set_body([](DGLArgs args, DGLRetValue *rv) {
      const int version = args[0];
      std::string meta = args[1];
      const List<Value> arrays = args[2];
      std::shared_ptr<HeteroPickleStates> st(new HeteroPickleStates);
      st->version = version == 0 ? 1 : version;
      st->meta = meta;
      st->arrays.reserve(arrays.size());
      for (const auto &ref : arrays) {
        st->arrays.push_back(ref->data);
      }
      *rv = HeteroPickleStatesRef(st);
    });
307
308

DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroPickle")
309
310
311
312
313
314
    .set_body([](DGLArgs args, DGLRetValue *rv) {
      HeteroGraphRef ref = args[0];
      std::shared_ptr<HeteroPickleStates> st(new HeteroPickleStates);
      *st = HeteroPickle(ref.sptr());
      *rv = HeteroPickleStatesRef(st);
    });
315

316
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroForkingPickle")
317
318
319
320
321
322
    .set_body([](DGLArgs args, DGLRetValue *rv) {
      HeteroGraphRef ref = args[0];
      std::shared_ptr<HeteroPickleStates> st(new HeteroPickleStates);
      *st = HeteroForkingPickle(ref.sptr());
      *rv = HeteroPickleStatesRef(st);
    });
323

324
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroUnpickle")
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
    .set_body([](DGLArgs args, DGLRetValue *rv) {
      HeteroPickleStatesRef ref = args[0];
      HeteroGraphPtr graph;
      switch (ref->version) {
        case 0:
          graph = HeteroUnpickleOld(*ref.sptr());
          break;
        case 1:
        case 2:
          graph = HeteroUnpickle(*ref.sptr());
          break;
        default:
          LOG(FATAL) << "Version can only be 0 or 1 or 2.";
      }
      *rv = HeteroGraphRef(graph);
    });
341

342
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroForkingUnpickle")
343
344
345
346
347
    .set_body([](DGLArgs args, DGLRetValue *rv) {
      HeteroPickleStatesRef ref = args[0];
      HeteroGraphPtr graph = HeteroForkingUnpickle(*ref.sptr());
      *rv = HeteroGraphRef(graph);
    });
348

349
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLCreateHeteroPickleStatesOld")
350
351
352
353
354
355
356
357
358
359
360
361
    .set_body([](DGLArgs args, DGLRetValue *rv) {
      GraphRef metagraph = args[0];
      IdArray num_nodes_per_type = args[1];
      List<SparseMatrixRef> adjs = args[2];
      std::shared_ptr<HeteroPickleStates> st(new HeteroPickleStates);
      st->version = 0;
      st->metagraph = metagraph.sptr();
      st->num_nodes_per_type = num_nodes_per_type.ToVector<int64_t>();
      st->adjs.reserve(adjs.size());
      for (const auto &ref : adjs) st->adjs.push_back(ref.sptr());
      *rv = HeteroPickleStatesRef(st);
    });
362
}  // namespace dgl