"docs/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "2b23ec82e898aa1f0e172da6ef054631634b643d"
Unverified Commit 18a26fcf authored by Jinjing Zhou's avatar Jinjing Zhou Committed by GitHub
Browse files

[Data] Support HeteroGraph save/load (#1526)



* 111

* add history version test

* fix

* 111

* save

* ``

* fix1

* 111

* add save heterograph

* lint

* lint

* add tests

* minor fix

* fix

* docs

* add format tets

* use unique_ptr

* fix

* fix interface

* 111

* 111

* fix

* lint

* fix

* add support to s3

* fix

* fix

* fix leak

* fix

* fix docs

* fix

* linlt

* fix

* fix

* fix

* address comment

* address comment
Co-authored-by: default avatarMinjie Wang <wmjlyjemaine@gmail.com>
Co-authored-by: default avatarxiang song(charlie.song) <classicxsong@gmail.com>
parent dd8d5289
......@@ -26,6 +26,7 @@ dgl_option(USE_CUDA "Build with CUDA" OFF)
dgl_option(USE_OPENMP "Build with OpenMP" ON)
dgl_option(BUILD_CPP_TEST "Build cpp unittest executables" OFF)
dgl_option(LIBCXX_ENABLE_PARALLEL_ALGORITHMS "Enable the parallel algorithms library. This requires the PSTL to be available." OFF)
# Set debug compile option for gdb, only happens when -DCMAKE_BUILD_TYPE=DEBUG
if (NOT MSVC)
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -O0 -g3 -ggdb")
......@@ -132,6 +133,9 @@ endif(USE_CUDA)
# For serialization
add_subdirectory("third_party/dmlc-core")
# dmlc-core options for support S3/HDFS
add_definitions(-DUSE_S3=OFF)
add_definitions(-DUSE_HDFS=OFF)
list(APPEND DGL_LINKER_LIBS dmlc)
set(GOOGLE_TEST 0) # Turn off dmlc-core test
......
......@@ -3,8 +3,10 @@ reference: tvm/python/tvm/collections.py
"""
from __future__ import absolute_import as _abs
from ._ffi.object import ObjectBase, register_object
from ._ffi.object_generic import convert_to_object
from . import _api_internal
@register_object
class List(ObjectBase):
"""List container of DGL.
......@@ -14,6 +16,7 @@ class List(ObjectBase):
to List during dgl function call.
You may get List in return values of DGL function call.
"""
def __getitem__(self, i):
if isinstance(i, slice):
start = i.start if i.start is not None else 0
......@@ -30,11 +33,15 @@ class List(ObjectBase):
.format(len(self), i))
if i < 0:
i += len(self)
return _api_internal._ListGetItem(self, i)
ret = _api_internal._ListGetItem(self, i)
if isinstance(ret, Value):
ret = ret.data
return ret
def __len__(self):
return _api_internal._ListSize(self)
@register_object
class Map(ObjectBase):
"""Map container of DGL.
......@@ -43,6 +50,7 @@ class Map(ObjectBase):
Normally python dict will be converted automaticall to Map during dgl function call.
You can use convert to create a dict[ObjectBase-> ObjectBase] into a Map
"""
def __getitem__(self, k):
return _api_internal._MapGetItem(self, k)
......@@ -64,10 +72,12 @@ class StrMap(Map):
You can use convert to create a dict[str->ObjectBase] into a Map.
"""
def items(self):
"""Get the items from the map"""
akvs = _api_internal._MapItems(self)
return [(akvs[i].data, akvs[i+1]) for i in range(0, len(akvs), 2)]
return [(akvs[i], akvs[i+1]) for i in range(0, len(akvs), 2)]
@register_object
class Value(ObjectBase):
......@@ -76,3 +86,12 @@ class Value(ObjectBase):
def data(self):
"""Return the value data."""
return _api_internal._ValueGet(self)
def convert_to_strmap(value):
"""Convert a python dictionary to a dgl.contrainer.StrMap"""
assert isinstance(value, dict), "Only support dict"
if len(value) == 0:
return _api_internal._EmptyStrMap()
else:
return convert_to_object(value)
"""For Graph Serialization"""
from __future__ import absolute_import
from ..graph import DGLGraph
from ..heterograph import DGLHeteroGraph
from .._ffi.object import ObjectBase, register_object
from .._ffi.function import _init_api
from .. import backend as F
from .heterograph_serialize import HeteroGraphData, save_heterographs
_init_api("dgl.data.graph_serialize")
__all__ = ['save_graphs', "load_graphs", "load_labels"]
@register_object("graph_serialize.StorageMetaData")
class StorageMetaData(ObjectBase):
"""StorageMetaData Object
......@@ -54,40 +57,36 @@ class GraphData(ObjectBase):
node_tensors_items = _CAPI_GDataNodeTensors(self).items()
edge_tensors_items = _CAPI_GDataEdgeTensors(self).items()
for k, v in node_tensors_items:
g.ndata[k] = F.zerocopy_from_dgl_ndarray(v.data)
g.ndata[k] = F.zerocopy_from_dgl_ndarray(v)
for k, v in edge_tensors_items:
g.edata[k] = F.zerocopy_from_dgl_ndarray(v.data)
g.edata[k] = F.zerocopy_from_dgl_ndarray(v)
return g
def save_graphs(filename, g_list, labels=None):
r"""
Save DGLGraphs and graph labels to file
Save DGLGraphs/DGLHeteroGraph and graph labels to file
Parameters
----------
filename : str
File name to store DGLGraphs.
File name to store graphs.
g_list: list
DGLGraph or list of DGLGraph
labels: dict (Default: None)
labels should be dict of tensors/ndarray, with str as keys
DGLGraph or list of DGLGraph/DGLHeteroGraph
labels: dict[str, tensor]
labels should be dict of tensors, with str as keys
Examples
----------
>>> import dgl
>>> import torch as th
Create :code:`DGLGraph` objects and initialize node and edge features.
Create :code:`DGLGraph`/:code:`DGLHeteroGraph` objects and initialize node
and edge features.
>>> g1 = dgl.DGLGraph()
>>> g1.add_nodes(3)
>>> g1.add_edges([0, 0, 0, 1, 1, 2], [0, 1, 2, 1, 2, 2])
>>> g1.ndata["e"] = th.ones(3, 5)
>>> g2 = dgl.DGLGraph()
>>> g2.add_nodes(3)
>>> g2.add_edges([0, 1, 2], [1, 2, 1])
>>> g2.edata["e"] = th.ones(3, 4)
>>> g1 = dgl.graph(([0, 1, 2], [1, 2, 3])
>>> g2 = dgl.graph(([0, 2], [2, 3])
>>> g2.edata["e"] = th.ones(2, 4)
Save Graphs into file
......@@ -96,6 +95,18 @@ def save_graphs(filename, g_list, labels=None):
>>> save_graphs("./data.bin", [g1, g2], graph_labels)
"""
g_sample = g_list[0] if isinstance(g_list, list) else g_list
if isinstance(g_sample, DGLGraph):
save_dglgraphs(filename, g_list, labels)
elif isinstance(g_sample, DGLHeteroGraph):
save_heterographs(filename, g_list, labels)
else:
raise Exception(
"Invalid argument g_list. Must be a DGLGraph or a list of DGLGraphs/DGLHeteroGraphs")
def save_dglgraphs(filename, g_list, labels=None):
"""Internal function to save DGLGraphs"""
if isinstance(g_list, DGLGraph):
g_list = [g_list]
if (labels is not None) and (len(labels) != 0):
......@@ -105,7 +116,7 @@ def save_graphs(filename, g_list, labels=None):
else:
label_dict = None
gdata_list = [GraphData.create(g) for g in g_list]
_CAPI_DGLSaveGraphs(filename, gdata_list, label_dict)
_CAPI_SaveDGLGraphs_V0(filename, gdata_list, label_dict)
def load_graphs(filename, idx_list=None):
......@@ -115,16 +126,19 @@ def load_graphs(filename, idx_list=None):
Parameters
----------
filename: str
filename to load DGLGraphs
filename to load graphs
idx_list: list of int
list of index of graph to be loaded. If not specified, will
load all graphs from file
Returns
----------
graph_list: list of immutable DGLGraphs
labels: dict of labels stored in file (empty dict returned if no
label stored)
--------
graph_list: list of DGLGraphs / DGLHeteroGraph
The loaded graphs.
labels: dict[str, Tensor]
The graph labels stored in file. If no label is stored, the dictionary is empty.
Regardless of whether the ``idx_list`` argument is given or not, the returned dictionary
always contains labels of all the graphs.
Examples
----------
......@@ -135,13 +149,34 @@ def load_graphs(filename, idx_list=None):
>>> glist, label_dict = load_graphs("./data.bin", [0]) # glist will be [g1]
"""
version = _CAPI_GetFileVersion(filename)
if version == 1:
return load_graph_v1(filename, idx_list)
elif version == 2:
return load_graph_v2(filename, idx_list)
else:
raise Exception("Invalid DGL Version Number")
def load_graph_v2(filename, idx_list=None):
"""Internal functions for loading DGLHeteroGraphs."""
if idx_list is None:
idx_list = []
assert isinstance(idx_list, list)
metadata = _CAPI_DGLLoadGraphs(filename, idx_list, False)
heterograph_list = _CAPI_LoadGraphFiles_V2(filename, idx_list)
label_dict = load_labels_v2(filename)
return [gdata.get_graph() for gdata in heterograph_list], label_dict
def load_graph_v1(filename, idx_list=None):
""""Internal functions for loading DGLGraphs (V0)."""
if idx_list is None:
idx_list = []
assert isinstance(idx_list, list)
metadata = _CAPI_LoadGraphFiles_V1(filename, idx_list, False)
label_dict = {}
for k, v in metadata.labels.items():
label_dict[k] = F.zerocopy_from_dgl_ndarray(v.data)
label_dict[k] = F.zerocopy_from_dgl_ndarray(v)
return [gdata.get_graph() for gdata in metadata.graph_data], label_dict
......@@ -169,8 +204,28 @@ def load_labels(filename):
>>> label_dict = load_graphs("./data.bin")
"""
metadata = _CAPI_DGLLoadGraphs(filename, [], True)
version = _CAPI_GetFileVersion(filename)
if version == 1:
return load_labels_v1(filename)
elif version == 2:
return load_labels_v2(filename)
else:
raise Exception("Invalid DGL Version Number")
def load_labels_v2(filename):
"""Internal functions for loading labels from V2 format"""
label_dict = {}
nd_dict = _CAPI_LoadLabels_V2(filename)
for k, v in nd_dict.items():
label_dict[k] = F.zerocopy_from_dgl_ndarray(v)
return label_dict
def load_labels_v1(filename):
"""Internal functions for loading labels from V1 format"""
metadata = _CAPI_LoadGraphFiles_V1(filename, [], True)
label_dict = {}
for k, v in metadata.labels.items():
label_dict[k] = F.zerocopy_from_dgl_ndarray(v.data)
label_dict[k] = F.zerocopy_from_dgl_ndarray(v)
return label_dict
"""For HeteroGraph Serialization"""
from __future__ import absolute_import
from ..heterograph import DGLHeteroGraph
from ..frame import Frame, FrameRef
from .._ffi.object import ObjectBase, register_object
from .._ffi.function import _init_api
from .. import backend as F
from ..container import convert_to_strmap
_init_api("dgl.data.heterograph_serialize")
def tensor_dict_to_ndarray_dict(tensor_dict):
"""Convert dict[str, tensor] to StrMap[NDArray]"""
ndarray_dict = {}
for key, value in tensor_dict.items():
ndarray_dict[key] = F.zerocopy_to_dgl_ndarray(value)
return convert_to_strmap(ndarray_dict)
def save_heterographs(filename, g_list, labels):
"""Save heterographs into file"""
if labels is None:
labels = {}
if isinstance(g_list, DGLHeteroGraph):
g_list = [g_list]
gdata_list = [HeteroGraphData.create(g) for g in g_list]
_CAPI_SaveHeteroGraphData(filename, gdata_list, tensor_dict_to_ndarray_dict(labels))
@register_object("heterograph_serialize.HeteroGraphData")
class HeteroGraphData(ObjectBase):
"""Object to hold the data to be stored for DGLHeteroGraph"""
@staticmethod
def create(g):
edata_list = []
ndata_list = []
for etype in g.etypes:
edata_list.append(tensor_dict_to_ndarray_dict(g.edges[etype].data))
for ntype in g.ntypes:
ndata_list.append(tensor_dict_to_ndarray_dict(g.nodes[ntype].data))
return _CAPI_MakeHeteroGraphData(g._graph, ndata_list, edata_list, g.ntypes, g.etypes)
def get_graph(self):
ntensor_list = list(_CAPI_GetNDataFromHeteroGraphData(self))
etensor_list = list(_CAPI_GetEDataFromHeteroGraphData(self))
ntype_names = list(_CAPI_GetNtypesFromHeteroGraphData(self))
etype_names = list(_CAPI_GetEtypesFromHeteroGraphData(self))
gidx = _CAPI_GetGindexFromHeteroGraphData(self)
nframes = []
eframes = []
for ntensor in ntensor_list:
ndict = {ntensor[i]: F.zerocopy_from_dgl_ndarray(ntensor[i+1]) for i in range(0, len(ntensor), 2)}
nframes.append(FrameRef(Frame(ndict)))
for etensor in etensor_list:
edict = {etensor[i]: F.zerocopy_from_dgl_ndarray(etensor[i+1]) for i in range(0, len(etensor), 2)}
eframes.append(FrameRef(Frame(edict)))
return DGLHeteroGraph(gidx, ntype_names, etype_names, nframes, eframes)
......@@ -61,7 +61,7 @@ def load_tensors(filename, return_dgl_ndarray=False):
tensor_dict = {}
for key, value in nd_dict.items():
if return_dgl_ndarray:
tensor_dict[key] = value.data
tensor_dict[key] = value
else:
tensor_dict[key] = F.zerocopy_from_dgl_ndarray(value.data)
tensor_dict[key] = F.zerocopy_from_dgl_ndarray(value)
return tensor_dict
......@@ -515,7 +515,7 @@ class RPCMessage(ObjectBase):
def tensors(self):
"""Get tensor payloads."""
rst = _CAPI_DGLRPCMessageGetTensors(self)
return [F.zerocopy_from_dgl_ndarray(tsor.data) for tsor in rst]
return [F.zerocopy_from_dgl_ndarray(tsor) for tsor in rst]
def send_request(target, request):
"""Send one request to the target server.
......
......@@ -4778,8 +4778,8 @@ def find_src_dst_ntypes(ntypes, metagraph):
return None
else:
src, dst = ret
srctypes = {ntypes[tid.data] : tid.data for tid in src}
dsttypes = {ntypes[tid.data] : tid.data for tid in dst}
srctypes = {ntypes[tid] : tid for tid in src}
dsttypes = {ntypes[tid] : tid for tid in dst}
return srctypes, dsttypes
def infer_ntype_from_dict(graph, etype_dict):
......
......@@ -1017,7 +1017,7 @@ class HeteroSubgraphIndex(ObjectBase):
Induced nodes
"""
ret = _CAPI_DGLHeteroSubgraphGetInducedVertices(self)
return [utils.toindex(v.data, self.graph.dtype) for v in ret]
return [utils.toindex(v, self.graph.dtype) for v in ret]
@property
def induced_edges(self):
......@@ -1030,7 +1030,7 @@ class HeteroSubgraphIndex(ObjectBase):
Induced edges
"""
ret = _CAPI_DGLHeteroSubgraphGetInducedEdges(self)
return [utils.toindex(v.data, self.graph.dtype) for v in ret]
return [utils.toindex(v, self.graph.dtype) for v in ret]
#################################################################
......
......@@ -172,7 +172,7 @@ class SparseMatrix(ObjectBase):
-------
list of boolean
"""
return [v.data for v in _CAPI_DGLSparseMatrixGetFlags(self)]
return [v for v in _CAPI_DGLSparseMatrixGetFlags(self)]
def __getstate__(self):
return self.format, self.num_rows, self.num_cols, self.indices, self.flags
......
......@@ -153,8 +153,8 @@ def random_walk(g, nodes, *, metapath=None, length=None, prob=None, restart_prob
traces, types = _CAPI_DGLSamplingRandomWalkWithRestart(
gidx, nodes, metapath, p_nd, restart_prob)
traces = F.zerocopy_from_dgl_ndarray(traces.data)
types = F.zerocopy_from_dgl_ndarray(types.data)
traces = F.zerocopy_from_dgl_ndarray(traces)
types = F.zerocopy_from_dgl_ndarray(types)
return traces, types
def pack_traces(traces, types):
......@@ -221,10 +221,10 @@ def pack_traces(traces, types):
concat_vids, concat_types, lengths, offsets = _CAPI_DGLSamplingPackTraces(traces, types)
concat_vids = F.zerocopy_from_dgl_ndarray(concat_vids.data)
concat_types = F.zerocopy_from_dgl_ndarray(concat_types.data)
lengths = F.zerocopy_from_dgl_ndarray(lengths.data)
offsets = F.zerocopy_from_dgl_ndarray(offsets.data)
concat_vids = F.zerocopy_from_dgl_ndarray(concat_vids)
concat_types = F.zerocopy_from_dgl_ndarray(concat_types)
lengths = F.zerocopy_from_dgl_ndarray(lengths)
offsets = F.zerocopy_from_dgl_ndarray(offsets)
return concat_vids, concat_types, lengths, offsets
......
......@@ -906,7 +906,7 @@ def compact_graphs(graphs, always_preserve=None):
# Compact and construct heterographs
new_graph_indexes, induced_nodes = _CAPI_DGLCompactGraphs(
[g._graph for g in graphs], always_preserve_nd)
induced_nodes = [F.zerocopy_from_dgl_ndarray(nodes.data) for nodes in induced_nodes]
induced_nodes = [F.zerocopy_from_dgl_ndarray(nodes) for nodes in induced_nodes]
new_graphs = [
DGLHeteroGraph(new_graph_index, graph.ntypes, graph.etypes)
......@@ -1063,7 +1063,7 @@ def to_block(g, dst_nodes=None, include_dst_in_src=True):
assert new_graph.is_unibipartite # sanity check
for i, ntype in enumerate(g.ntypes):
new_graph.srcnodes[ntype].data[NID] = F.zerocopy_from_dgl_ndarray(src_nodes_nd[i].data)
new_graph.srcnodes[ntype].data[NID] = F.zerocopy_from_dgl_ndarray(src_nodes_nd[i])
if ntype in dst_nodes:
new_graph.dstnodes[ntype].data[NID] = dst_nodes[ntype]
else:
......@@ -1071,7 +1071,7 @@ def to_block(g, dst_nodes=None, include_dst_in_src=True):
new_graph.dstnodes[ntype].data[NID] = F.tensor([], dtype=g.idtype)
for i, canonical_etype in enumerate(g.canonical_etypes):
induced_edges = F.zerocopy_from_dgl_ndarray(induced_edges_nd[i].data)
induced_edges = F.zerocopy_from_dgl_ndarray(induced_edges_nd[i])
utype, etype, vtype = canonical_etype
new_canonical_etype = (utype, etype, vtype)
new_graph.edges[new_canonical_etype].data[EID] = induced_edges
......@@ -1114,7 +1114,7 @@ def remove_edges(g, edge_ids):
new_graph = DGLHeteroGraph(new_graph_index, g.ntypes, g.etypes)
for i, canonical_etype in enumerate(g.canonical_etypes):
data = induced_eids_nd[i].data
data = induced_eids_nd[i]
if len(data) == 0:
# Empty means that either
# (1) no edges are removed and edges are not shuffled.
......@@ -1256,8 +1256,8 @@ def to_simple(g, return_counts='count', writeback_mapping=None):
"""
simple_graph_index, counts, edge_maps = _CAPI_DGLToSimpleHetero(g._graph)
simple_graph = DGLHeteroGraph(simple_graph_index, g.ntypes, g.etypes)
counts = [F.zerocopy_from_dgl_ndarray(count.data) for count in counts]
edge_maps = [F.zerocopy_from_dgl_ndarray(edge_map.data) for edge_map in edge_maps]
counts = [F.zerocopy_from_dgl_ndarray(count) for count in counts]
edge_maps = [F.zerocopy_from_dgl_ndarray(edge_map) for edge_map in edge_maps]
if return_counts is not None:
for count, canonical_etype in zip(counts, g.canonical_etypes):
......
......@@ -71,6 +71,13 @@ DGL_REGISTER_GLOBAL("_Map")
}
});
DGL_REGISTER_GLOBAL("_EmptyStrMap").set_body([](DGLArgs args, DGLRetValue* rv) {
StrMapObject::ContainerType data;
auto obj = std::make_shared<StrMapObject>();
obj->data = std::move(data);
*rv = obj;
});
DGL_REGISTER_GLOBAL("_MapSize")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
auto& sptr = args[0].obj_sptr();
......
/*!
* Copyright (c) 2019 by Contributors
* \file graph/serialize/dglgraph_data.h
* \brief Graph serialization header
*/
#ifndef DGL_GRAPH_SERIALIZE_DGLGRAPH_DATA_H_
#define DGL_GRAPH_SERIALIZE_DGLGRAPH_DATA_H_
#include <dgl/graph.h>
#include <dgl/array.h>
#include <dgl/immutable_graph.h>
#include <dmlc/io.h>
#include <dmlc/type_traits.h>
#include <dgl/runtime/ndarray.h>
#include <dgl/runtime/container.h>
#include <dgl/runtime/object.h>
#include <dgl/packed_func_ext.h>
#include <iostream>
#include <string>
#include <vector>
#include <algorithm>
#include <utility>
#include "../../c_api_common.h"
using dgl::runtime::NDArray;
using dgl::ImmutableGraph;
using namespace dgl::runtime;
namespace dgl {
namespace serialize {
typedef std::pair<std::string, NDArray> NamedTensor;
class GraphDataObject : public runtime::Object {
public:
ImmutableGraphPtr gptr;
std::vector<NamedTensor> node_tensors;
std::vector<NamedTensor> edge_tensors;
static constexpr const char *_type_key = "graph_serialize.GraphData";
void SetData(ImmutableGraphPtr gptr,
Map<std::string, Value> node_tensors,
Map<std::string, Value> edge_tensors);
void Save(dmlc::Stream *fs) const;
bool Load(dmlc::Stream *fs);
DGL_DECLARE_OBJECT_TYPE_INFO(GraphDataObject, runtime::Object);
};
class GraphData : public runtime::ObjectRef {
public:
DGL_DEFINE_OBJECT_REF_METHODS(GraphData, runtime::ObjectRef, GraphDataObject);
/*! \brief create a new GraphData reference */
static GraphData Create() {
return GraphData(std::make_shared<GraphDataObject>());
}
};
ImmutableGraphPtr ToImmutableGraph(GraphPtr g);
} // namespace serialize
} // namespace dgl
#endif // DGL_GRAPH_SERIALIZE_DGLGRAPH_DATA_H_
/*!
* Copyright (c) 2019 by Contributors
* \file graph/serialize/graph_serialize.cc
* \brief Graph serialization implementation
*
* The storage structure is
* {
* // MetaData Section
* uint64_t kDGLSerializeMagic
* uint64_t kVersion
* uint64_t GraphType
* ** Reserved Area till 4kB **
*
* dgl_id_t num_graphs
* vector<dgl_id_t> graph_indices (start address of each graph)
* vector<dgl_id_t> nodes_num_list (list of number of nodes for each graph)
* vector<dgl_id_t> edges_num_list (list of number of edges for each graph)
*
* vector<GraphData> graph_datas;
*
* }
*
* Storage of GraphData is
* {
* // Everything uses in csr
* NDArray indptr
* NDArray indices
* NDArray edge_ids
* vector<pair<string, NDArray>> node_tensors;
* vector<pair<string, NDArray>> edge_tensors;
* }
*
*/
#include <dgl/graph_op.h>
#include <dgl/immutable_graph.h>
#include <dgl/runtime/container.h>
#include <dgl/runtime/object.h>
#include <dmlc/io.h>
#include <dmlc/type_traits.h>
#include <algorithm>
#include <iostream>
#include <string>
#include <utility>
#include <vector>
#include "graph_serialize.h"
using namespace dgl::runtime;
using dgl::COO;
using dgl::COOPtr;
using dgl::ImmutableGraph;
using dgl::runtime::NDArray;
using dgl::serialize::GraphData;
using dgl::serialize::GraphDataObject;
using dmlc::SeekStream;
using std::vector;
namespace dmlc {
DMLC_DECLARE_TRAITS(has_saveload, GraphDataObject, true);
}
namespace dgl {
namespace serialize {
bool SaveDGLGraphs(std::string filename, List<GraphData> graph_data,
std::vector<NamedTensor> labels_list) {
auto fs = std::unique_ptr<SeekStream>(dynamic_cast<SeekStream *>(
SeekStream::Create(filename.c_str(), "w", true)));
CHECK(fs) << "File name is not a valid local file name";
// Write DGL MetaData
const uint64_t kVersion = 1;
fs->Write(kDGLSerializeMagic);
fs->Write(kVersion);
fs->Write(GraphType::kImmutableGraph);
fs->Seek(4096);
// Write Graph Meta Data
dgl_id_t num_graph = graph_data.size();
std::vector<dgl_id_t> graph_indices(num_graph);
std::vector<int64_t> nodes_num_list(num_graph);
std::vector<int64_t> edges_num_list(num_graph);
for (uint64_t i = 0; i < num_graph; ++i) {
nodes_num_list[i] = graph_data[i]->gptr->NumVertices();
edges_num_list[i] = graph_data[i]->gptr->NumEdges();
}
// Reserve spaces for graph indices
fs->Write(num_graph);
dgl_id_t indices_start_ptr = fs->Tell();
fs->Write(graph_indices);
fs->Write(nodes_num_list);
fs->Write(edges_num_list);
fs->Write(labels_list);
// Write GraphData
for (uint64_t i = 0; i < num_graph; ++i) {
graph_indices[i] = fs->Tell();
GraphDataObject gdata = *graph_data[i].as<GraphDataObject>();
fs->Write(gdata);
}
fs->Seek(indices_start_ptr);
fs->Write(graph_indices);
return true;
}
StorageMetaData LoadDGLGraphs(const std::string &filename,
std::vector<dgl_id_t> idx_list, bool onlyMeta) {
auto fs = std::unique_ptr<SeekStream>(
SeekStream::CreateForRead(filename.c_str(), true));
CHECK(fs) << "Filename is invalid";
// Read DGL MetaData
uint64_t magicNum, graphType, version;
fs->Read(&magicNum);
fs->Read(&version);
fs->Read(&graphType);
fs->Seek(4096);
CHECK_EQ(magicNum, kDGLSerializeMagic) << "Invalid DGL files";
CHECK_EQ(version, 1) << "Invalid DGL files";
StorageMetaData metadata = StorageMetaData::Create();
// Read Graph MetaData
dgl_id_t num_graph;
CHECK(fs->Read(&num_graph)) << "Invalid num of graph";
std::vector<dgl_id_t> graph_indices;
std::vector<int64_t> nodes_num_list;
std::vector<int64_t> edges_num_list;
std::vector<NamedTensor> labels_list;
CHECK(fs->Read(&graph_indices)) << "Invalid graph indices";
CHECK(fs->Read(&nodes_num_list)) << "Invalid node num list";
CHECK(fs->Read(&edges_num_list)) << "Invalid edge num list";
CHECK(fs->Read(&labels_list)) << "Invalid label list";
metadata->SetMetaData(num_graph, nodes_num_list, edges_num_list, labels_list);
std::vector<GraphData> gdata_refs;
// Early Return
if (onlyMeta) {
return metadata;
}
if (idx_list.empty()) {
// Read All Graphs
gdata_refs.reserve(num_graph);
for (uint64_t i = 0; i < num_graph; ++i) {
GraphData gdata = GraphData::Create();
GraphDataObject *gdata_ptr =
const_cast<GraphDataObject *>(gdata.as<GraphDataObject>());
fs->Read(gdata_ptr);
gdata_refs.push_back(gdata);
}
} else {
// Read Selected Graphss
gdata_refs.reserve(idx_list.size());
// Would be better if idx_list is sorted. However the returned the graphs
// should be the same order as the idx_list
for (uint64_t i = 0; i < idx_list.size(); ++i) {
fs->Seek(graph_indices[idx_list[i]]);
GraphData gdata = GraphData::Create();
GraphDataObject *gdata_ptr =
const_cast<GraphDataObject *>(gdata.as<GraphDataObject>());
fs->Read(gdata_ptr);
gdata_refs.push_back(gdata);
}
}
metadata->SetGraphData(gdata_refs);
return metadata;
}
void GraphDataObject::SetData(ImmutableGraphPtr gptr,
Map<std::string, Value> node_tensors,
Map<std::string, Value> edge_tensors) {
this->gptr = gptr;
for (auto kv : node_tensors) {
std::string name = kv.first;
Value v = kv.second;
NDArray ndarray = static_cast<NDArray>(v->data);
this->node_tensors.emplace_back(name, ndarray);
}
for (auto kv : edge_tensors) {
std::string &name = kv.first;
Value v = kv.second;
const NDArray &ndarray = static_cast<NDArray>(v->data);
this->edge_tensors.emplace_back(name, ndarray);
}
}
void GraphDataObject::Save(dmlc::Stream *fs) const {
// Using in csr for storage
const CSRPtr g_csr = this->gptr->GetInCSR();
fs->Write(g_csr->indptr());
fs->Write(g_csr->indices());
fs->Write(g_csr->edge_ids());
fs->Write(node_tensors);
fs->Write(edge_tensors);
}
bool GraphDataObject::Load(dmlc::Stream *fs) {
NDArray indptr, indices, edge_ids;
fs->Read(&indptr);
fs->Read(&indices);
fs->Read(&edge_ids);
this->gptr = ImmutableGraph::CreateFromCSR(indptr, indices, edge_ids, "in");
fs->Read(&this->node_tensors);
fs->Read(&this->edge_tensors);
return true;
}
ImmutableGraphPtr BatchLoadedGraphs(std::vector<GraphData> gdata_list) {
std::vector<GraphPtr> gptrs;
gptrs.reserve(gdata_list.size());
for (auto gdata : gdata_list) {
gptrs.push_back(static_cast<GraphPtr>(gdata->gptr));
}
ImmutableGraphPtr imGPtr =
std::dynamic_pointer_cast<ImmutableGraph>(GraphOp::DisjointUnion(gptrs));
return imGPtr;
}
ImmutableGraphPtr ToImmutableGraph(GraphPtr g) {
ImmutableGraphPtr imgr = std::dynamic_pointer_cast<ImmutableGraph>(g);
if (imgr) {
return imgr;
} else {
MutableGraphPtr mgr = std::dynamic_pointer_cast<Graph>(g);
CHECK(mgr) << "Invalid Graph Pointer";
EdgeArray earray = mgr->Edges("eid");
IdArray srcs_array = earray.src;
IdArray dsts_array = earray.dst;
ImmutableGraphPtr imgptr =
ImmutableGraph::CreateFromCOO(mgr->NumVertices(), srcs_array, dsts_array);
return imgptr;
}
}
void StorageMetaDataObject::SetMetaData(dgl_id_t num_graph,
std::vector<int64_t> nodes_num_list,
std::vector<int64_t> edges_num_list,
std::vector<NamedTensor> labels_list) {
this->num_graph = num_graph;
this->nodes_num_list = Value(MakeValue(aten::VecToIdArray(nodes_num_list)));
this->edges_num_list = Value(MakeValue(aten::VecToIdArray(edges_num_list)));
for (auto kv : labels_list) {
this->labels_list.Set(kv.first, Value(MakeValue(kv.second)));
}
}
void StorageMetaDataObject::SetGraphData(std::vector<GraphData> gdata) {
this->graph_data = List<GraphData>(gdata);
}
} // namespace serialize
} // namespace dgl
......@@ -32,28 +32,32 @@
*
*/
#include "graph_serialize.h"
#include <dmlc/io.h>
#include <dmlc/type_traits.h>
#include <dgl/runtime/container.h>
#include <dgl/graph_op.h>
#include <dgl/immutable_graph.h>
#include <dgl/runtime/container.h>
#include <dgl/runtime/object.h>
#include <dgl/graph_op.h>
#include <dmlc/io.h>
#include <dmlc/logging.h>
#include <dmlc/type_traits.h>
#include <algorithm>
#include <iostream>
#include <string>
#include <vector>
#include <algorithm>
#include <utility>
#include <vector>
using namespace dgl::runtime;
using dgl::COO;
using dgl::COOPtr;
using dgl::ImmutableGraph;
using dmlc::SeekStream;
using dgl::runtime::NDArray;
using std::vector;
using dgl::serialize::GraphData;
using dgl::serialize::GraphDataObject;
using dmlc::SeekStream;
using dmlc::Stream;
using std::vector;
namespace dmlc {
DMLC_DECLARE_TRAITS(has_saveload, GraphDataObject, true);
......@@ -62,13 +66,8 @@ DMLC_DECLARE_TRAITS(has_saveload, GraphDataObject, true);
namespace dgl {
namespace serialize {
enum GraphType {
kMutableGraph = 0ull,
kImmutableGraph = 1ull
};
DGL_REGISTER_GLOBAL("data.graph_serialize._CAPI_MakeGraphData")
.set_body([](DGLArgs args, DGLRetValue *rv) {
.set_body([](DGLArgs args, DGLRetValue *rv) {
GraphRef gptr = args[0];
ImmutableGraphPtr imGPtr = ToImmutableGraph(gptr.sptr());
Map<std::string, Value> node_tensors = args[1];
......@@ -76,10 +75,10 @@ DGL_REGISTER_GLOBAL("data.graph_serialize._CAPI_MakeGraphData")
GraphData gd = GraphData::Create();
gd->SetData(imGPtr, node_tensors, edge_tensors);
*rv = gd;
});
});
DGL_REGISTER_GLOBAL("data.graph_serialize._CAPI_DGLSaveGraphs")
.set_body([](DGLArgs args, DGLRetValue *rv) {
DGL_REGISTER_GLOBAL("data.graph_serialize._CAPI_SaveDGLGraphs_V0")
.set_body([](DGLArgs args, DGLRetValue *rv) {
std::string filename = args[0];
List<GraphData> graph_data = args[1];
Map<std::string, Value> labels = args[2];
......@@ -91,257 +90,67 @@ DGL_REGISTER_GLOBAL("data.graph_serialize._CAPI_DGLSaveGraphs")
labels_list.emplace_back(name, ndarray);
}
SaveDGLGraphs(filename, graph_data, labels_list);
});
DGL_REGISTER_GLOBAL("data.graph_serialize._CAPI_DGLLoadGraphs")
.set_body([](DGLArgs args, DGLRetValue *rv) {
std::string filename = args[0];
List<Value> idxs = args[1];
bool onlyMeta = args[2];
std::vector<dgl_id_t> idx_list(idxs.size());
for (uint64_t i = 0; i < idxs.size(); ++i) {
idx_list[i] = static_cast<dgl_id_t >(idxs[i]->data);
}
*rv = LoadDGLGraphs(filename, idx_list, onlyMeta);
});
});
DGL_REGISTER_GLOBAL("data.graph_serialize._CAPI_GDataGraphHandle")
.set_body([](DGLArgs args, DGLRetValue *rv) {
.set_body([](DGLArgs args, DGLRetValue *rv) {
GraphData gdata = args[0];
*rv = gdata->gptr;
});
});
DGL_REGISTER_GLOBAL("data.graph_serialize._CAPI_GDataNodeTensors")
.set_body([](DGLArgs args, DGLRetValue *rv) {
.set_body([](DGLArgs args, DGLRetValue *rv) {
GraphData gdata = args[0];
Map<std::string, Value> rvmap;
for (auto kv : gdata->node_tensors) {
rvmap.Set(kv.first, Value(MakeValue(kv.second)));
}
*rv = rvmap;
});
});
DGL_REGISTER_GLOBAL("data.graph_serialize._CAPI_GDataEdgeTensors")
.set_body([](DGLArgs args, DGLRetValue *rv) {
.set_body([](DGLArgs args, DGLRetValue *rv) {
GraphData gdata = args[0];
Map<std::string, Value> rvmap;
for (auto kv : gdata->edge_tensors) {
rvmap.Set(kv.first, Value(MakeValue(kv.second)));
}
*rv = rvmap;
});
constexpr uint64_t kDGLSerializeMagic = 0xDD2E4FF046B4A13F;
});
bool SaveDGLGraphs(std::string filename,
List<GraphData> graph_data,
std::vector<NamedTensor> labels_list) {
auto *fs = dynamic_cast<SeekStream *>(SeekStream::Create(filename.c_str(), "w",
true));
CHECK(fs) << "File name is not a valid local file name";
// Write DGL MetaData
const uint64_t kVersion = 1;
fs->Write(kDGLSerializeMagic);
fs->Write(kVersion);
fs->Write(kImmutableGraph);
fs->Seek(4096);
// Write Graph Meta Data
dgl_id_t num_graph = graph_data.size();
std::vector<dgl_id_t> graph_indices(num_graph);
std::vector<int64_t> nodes_num_list(num_graph);
std::vector<int64_t> edges_num_list(num_graph);
for (uint64_t i = 0; i < num_graph; ++i) {
nodes_num_list[i] = graph_data[i]->gptr->NumVertices();
edges_num_list[i] = graph_data[i]->gptr->NumEdges();
}
// Reserve spaces for graph indices
fs->Write(num_graph);
dgl_id_t indices_start_ptr = fs->Tell();
fs->Write(graph_indices);
fs->Write(nodes_num_list);
fs->Write(edges_num_list);
fs->Write(labels_list);
// Write GraphData
for (uint64_t i = 0; i < num_graph; ++i) {
graph_indices[i] = fs->Tell();
GraphDataObject gdata = *graph_data[i].as<GraphDataObject>();
fs->Write(gdata);
}
fs->Seek(indices_start_ptr);
fs->Write(graph_indices);
std::vector<dgl_id_t> test;
fs->Seek(indices_start_ptr);
fs->Read(&test);
delete fs;
return true;
}
StorageMetaData LoadDGLGraphs(const std::string &filename,
std::vector<dgl_id_t> idx_list,
bool onlyMeta) {
SeekStream *fs = SeekStream::CreateForRead(filename.c_str(), true);
CHECK(fs) << "Filename is invalid";
StorageMetaData metadata = StorageMetaData::Create();
// Read DGL MetaData
uint64_t magicNum, graphType, version;
uint64_t GetFileVersion(const std::string &filename) {
auto fs = std::unique_ptr<SeekStream>(
SeekStream::CreateForRead(filename.c_str(), false));
CHECK(fs) << "File " << filename << " not found";
uint64_t magicNum, version;
fs->Read(&magicNum);
fs->Read(&graphType);
fs->Read(&version);
fs->Seek(4096);
CHECK_EQ(magicNum, kDGLSerializeMagic) << "Invalid DGL files";
CHECK_EQ(graphType, kImmutableGraph) << "Invalid DGL files";
CHECK_EQ(version, 1) << "Invalid Serialization Version";
// Read Graph MetaData
dgl_id_t num_graph;
CHECK(fs->Read(&num_graph)) << "Invalid num of graph";
std::vector<dgl_id_t> graph_indices;
std::vector<int64_t> nodes_num_list;
std::vector<int64_t> edges_num_list;
std::vector<NamedTensor> labels_list;
CHECK(fs->Read(&graph_indices)) << "Invalid graph indices";
CHECK(fs->Read(&nodes_num_list)) << "Invalid node num list";
CHECK(fs->Read(&edges_num_list)) << "Invalid edge num list";
CHECK(fs->Read(&labels_list)) << "Invalid label list";
metadata->SetMetaData(num_graph, nodes_num_list, edges_num_list, labels_list);
std::vector<GraphData> gdata_refs;
// Early Return
if (onlyMeta) {
delete fs;
return metadata;
}
if (idx_list.empty()) {
// Read All Graphs
gdata_refs.reserve(num_graph);
for (uint64_t i = 0; i < num_graph; ++i) {
GraphData gdata = GraphData::Create();
GraphDataObject *gdata_ptr =
const_cast<GraphDataObject *>(gdata.as<GraphDataObject>());
fs->Read(gdata_ptr);
gdata_refs.push_back(gdata);
}
} else {
// Read Selected Graphss
gdata_refs.reserve(idx_list.size());
// Would be better if idx_list is sorted. However the returned the graphs should be the same
// order as the idx_list
for (uint64_t i = 0; i < idx_list.size(); ++i) {
fs->Seek(graph_indices[idx_list[i]]);
GraphData gdata = GraphData::Create();
GraphDataObject *gdata_ptr =
const_cast<GraphDataObject *>(gdata.as<GraphDataObject>());
fs->Read(gdata_ptr);
gdata_refs.push_back(gdata);
}
}
metadata->SetGraphData(gdata_refs);
delete fs;
return metadata;
}
void GraphDataObject::SetData(ImmutableGraphPtr gptr,
Map<std::string, Value> node_tensors,
Map<std::string, Value> edge_tensors) {
this->gptr = gptr;
for (auto kv : node_tensors) {
std::string name = kv.first;
Value v = kv.second;
NDArray ndarray = static_cast<NDArray>(v->data);
this->node_tensors.emplace_back(name, ndarray);
}
for (auto kv : edge_tensors) {
std::string &name = kv.first;
Value v = kv.second;
const NDArray &ndarray = static_cast<NDArray>(v->data);
this->edge_tensors.emplace_back(name, ndarray);
}
}
void GraphDataObject::Save(dmlc::Stream *fs) const {
// Using in csr for storage
const CSRPtr g_csr = this->gptr->GetInCSR();
fs->Write(g_csr->indptr());
fs->Write(g_csr->indices());
fs->Write(g_csr->edge_ids());
fs->Write(node_tensors);
fs->Write(edge_tensors);
return version;
}
bool GraphDataObject::Load(dmlc::Stream *fs) {
NDArray indptr, indices, edge_ids;
fs->Read(&indptr);
fs->Read(&indices);
fs->Read(&edge_ids);
this->gptr = ImmutableGraph::CreateFromCSR(indptr, indices, edge_ids, "in");
fs->Read(&this->node_tensors);
fs->Read(&this->edge_tensors);
return true;
}
ImmutableGraphPtr BatchLoadedGraphs(std::vector<GraphData> gdata_list) {
std::vector<GraphPtr> gptrs;
gptrs.reserve(gdata_list.size());
for (auto gdata : gdata_list) {
gptrs.push_back(static_cast<GraphPtr>(gdata->gptr));
}
ImmutableGraphPtr imGPtr = std::dynamic_pointer_cast<ImmutableGraph>(
GraphOp::DisjointUnion(gptrs));
return imGPtr;
}
ImmutableGraphPtr ToImmutableGraph(GraphPtr g) {
ImmutableGraphPtr imgr = std::dynamic_pointer_cast<ImmutableGraph>(g);
if (imgr) {
return imgr;
} else {
MutableGraphPtr mgr = std::dynamic_pointer_cast<Graph>(g);
CHECK(mgr) << "Invalid Graph Pointer";
EdgeArray earray = mgr->Edges("eid");
IdArray srcs_array = earray.src;
IdArray dsts_array = earray.dst;
ImmutableGraphPtr imgptr = ImmutableGraph::CreateFromCOO(mgr->NumVertices(), srcs_array,
dsts_array);
return imgptr;
}
}
DGL_REGISTER_GLOBAL("data.graph_serialize._CAPI_GetFileVersion")
.set_body([](DGLArgs args, DGLRetValue *rv) {
std::string filename = args[0];
*rv = static_cast<int64_t>(GetFileVersion(filename));
});
void StorageMetaDataObject::SetMetaData(dgl_id_t num_graph,
std::vector<int64_t> nodes_num_list,
std::vector<int64_t> edges_num_list,
std::vector<NamedTensor> labels_list) {
this->num_graph = num_graph;
this->nodes_num_list = Value(MakeValue(aten::VecToIdArray(nodes_num_list)));
this->edges_num_list = Value(MakeValue(aten::VecToIdArray(edges_num_list)));
for (auto kv : labels_list) {
this->labels_list.Set(kv.first, Value(MakeValue(kv.second)));
}
}
DGL_REGISTER_GLOBAL("data.graph_serialize._CAPI_LoadGraphFiles_V1")
.set_body([](DGLArgs args, DGLRetValue *rv) {
std::string filename = args[0];
List<Value> idxs = args[1];
bool onlyMeta = args[2];
auto idx_list = ListValueToVector<dgl_id_t>(idxs);
*rv = LoadDGLGraphs(filename, idx_list, onlyMeta);
});
void StorageMetaDataObject::SetGraphData(std::vector<GraphData> gdata) {
this->graph_data = List<GraphData>(gdata);
}
DGL_REGISTER_GLOBAL("data.graph_serialize._CAPI_LoadGraphFiles_V2")
.set_body([](DGLArgs args, DGLRetValue *rv) {
std::string filename = args[0];
List<Value> idxs = args[1];
auto idx_list = ListValueToVector<dgl_id_t>(idxs);
*rv = List<HeteroGraphData>(LoadHeteroGraphs(filename, idx_list));
});
} // namespace serialize
} // namespace dgl
......@@ -6,72 +6,54 @@
#ifndef DGL_GRAPH_SERIALIZE_GRAPH_SERIALIZE_H_
#define DGL_GRAPH_SERIALIZE_GRAPH_SERIALIZE_H_
#include <dgl/graph.h>
#include <dgl/array.h>
#include <dgl/graph.h>
#include <dgl/immutable_graph.h>
#include <dmlc/io.h>
#include <dmlc/type_traits.h>
#include <dgl/runtime/ndarray.h>
#include <dgl/packed_func_ext.h>
#include <dgl/runtime/container.h>
#include <dgl/runtime/ndarray.h>
#include <dgl/runtime/object.h>
#include <dgl/packed_func_ext.h>
#include <dmlc/io.h>
#include <dmlc/type_traits.h>
#include <algorithm>
#include <iostream>
#include <memory>
#include <string>
#include <vector>
#include <algorithm>
#include <utility>
#include <memory>
#include <vector>
#include "../../c_api_common.h"
#include "dglgraph_data.h"
#include "heterograph_data.h"
using dgl::runtime::NDArray;
using dgl::ImmutableGraph;
using dgl::runtime::NDArray;
using namespace dgl::runtime;
namespace dgl {
namespace serialize {
typedef std::pair<std::string, NDArray> NamedTensor;
class GraphDataObject : public runtime::Object {
public:
ImmutableGraphPtr gptr;
std::vector<NamedTensor> node_tensors;
std::vector<NamedTensor> edge_tensors;
static constexpr const char *_type_key = "graph_serialize.GraphData";
void SetData(ImmutableGraphPtr gptr,
Map<std::string, Value> node_tensors,
Map<std::string, Value> edge_tensors);
void Save(dmlc::Stream *fs) const;
bool Load(dmlc::Stream *fs);
DGL_DECLARE_OBJECT_TYPE_INFO(GraphDataObject, runtime::Object);
enum GraphType : uint64_t {
kMutableGraph = 0ull,
kImmutableGraph = 1ull,
kHeteroGraph = 2ull
};
class GraphData : public runtime::ObjectRef {
public:
DGL_DEFINE_OBJECT_REF_METHODS(GraphData, runtime::ObjectRef, GraphDataObject);
/*! \brief create a new GraphData reference */
static GraphData Create() {
return GraphData(std::make_shared<GraphDataObject>());
}
};
constexpr uint64_t kDGLSerializeMagic = 0xDD2E4FF046B4A13F;
class StorageMetaDataObject : public runtime::Object {
public:
// For saving DGLGraph
dgl_id_t num_graph;
Value nodes_num_list;
Value edges_num_list;
Map<std::string, Value> labels_list;
List<GraphData> graph_data;
static constexpr const char *_type_key = "graph_serialize.StorageMetaData";
void SetMetaData(dgl_id_t num_graph,
std::vector<int64_t> nodes_num_list,
void SetMetaData(dgl_id_t num_graph, std::vector<int64_t> nodes_num_list,
std::vector<int64_t> edges_num_list,
std::vector<NamedTensor> labels_list);
......@@ -88,10 +70,10 @@ class StorageMetaDataObject : public runtime::Object {
DGL_DECLARE_OBJECT_TYPE_INFO(StorageMetaDataObject, runtime::Object);
};
class StorageMetaData : public runtime::ObjectRef {
public:
DGL_DEFINE_OBJECT_REF_METHODS(StorageMetaData, runtime::ObjectRef, StorageMetaDataObject);
DGL_DEFINE_OBJECT_REF_METHODS(StorageMetaData, runtime::ObjectRef,
StorageMetaDataObject);
/*! \brief create a new StorageMetaData reference */
static StorageMetaData Create() {
......@@ -99,14 +81,18 @@ class StorageMetaData : public runtime::ObjectRef {
}
};
StorageMetaData LoadDGLGraphFiles(const std::string &filename,
std::vector<dgl_id_t> idx_list,
bool onlyMeta);
StorageMetaData LoadDGLGraphs(const std::string &filename,
std::vector<dgl_id_t> idx_list, bool onlyMeta);
bool SaveDGLGraphs(std::string filename,
List<GraphData> graph_data,
bool SaveDGLGraphs(std::string filename, List<GraphData> graph_data,
std::vector<NamedTensor> labels_list);
StorageMetaData LoadDGLGraphs(const std::string &filename,
std::vector<dgl_id_t> idx_list,
bool onlyMeta = false);
std::vector<HeteroGraphData> LoadHeteroGraphs(const std::string &filename,
std::vector<dgl_id_t> idx_list);
ImmutableGraphPtr ToImmutableGraph(GraphPtr g);
......
/*!
* Copyright (c) 2019 by Contributors
* \file graph/serialize/heterograph_data.h
* \brief Graph serialization header
*/
#ifndef DGL_GRAPH_SERIALIZE_HETEROGRAPH_DATA_H_
#define DGL_GRAPH_SERIALIZE_HETEROGRAPH_DATA_H_
#include <dgl/array.h>
#include <dgl/graph.h>
#include <dgl/packed_func_ext.h>
#include <dgl/runtime/container.h>
#include <dgl/runtime/ndarray.h>
#include <dgl/runtime/object.h>
#include <dmlc/io.h>
#include <dmlc/type_traits.h>
#include <algorithm>
#include <iostream>
#include <string>
#include <utility>
#include <vector>
#include "../../c_api_common.h"
#include "../heterograph.h"
using dgl::runtime::NDArray;
using namespace dgl::runtime;
namespace dgl {
namespace serialize {
typedef std::pair<std::string, NDArray> NamedTensor;
class HeteroGraphDataObject : public runtime::Object {
public:
std::shared_ptr<HeteroGraph> gptr;
std::vector<std::vector<NamedTensor>> node_tensors;
std::vector<std::vector<NamedTensor>> edge_tensors;
std::vector<std::string> etype_names;
std::vector<std::string> ntype_names;
static constexpr const char *_type_key =
"heterograph_serialize.HeteroGraphData";
HeteroGraphDataObject() {}
HeteroGraphDataObject(HeteroGraphPtr gptr,
List<Map<std::string, Value>> ndata,
List<Map<std::string, Value>> edata,
List<Value> ntype_names, List<Value> etype_names) {
this->gptr = std::dynamic_pointer_cast<HeteroGraph>(gptr);
CHECK_NOTNULL(this->gptr);
for (auto nd_dict : ndata) {
node_tensors.emplace_back();
for (auto kv : nd_dict) {
auto last = &node_tensors.back();
NDArray ndarray = kv.second->data;
last->emplace_back(kv.first, ndarray);
}
}
for (auto nd_dict : edata) {
edge_tensors.emplace_back();
for (auto kv : nd_dict) {
auto last = &edge_tensors.back();
NDArray ndarray = kv.second->data;
last->emplace_back(kv.first, ndarray);
}
}
this->ntype_names = ListValueToVector<std::string>(ntype_names);
this->etype_names = ListValueToVector<std::string>(etype_names);
}
void Save(dmlc::Stream *fs) const {
fs->Write(gptr);
fs->Write(node_tensors);
fs->Write(edge_tensors);
fs->Write(ntype_names);
fs->Write(etype_names);
}
bool Load(dmlc::Stream *fs) {
fs->Read(&gptr);
fs->Read(&node_tensors);
fs->Read(&edge_tensors);
fs->Read(&ntype_names);
fs->Read(&etype_names);
return true;
}
DGL_DECLARE_OBJECT_TYPE_INFO(HeteroGraphDataObject, runtime::Object);
};
class HeteroGraphData : public runtime::ObjectRef {
public:
DGL_DEFINE_OBJECT_REF_METHODS(HeteroGraphData, runtime::ObjectRef,
HeteroGraphDataObject);
/*! \brief create a new GraphData reference */
static HeteroGraphData Create(HeteroGraphPtr gptr,
List<Map<std::string, Value>> node_tensors,
List<Map<std::string, Value>> edge_tensors,
List<Value> ntype_names,
List<Value> etype_names) {
return HeteroGraphData(std::make_shared<HeteroGraphDataObject>(
gptr, node_tensors, edge_tensors, ntype_names, etype_names));
}
/*! \brief create an empty GraphData reference */
static HeteroGraphData Create() {
return HeteroGraphData(std::make_shared<HeteroGraphDataObject>());
}
};
} // namespace serialize
} // namespace dgl
namespace dmlc {
DMLC_DECLARE_TRAITS(has_saveload, dgl::serialize::HeteroGraphDataObject, true);
}
#endif // DGL_GRAPH_SERIALIZE_HETEROGRAPH_DATA_H_
/*!
* Copyright (c) 2019 by Contributors
* \file graph/serialize/heterograph_serialize.cc
* \brief DGLHeteroGraph serialization implementation
*
* The storage structure is
* {
* // MetaData Section
* uint64_t kDGLSerializeMagic
* uint64_t kVersion = 2
* uint64_t GraphType = kDGLHeteroGraph
* dgl_id_t num_graphs
* ** Reserved Area till 4kB **
*
* uint64_t gdata_start_pos (This stores the start position of graph_data,
* which is used to skip label dict part if unnecessary)
* vector<pair<string, NDArray>> label_dict (To store the dict[str, NDArray])
*
* vector<HeteroGraphData> graph_datas;
* vector<dgl_id_t> graph_indices (start address of each graph)
* uint64_t size_of_graph_indices_vector (Used to seek to graph_indices
* vector)
*
* }
*
* Storage of HeteroGraphData is
* {
* HeteroGraphPtr ptr;
* vector<vector<pair<string, NDArray>>> node_tensors;
* vector<vector<pair<string, NDArray>>> edge_tensors;
* vector<string> ntype_name;
* vector<string> etype_name;
* }
*
*/
#include <dgl/graph_op.h>
#include <dgl/immutable_graph.h>
#include <dgl/runtime/container.h>
#include <dgl/runtime/object.h>
#include <dmlc/io.h>
#include <dmlc/type_traits.h>
#include <algorithm>
#include <array>
#include <iostream>
#include <string>
#include <utility>
#include <vector>
#include "../heterograph.h"
#include "./graph_serialize.h"
#include "./streamwithcount.h"
#include "dmlc/memory_io.h"
namespace dgl {
namespace serialize {
using namespace dgl::runtime;
using dmlc::SeekStream;
using dmlc::Stream;
using dmlc::io::FileSystem;
using dmlc::io::URI;
bool SaveHeteroGraphs(std::string filename, List<HeteroGraphData> hdata,
const std::vector<NamedTensor> &nd_list) {
auto fs = std::unique_ptr<StreamWithCount>(
StreamWithCount::Create(filename.c_str(), "w", false));
CHECK(fs->IsValid()) << "File name " << filename << " is not a valid name";
// Write DGL MetaData
const uint64_t kVersion = 2;
std::array<char, 4096> meta_buffer;
// Write metadata into char buffer with size 4096
dmlc::MemoryFixedSizeStream meta_fs_(meta_buffer.data(), 4096);
auto meta_fs = static_cast<Stream *>(&meta_fs_);
meta_fs->Write(kDGLSerializeMagic);
meta_fs->Write(kVersion);
meta_fs->Write(GraphType::kHeteroGraph);
uint64_t num_graph = hdata.size();
meta_fs->Write(num_graph);
// Write metadata into files
fs->Write(meta_buffer.data(), 4096);
// Calculate label dict binary size
std::string labels_blob;
dmlc::MemoryStringStream label_fs_(&labels_blob);
auto label_fs = static_cast<Stream *>(&label_fs_);
label_fs->Write(nd_list);
uint64_t gdata_start_pos =
fs->Count() + sizeof(uint64_t) + labels_blob.size();
// Write start position of gdata, which can be skipped when only reading gdata
// And label dict
fs->Write(gdata_start_pos);
fs->Write(labels_blob.c_str(), labels_blob.size());
std::vector<uint64_t> graph_indices(num_graph);
// Write HeteroGraphData
for (uint64_t i = 0; i < num_graph; ++i) {
graph_indices[i] = fs->Count();
auto gdata = hdata[i].sptr();
fs->Write(gdata);
}
// Write indptr into string to count size
std::string indptr_blob;
dmlc::MemoryStringStream indptr_fs_(&indptr_blob);
auto indptr_fs = static_cast<Stream *>(&indptr_fs_);
indptr_fs->Write(graph_indices);
uint64_t indptr_buffer_size = indptr_blob.size();
fs->Write(indptr_blob);
fs->Write(indptr_buffer_size);
return true;
}
std::vector<HeteroGraphData> LoadHeteroGraphs(const std::string &filename,
std::vector<dgl_id_t> idx_list) {
auto fs = std::unique_ptr<SeekStream>(
SeekStream::CreateForRead(filename.c_str(), false));
CHECK(fs) << "File name " << filename << " is not a valid name";
// Read DGL MetaData
uint64_t magicNum, graphType, version, num_graph;
fs->Read(&magicNum);
fs->Read(&version);
fs->Read(&graphType);
CHECK(fs->Read(&num_graph)) << "Invalid num of graph";
fs->Seek(4096);
CHECK_EQ(magicNum, kDGLSerializeMagic) << "Invalid DGL files";
CHECK_EQ(version, 2) << "Invalid GraphType";
CHECK_EQ(graphType, GraphType::kHeteroGraph) << "Invalid GraphType";
uint64_t gdata_start_pos;
fs->Read(&gdata_start_pos);
// Skip labels part
fs->Seek(gdata_start_pos);
std::vector<HeteroGraphData> gdata_refs;
if (idx_list.empty()) {
// Read All Graphs
gdata_refs.reserve(num_graph);
for (uint64_t i = 0; i < num_graph; ++i) {
HeteroGraphData gdata = HeteroGraphData::Create();
auto hetero_data = gdata.sptr();
fs->Read(&hetero_data);
gdata_refs.push_back(gdata);
}
} else {
uint64_t gdata_start_pos = fs->Tell();
// Read Selected Graphss
gdata_refs.reserve(idx_list.size());
URI uri(filename.c_str());
uint64_t filesize = FileSystem::GetInstance(uri)->GetPathInfo(uri).size;
fs->Seek(filesize - sizeof(uint64_t));
uint64_t indptr_buffer_size;
fs->Read(&indptr_buffer_size);
std::vector<uint64_t> graph_indices(num_graph);
fs->Seek(filesize - sizeof(uint64_t) - indptr_buffer_size);
fs->Read(&graph_indices);
fs->Seek(gdata_start_pos);
// Would be better if idx_list is sorted. However the returned the graphs
// should be the same order as the idx_list
for (uint64_t i = 0; i < idx_list.size(); ++i) {
fs->Seek(graph_indices[idx_list[i]]);
HeteroGraphData gdata = HeteroGraphData::Create();
auto hetero_data = gdata.sptr();
fs->Read(&hetero_data);
gdata_refs.push_back(gdata);
}
}
return gdata_refs;
}
std::vector<NamedTensor> LoadLabels_V2(const std::string &filename) {
auto fs = std::unique_ptr<SeekStream>(
SeekStream::CreateForRead(filename.c_str(), false));
CHECK(fs) << "File name " << filename << " is not a valid name";
// Read DGL MetaData
uint64_t magicNum, graphType, version, num_graph;
fs->Read(&magicNum);
fs->Read(&version);
fs->Read(&graphType);
CHECK(fs->Read(&num_graph)) << "Invalid num of graph";
fs->Seek(4096);
uint64_t gdata_start_pos;
fs->Read(&gdata_start_pos);
std::vector<NamedTensor> labels_list;
fs->Read(&labels_list);
return labels_list;
}
DGL_REGISTER_GLOBAL("data.heterograph_serialize._CAPI_MakeHeteroGraphData")
.set_body([](DGLArgs args, DGLRetValue *rv) {
HeteroGraphRef hg = args[0];
List<Map<std::string, Value>> ndata = args[1];
List<Map<std::string, Value>> edata = args[2];
List<Value> ntype_names = args[3];
List<Value> etype_names = args[4];
*rv = HeteroGraphData::Create(hg.sptr(), ndata, edata, ntype_names,
etype_names);
});
DGL_REGISTER_GLOBAL("data.heterograph_serialize._CAPI_SaveHeteroGraphData")
.set_body([](DGLArgs args, DGLRetValue *rv) {
std::string filename = args[0];
List<HeteroGraphData> hgdata = args[1];
Map<std::string, Value> nd_map = args[2];
std::vector<NamedTensor> nd_list;
for (auto kv : nd_map) {
NDArray ndarray = static_cast<NDArray>(kv.second->data);
nd_list.emplace_back(kv.first, ndarray);
}
*rv = dgl::serialize::SaveHeteroGraphs(filename, hgdata, nd_list);
});
DGL_REGISTER_GLOBAL(
"data.heterograph_serialize._CAPI_GetGindexFromHeteroGraphData")
.set_body([](DGLArgs args, DGLRetValue *rv) {
HeteroGraphData hdata = args[0];
*rv = HeteroGraphRef(hdata->gptr);
});
DGL_REGISTER_GLOBAL(
"data.heterograph_serialize._CAPI_GetEtypesFromHeteroGraphData")
.set_body([](DGLArgs args, DGLRetValue *rv) {
HeteroGraphData hdata = args[0];
List<Value> etype_names;
for (const auto &name : hdata->etype_names) {
etype_names.push_back(Value(MakeValue(name)));
}
*rv = etype_names;
});
DGL_REGISTER_GLOBAL(
"data.heterograph_serialize._CAPI_GetNtypesFromHeteroGraphData")
.set_body([](DGLArgs args, DGLRetValue *rv) {
HeteroGraphData hdata = args[0];
List<Value> ntype_names;
for (auto name : hdata->ntype_names) {
ntype_names.push_back(Value(MakeValue(name)));
}
*rv = ntype_names;
});
DGL_REGISTER_GLOBAL(
"data.heterograph_serialize._CAPI_GetNDataFromHeteroGraphData")
.set_body([](DGLArgs args, DGLRetValue *rv) {
HeteroGraphData hdata = args[0];
List<List<Value>> ntensors;
for (auto tensor_list : hdata->node_tensors) {
List<Value> nlist;
for (const auto &kv : tensor_list) {
nlist.push_back(Value(MakeValue(kv.first)));
nlist.push_back(Value(MakeValue(kv.second)));
}
ntensors.push_back(nlist);
}
*rv = ntensors;
});
DGL_REGISTER_GLOBAL(
"data.heterograph_serialize._CAPI_GetEDataFromHeteroGraphData")
.set_body([](DGLArgs args, DGLRetValue *rv) {
HeteroGraphData hdata = args[0];
List<List<Value>> etensors;
for (auto tensor_list : hdata->edge_tensors) {
List<Value> elist;
for (const auto &kv : tensor_list) {
elist.push_back(Value(MakeValue(kv.first)));
elist.push_back(Value(MakeValue(kv.second)));
}
etensors.push_back(elist);
}
*rv = etensors;
});
DGL_REGISTER_GLOBAL("data.graph_serialize._CAPI_LoadLabels_V2")
.set_body([](DGLArgs args, DGLRetValue *rv) {
std::string filename = args[0];
auto labels_list = LoadLabels_V2(filename);
Map<std::string, Value> rvmap;
for (auto kv : labels_list) {
rvmap.Set(kv.first, Value(MakeValue(kv.second)));
}
*rv = rvmap;
});
} // namespace serialize
} // namespace dgl
/*!
* Copyright (c) 2019 by Contributors
* \file graph/serialize/streamwithcount.h
* \brief Graph serialization header
*/
#ifndef DGL_GRAPH_SERIALIZE_STREAMWITHCOUNT_H_
#define DGL_GRAPH_SERIALIZE_STREAMWITHCOUNT_H_
#include <dmlc/io.h>
#include <dmlc/type_traits.h>
/*!
* \brief StreamWithCount counts the bytes that already written into the
* underlying stream.
*/
class StreamWithCount : dmlc::Stream {
public:
static StreamWithCount *Create(const char *uri, const char *const flag,
bool allow_null = false) {
return new StreamWithCount(uri, flag, allow_null);
}
size_t Read(void *ptr, size_t size) override {
return strm_->Read(ptr, size);
}
void Write(const void *ptr, size_t size) override {
count_ += size;
strm_->Write(ptr, size);
}
using dmlc::Stream::Read;
using dmlc::Stream::Write;
bool IsValid() { return strm_.get(); }
uint64_t Count() const { return count_; }
private:
StreamWithCount(const char *uri, const char *const flag, bool allow_null)
: strm_(dmlc::Stream::Create(uri, flag, allow_null)) {}
std::unique_ptr<dmlc::Stream> strm_;
uint64_t count_ = 0;
};
#endif // DGL_GRAPH_SERIALIZE_STREAMWITHCOUNT_H_
......@@ -24,7 +24,8 @@ constexpr uint64_t kDGLSerialize_Tensors = 0xDD5A9FBE3FA2443F;
DGL_REGISTER_GLOBAL("data.tensor_serialize._CAPI_SaveNDArrayDict")
.set_body([](DGLArgs args, DGLRetValue *rv) {
std::string filename = args[0];
auto *fs = dmlc::Stream::Create(filename.c_str(), "w");
auto fs = std::unique_ptr<dmlc::Stream>(
dmlc::Stream::Create(filename.c_str(), "w"));
CHECK(fs) << "Filename is invalid";
fs->Write(kDGLSerialize_Tensors);
bool empty_dict = args[2];
......@@ -40,13 +41,13 @@ DGL_REGISTER_GLOBAL("data.tensor_serialize._CAPI_SaveNDArrayDict")
}
fs->Write(namedTensors);
*rv = true;
delete fs;
});
DGL_REGISTER_GLOBAL("data.tensor_serialize._CAPI_LoadNDArrayDict")
.set_body([](DGLArgs args, DGLRetValue *rv) {
std::string filename = args[0];
auto *fs = dmlc::Stream::Create(filename.c_str(), "r");
auto fs = std::unique_ptr<dmlc::Stream>(
dmlc::Stream::Create(filename.c_str(), "r"));
CHECK(fs) << "Filename is invalid or file doesn't exists";
uint64_t magincNum, num_elements;
CHECK(fs->Read(&magincNum)) << "Invalid file";
......@@ -60,7 +61,6 @@ DGL_REGISTER_GLOBAL("data.tensor_serialize._CAPI_LoadNDArrayDict")
nd_dict.Set(kv.first, ndarray);
}
*rv = nd_dict;
delete fs;
});
} // namespace serialize
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment