graph_serialize.cc 4.47 KB
Newer Older
VoVAllen's avatar
VoVAllen committed
1
2
/*!
 *  Copyright (c) 2019 by Contributors
3
 * \file graph/serialize/graph_serialize.cc
VoVAllen's avatar
VoVAllen committed
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
 * \brief Graph serialization implementation
 *
 * The storage structure is
 * {
 *   // MetaData Section
 *   uint64_t kDGLSerializeMagic
 *   uint64_t kVersion
 *   uint64_t GraphType
 *   ** Reserved Area till 4kB **
 *
 *   dgl_id_t num_graphs
 *   vector<dgl_id_t> graph_indices (start address of each graph)
 *   vector<dgl_id_t> nodes_num_list (list of number of nodes for each graph)
 *   vector<dgl_id_t> edges_num_list (list of number of edges for each graph)
 *
 *   vector<GraphData> graph_datas;
 *
 * }
 *
 * Storage of GraphData is
 * {
 *   // Everything uses in csr
 *   NDArray indptr
 *   NDArray indices
 *   NDArray edge_ids
 *   vector<pair<string, NDArray>> node_tensors;
 *   vector<pair<string, NDArray>> edge_tensors;
 * }
 *
 */
#include "graph_serialize.h"
35
36

#include <dgl/graph_op.h>
VoVAllen's avatar
VoVAllen committed
37
#include <dgl/immutable_graph.h>
38
#include <dgl/runtime/container.h>
VoVAllen's avatar
VoVAllen committed
39
#include <dgl/runtime/object.h>
40
41
42
43
44
#include <dmlc/io.h>
#include <dmlc/logging.h>
#include <dmlc/type_traits.h>

#include <algorithm>
VoVAllen's avatar
VoVAllen committed
45
46
47
#include <iostream>
#include <string>
#include <utility>
48
#include <vector>
VoVAllen's avatar
VoVAllen committed
49
50
51
52
53
54
55
56
57

using namespace dgl::runtime;

using dgl::COO;
using dgl::COOPtr;
using dgl::ImmutableGraph;
using dgl::runtime::NDArray;
using dgl::serialize::GraphData;
using dgl::serialize::GraphDataObject;
58
59
60
using dmlc::SeekStream;
using dmlc::Stream;
using std::vector;
VoVAllen's avatar
VoVAllen committed
61
62
63
64
65
66
67
68
69

namespace dmlc {
DMLC_DECLARE_TRAITS(has_saveload, GraphDataObject, true);
}

namespace dgl {
namespace serialize {

DGL_REGISTER_GLOBAL("data.graph_serialize._CAPI_MakeGraphData")
70
  .set_body([](DGLArgs args, DGLRetValue *rv) {
VoVAllen's avatar
VoVAllen committed
71
72
73
74
75
76
77
    GraphRef gptr = args[0];
    ImmutableGraphPtr imGPtr = ToImmutableGraph(gptr.sptr());
    Map<std::string, Value> node_tensors = args[1];
    Map<std::string, Value> edge_tensors = args[2];
    GraphData gd = GraphData::Create();
    gd->SetData(imGPtr, node_tensors, edge_tensors);
    *rv = gd;
78
  });
VoVAllen's avatar
VoVAllen committed
79

80
81
DGL_REGISTER_GLOBAL("data.graph_serialize._CAPI_SaveDGLGraphs_V0")
  .set_body([](DGLArgs args, DGLRetValue *rv) {
VoVAllen's avatar
VoVAllen committed
82
83
84
85
86
87
88
89
90
91
92
    std::string filename = args[0];
    List<GraphData> graph_data = args[1];
    Map<std::string, Value> labels = args[2];
    std::vector<NamedTensor> labels_list;
    for (auto kv : labels) {
      std::string name = kv.first;
      Value v = kv.second;
      NDArray ndarray = static_cast<NDArray>(v->data);
      labels_list.emplace_back(name, ndarray);
    }
    SaveDGLGraphs(filename, graph_data, labels_list);
93
  });
VoVAllen's avatar
VoVAllen committed
94
95

DGL_REGISTER_GLOBAL("data.graph_serialize._CAPI_GDataGraphHandle")
96
  .set_body([](DGLArgs args, DGLRetValue *rv) {
VoVAllen's avatar
VoVAllen committed
97
98
    GraphData gdata = args[0];
    *rv = gdata->gptr;
99
  });
VoVAllen's avatar
VoVAllen committed
100
101

DGL_REGISTER_GLOBAL("data.graph_serialize._CAPI_GDataNodeTensors")
102
  .set_body([](DGLArgs args, DGLRetValue *rv) {
VoVAllen's avatar
VoVAllen committed
103
104
105
106
107
108
    GraphData gdata = args[0];
    Map<std::string, Value> rvmap;
    for (auto kv : gdata->node_tensors) {
      rvmap.Set(kv.first, Value(MakeValue(kv.second)));
    }
    *rv = rvmap;
109
  });
VoVAllen's avatar
VoVAllen committed
110
111

DGL_REGISTER_GLOBAL("data.graph_serialize._CAPI_GDataEdgeTensors")
112
  .set_body([](DGLArgs args, DGLRetValue *rv) {
VoVAllen's avatar
VoVAllen committed
113
114
115
116
117
118
    GraphData gdata = args[0];
    Map<std::string, Value> rvmap;
    for (auto kv : gdata->edge_tensors) {
      rvmap.Set(kv.first, Value(MakeValue(kv.second)));
    }
    *rv = rvmap;
119
  });
VoVAllen's avatar
VoVAllen committed
120

121
122
123
124
125
uint64_t GetFileVersion(const std::string &filename) {
  auto fs = std::unique_ptr<SeekStream>(
    SeekStream::CreateForRead(filename.c_str(), false));
  CHECK(fs) << "File " << filename << " not found";
  uint64_t magicNum, version;
VoVAllen's avatar
VoVAllen committed
126
127
128
  fs->Read(&magicNum);
  fs->Read(&version);
  CHECK_EQ(magicNum, kDGLSerializeMagic) << "Invalid DGL files";
129
  return version;
VoVAllen's avatar
VoVAllen committed
130
131
}

132
133
134
135
136
DGL_REGISTER_GLOBAL("data.graph_serialize._CAPI_GetFileVersion")
  .set_body([](DGLArgs args, DGLRetValue *rv) {
    std::string filename = args[0];
    *rv = static_cast<int64_t>(GetFileVersion(filename));
  });
VoVAllen's avatar
VoVAllen committed
137

138
139
140
141
142
143
144
145
DGL_REGISTER_GLOBAL("data.graph_serialize._CAPI_LoadGraphFiles_V1")
  .set_body([](DGLArgs args, DGLRetValue *rv) {
    std::string filename = args[0];
    List<Value> idxs = args[1];
    bool onlyMeta = args[2];
    auto idx_list = ListValueToVector<dgl_id_t>(idxs);
    *rv = LoadDGLGraphs(filename, idx_list, onlyMeta);
  });
VoVAllen's avatar
VoVAllen committed
146

147
148
149
150
151
152
153
DGL_REGISTER_GLOBAL("data.graph_serialize._CAPI_LoadGraphFiles_V2")
  .set_body([](DGLArgs args, DGLRetValue *rv) {
    std::string filename = args[0];
    List<Value> idxs = args[1];
    auto idx_list = ListValueToVector<dgl_id_t>(idxs);
    *rv = List<HeteroGraphData>(LoadHeteroGraphs(filename, idx_list));
  });
VoVAllen's avatar
VoVAllen committed
154
155
156

}  // namespace serialize
}  // namespace dgl