Unverified Commit 18a26fcf authored by Jinjing Zhou's avatar Jinjing Zhou Committed by GitHub
Browse files

[Data] Support HeteroGraph save/load (#1526)



* 111

* add history version test

* fix

* 111

* save

* ``

* fix1

* 111

* add save heterograph

* lint

* lint

* add tests

* minor fix

* fix

* docs

* add format tets

* use unique_ptr

* fix

* fix interface

* 111

* 111

* fix

* lint

* fix

* add support to s3

* fix

* fix

* fix leak

* fix

* fix docs

* fix

* linlt

* fix

* fix

* fix

* address comment

* address comment
Co-authored-by: default avatarMinjie Wang <wmjlyjemaine@gmail.com>
Co-authored-by: default avatarxiang song(charlie.song) <classicxsong@gmail.com>
parent dd8d5289
...@@ -26,6 +26,7 @@ dgl_option(USE_CUDA "Build with CUDA" OFF) ...@@ -26,6 +26,7 @@ dgl_option(USE_CUDA "Build with CUDA" OFF)
dgl_option(USE_OPENMP "Build with OpenMP" ON) dgl_option(USE_OPENMP "Build with OpenMP" ON)
dgl_option(BUILD_CPP_TEST "Build cpp unittest executables" OFF) dgl_option(BUILD_CPP_TEST "Build cpp unittest executables" OFF)
dgl_option(LIBCXX_ENABLE_PARALLEL_ALGORITHMS "Enable the parallel algorithms library. This requires the PSTL to be available." OFF) dgl_option(LIBCXX_ENABLE_PARALLEL_ALGORITHMS "Enable the parallel algorithms library. This requires the PSTL to be available." OFF)
# Set debug compile option for gdb, only happens when -DCMAKE_BUILD_TYPE=DEBUG # Set debug compile option for gdb, only happens when -DCMAKE_BUILD_TYPE=DEBUG
if (NOT MSVC) if (NOT MSVC)
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -O0 -g3 -ggdb") set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -O0 -g3 -ggdb")
...@@ -132,6 +133,9 @@ endif(USE_CUDA) ...@@ -132,6 +133,9 @@ endif(USE_CUDA)
# For serialization # For serialization
add_subdirectory("third_party/dmlc-core") add_subdirectory("third_party/dmlc-core")
# dmlc-core options for support S3/HDFS
add_definitions(-DUSE_S3=OFF)
add_definitions(-DUSE_HDFS=OFF)
list(APPEND DGL_LINKER_LIBS dmlc) list(APPEND DGL_LINKER_LIBS dmlc)
set(GOOGLE_TEST 0) # Turn off dmlc-core test set(GOOGLE_TEST 0) # Turn off dmlc-core test
......
...@@ -3,8 +3,10 @@ reference: tvm/python/tvm/collections.py ...@@ -3,8 +3,10 @@ reference: tvm/python/tvm/collections.py
""" """
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from ._ffi.object import ObjectBase, register_object from ._ffi.object import ObjectBase, register_object
from ._ffi.object_generic import convert_to_object
from . import _api_internal from . import _api_internal
@register_object @register_object
class List(ObjectBase): class List(ObjectBase):
"""List container of DGL. """List container of DGL.
...@@ -14,6 +16,7 @@ class List(ObjectBase): ...@@ -14,6 +16,7 @@ class List(ObjectBase):
to List during dgl function call. to List during dgl function call.
You may get List in return values of DGL function call. You may get List in return values of DGL function call.
""" """
def __getitem__(self, i): def __getitem__(self, i):
if isinstance(i, slice): if isinstance(i, slice):
start = i.start if i.start is not None else 0 start = i.start if i.start is not None else 0
...@@ -30,11 +33,15 @@ class List(ObjectBase): ...@@ -30,11 +33,15 @@ class List(ObjectBase):
.format(len(self), i)) .format(len(self), i))
if i < 0: if i < 0:
i += len(self) i += len(self)
return _api_internal._ListGetItem(self, i) ret = _api_internal._ListGetItem(self, i)
if isinstance(ret, Value):
ret = ret.data
return ret
def __len__(self): def __len__(self):
return _api_internal._ListSize(self) return _api_internal._ListSize(self)
@register_object @register_object
class Map(ObjectBase): class Map(ObjectBase):
"""Map container of DGL. """Map container of DGL.
...@@ -43,6 +50,7 @@ class Map(ObjectBase): ...@@ -43,6 +50,7 @@ class Map(ObjectBase):
Normally python dict will be converted automaticall to Map during dgl function call. Normally python dict will be converted automaticall to Map during dgl function call.
You can use convert to create a dict[ObjectBase-> ObjectBase] into a Map You can use convert to create a dict[ObjectBase-> ObjectBase] into a Map
""" """
def __getitem__(self, k): def __getitem__(self, k):
return _api_internal._MapGetItem(self, k) return _api_internal._MapGetItem(self, k)
...@@ -64,10 +72,12 @@ class StrMap(Map): ...@@ -64,10 +72,12 @@ class StrMap(Map):
You can use convert to create a dict[str->ObjectBase] into a Map. You can use convert to create a dict[str->ObjectBase] into a Map.
""" """
def items(self): def items(self):
"""Get the items from the map""" """Get the items from the map"""
akvs = _api_internal._MapItems(self) akvs = _api_internal._MapItems(self)
return [(akvs[i].data, akvs[i+1]) for i in range(0, len(akvs), 2)] return [(akvs[i], akvs[i+1]) for i in range(0, len(akvs), 2)]
@register_object @register_object
class Value(ObjectBase): class Value(ObjectBase):
...@@ -76,3 +86,12 @@ class Value(ObjectBase): ...@@ -76,3 +86,12 @@ class Value(ObjectBase):
def data(self): def data(self):
"""Return the value data.""" """Return the value data."""
return _api_internal._ValueGet(self) return _api_internal._ValueGet(self)
def convert_to_strmap(value):
"""Convert a python dictionary to a dgl.contrainer.StrMap"""
assert isinstance(value, dict), "Only support dict"
if len(value) == 0:
return _api_internal._EmptyStrMap()
else:
return convert_to_object(value)
"""For Graph Serialization""" """For Graph Serialization"""
from __future__ import absolute_import from __future__ import absolute_import
from ..graph import DGLGraph from ..graph import DGLGraph
from ..heterograph import DGLHeteroGraph
from .._ffi.object import ObjectBase, register_object from .._ffi.object import ObjectBase, register_object
from .._ffi.function import _init_api from .._ffi.function import _init_api
from .. import backend as F from .. import backend as F
from .heterograph_serialize import HeteroGraphData, save_heterographs
_init_api("dgl.data.graph_serialize") _init_api("dgl.data.graph_serialize")
__all__ = ['save_graphs', "load_graphs", "load_labels"] __all__ = ['save_graphs', "load_graphs", "load_labels"]
@register_object("graph_serialize.StorageMetaData") @register_object("graph_serialize.StorageMetaData")
class StorageMetaData(ObjectBase): class StorageMetaData(ObjectBase):
"""StorageMetaData Object """StorageMetaData Object
...@@ -54,40 +57,36 @@ class GraphData(ObjectBase): ...@@ -54,40 +57,36 @@ class GraphData(ObjectBase):
node_tensors_items = _CAPI_GDataNodeTensors(self).items() node_tensors_items = _CAPI_GDataNodeTensors(self).items()
edge_tensors_items = _CAPI_GDataEdgeTensors(self).items() edge_tensors_items = _CAPI_GDataEdgeTensors(self).items()
for k, v in node_tensors_items: for k, v in node_tensors_items:
g.ndata[k] = F.zerocopy_from_dgl_ndarray(v.data) g.ndata[k] = F.zerocopy_from_dgl_ndarray(v)
for k, v in edge_tensors_items: for k, v in edge_tensors_items:
g.edata[k] = F.zerocopy_from_dgl_ndarray(v.data) g.edata[k] = F.zerocopy_from_dgl_ndarray(v)
return g return g
def save_graphs(filename, g_list, labels=None): def save_graphs(filename, g_list, labels=None):
r""" r"""
Save DGLGraphs and graph labels to file Save DGLGraphs/DGLHeteroGraph and graph labels to file
Parameters Parameters
---------- ----------
filename : str filename : str
File name to store DGLGraphs. File name to store graphs.
g_list: list g_list: list
DGLGraph or list of DGLGraph DGLGraph or list of DGLGraph/DGLHeteroGraph
labels: dict (Default: None) labels: dict[str, tensor]
labels should be dict of tensors/ndarray, with str as keys labels should be dict of tensors, with str as keys
Examples Examples
---------- ----------
>>> import dgl >>> import dgl
>>> import torch as th >>> import torch as th
Create :code:`DGLGraph` objects and initialize node and edge features. Create :code:`DGLGraph`/:code:`DGLHeteroGraph` objects and initialize node
and edge features.
>>> g1 = dgl.DGLGraph() >>> g1 = dgl.graph(([0, 1, 2], [1, 2, 3])
>>> g1.add_nodes(3) >>> g2 = dgl.graph(([0, 2], [2, 3])
>>> g1.add_edges([0, 0, 0, 1, 1, 2], [0, 1, 2, 1, 2, 2]) >>> g2.edata["e"] = th.ones(2, 4)
>>> g1.ndata["e"] = th.ones(3, 5)
>>> g2 = dgl.DGLGraph()
>>> g2.add_nodes(3)
>>> g2.add_edges([0, 1, 2], [1, 2, 1])
>>> g2.edata["e"] = th.ones(3, 4)
Save Graphs into file Save Graphs into file
...@@ -96,6 +95,18 @@ def save_graphs(filename, g_list, labels=None): ...@@ -96,6 +95,18 @@ def save_graphs(filename, g_list, labels=None):
>>> save_graphs("./data.bin", [g1, g2], graph_labels) >>> save_graphs("./data.bin", [g1, g2], graph_labels)
""" """
g_sample = g_list[0] if isinstance(g_list, list) else g_list
if isinstance(g_sample, DGLGraph):
save_dglgraphs(filename, g_list, labels)
elif isinstance(g_sample, DGLHeteroGraph):
save_heterographs(filename, g_list, labels)
else:
raise Exception(
"Invalid argument g_list. Must be a DGLGraph or a list of DGLGraphs/DGLHeteroGraphs")
def save_dglgraphs(filename, g_list, labels=None):
"""Internal function to save DGLGraphs"""
if isinstance(g_list, DGLGraph): if isinstance(g_list, DGLGraph):
g_list = [g_list] g_list = [g_list]
if (labels is not None) and (len(labels) != 0): if (labels is not None) and (len(labels) != 0):
...@@ -105,7 +116,7 @@ def save_graphs(filename, g_list, labels=None): ...@@ -105,7 +116,7 @@ def save_graphs(filename, g_list, labels=None):
else: else:
label_dict = None label_dict = None
gdata_list = [GraphData.create(g) for g in g_list] gdata_list = [GraphData.create(g) for g in g_list]
_CAPI_DGLSaveGraphs(filename, gdata_list, label_dict) _CAPI_SaveDGLGraphs_V0(filename, gdata_list, label_dict)
def load_graphs(filename, idx_list=None): def load_graphs(filename, idx_list=None):
...@@ -115,16 +126,19 @@ def load_graphs(filename, idx_list=None): ...@@ -115,16 +126,19 @@ def load_graphs(filename, idx_list=None):
Parameters Parameters
---------- ----------
filename: str filename: str
filename to load DGLGraphs filename to load graphs
idx_list: list of int idx_list: list of int
list of index of graph to be loaded. If not specified, will list of index of graph to be loaded. If not specified, will
load all graphs from file load all graphs from file
Returns Returns
---------- --------
graph_list: list of immutable DGLGraphs graph_list: list of DGLGraphs / DGLHeteroGraph
labels: dict of labels stored in file (empty dict returned if no The loaded graphs.
label stored) labels: dict[str, Tensor]
The graph labels stored in file. If no label is stored, the dictionary is empty.
Regardless of whether the ``idx_list`` argument is given or not, the returned dictionary
always contains labels of all the graphs.
Examples Examples
---------- ----------
...@@ -135,13 +149,34 @@ def load_graphs(filename, idx_list=None): ...@@ -135,13 +149,34 @@ def load_graphs(filename, idx_list=None):
>>> glist, label_dict = load_graphs("./data.bin", [0]) # glist will be [g1] >>> glist, label_dict = load_graphs("./data.bin", [0]) # glist will be [g1]
""" """
version = _CAPI_GetFileVersion(filename)
if version == 1:
return load_graph_v1(filename, idx_list)
elif version == 2:
return load_graph_v2(filename, idx_list)
else:
raise Exception("Invalid DGL Version Number")
def load_graph_v2(filename, idx_list=None):
"""Internal functions for loading DGLHeteroGraphs."""
if idx_list is None:
idx_list = []
assert isinstance(idx_list, list)
heterograph_list = _CAPI_LoadGraphFiles_V2(filename, idx_list)
label_dict = load_labels_v2(filename)
return [gdata.get_graph() for gdata in heterograph_list], label_dict
def load_graph_v1(filename, idx_list=None):
""""Internal functions for loading DGLGraphs (V0)."""
if idx_list is None: if idx_list is None:
idx_list = [] idx_list = []
assert isinstance(idx_list, list) assert isinstance(idx_list, list)
metadata = _CAPI_DGLLoadGraphs(filename, idx_list, False) metadata = _CAPI_LoadGraphFiles_V1(filename, idx_list, False)
label_dict = {} label_dict = {}
for k, v in metadata.labels.items(): for k, v in metadata.labels.items():
label_dict[k] = F.zerocopy_from_dgl_ndarray(v.data) label_dict[k] = F.zerocopy_from_dgl_ndarray(v)
return [gdata.get_graph() for gdata in metadata.graph_data], label_dict return [gdata.get_graph() for gdata in metadata.graph_data], label_dict
...@@ -169,8 +204,28 @@ def load_labels(filename): ...@@ -169,8 +204,28 @@ def load_labels(filename):
>>> label_dict = load_graphs("./data.bin") >>> label_dict = load_graphs("./data.bin")
""" """
metadata = _CAPI_DGLLoadGraphs(filename, [], True) version = _CAPI_GetFileVersion(filename)
if version == 1:
return load_labels_v1(filename)
elif version == 2:
return load_labels_v2(filename)
else:
raise Exception("Invalid DGL Version Number")
def load_labels_v2(filename):
"""Internal functions for loading labels from V2 format"""
label_dict = {}
nd_dict = _CAPI_LoadLabels_V2(filename)
for k, v in nd_dict.items():
label_dict[k] = F.zerocopy_from_dgl_ndarray(v)
return label_dict
def load_labels_v1(filename):
"""Internal functions for loading labels from V1 format"""
metadata = _CAPI_LoadGraphFiles_V1(filename, [], True)
label_dict = {} label_dict = {}
for k, v in metadata.labels.items(): for k, v in metadata.labels.items():
label_dict[k] = F.zerocopy_from_dgl_ndarray(v.data) label_dict[k] = F.zerocopy_from_dgl_ndarray(v)
return label_dict return label_dict
"""For HeteroGraph Serialization"""
from __future__ import absolute_import
from ..heterograph import DGLHeteroGraph
from ..frame import Frame, FrameRef
from .._ffi.object import ObjectBase, register_object
from .._ffi.function import _init_api
from .. import backend as F
from ..container import convert_to_strmap
_init_api("dgl.data.heterograph_serialize")
def tensor_dict_to_ndarray_dict(tensor_dict):
"""Convert dict[str, tensor] to StrMap[NDArray]"""
ndarray_dict = {}
for key, value in tensor_dict.items():
ndarray_dict[key] = F.zerocopy_to_dgl_ndarray(value)
return convert_to_strmap(ndarray_dict)
def save_heterographs(filename, g_list, labels):
"""Save heterographs into file"""
if labels is None:
labels = {}
if isinstance(g_list, DGLHeteroGraph):
g_list = [g_list]
gdata_list = [HeteroGraphData.create(g) for g in g_list]
_CAPI_SaveHeteroGraphData(filename, gdata_list, tensor_dict_to_ndarray_dict(labels))
@register_object("heterograph_serialize.HeteroGraphData")
class HeteroGraphData(ObjectBase):
"""Object to hold the data to be stored for DGLHeteroGraph"""
@staticmethod
def create(g):
edata_list = []
ndata_list = []
for etype in g.etypes:
edata_list.append(tensor_dict_to_ndarray_dict(g.edges[etype].data))
for ntype in g.ntypes:
ndata_list.append(tensor_dict_to_ndarray_dict(g.nodes[ntype].data))
return _CAPI_MakeHeteroGraphData(g._graph, ndata_list, edata_list, g.ntypes, g.etypes)
def get_graph(self):
ntensor_list = list(_CAPI_GetNDataFromHeteroGraphData(self))
etensor_list = list(_CAPI_GetEDataFromHeteroGraphData(self))
ntype_names = list(_CAPI_GetNtypesFromHeteroGraphData(self))
etype_names = list(_CAPI_GetEtypesFromHeteroGraphData(self))
gidx = _CAPI_GetGindexFromHeteroGraphData(self)
nframes = []
eframes = []
for ntensor in ntensor_list:
ndict = {ntensor[i]: F.zerocopy_from_dgl_ndarray(ntensor[i+1]) for i in range(0, len(ntensor), 2)}
nframes.append(FrameRef(Frame(ndict)))
for etensor in etensor_list:
edict = {etensor[i]: F.zerocopy_from_dgl_ndarray(etensor[i+1]) for i in range(0, len(etensor), 2)}
eframes.append(FrameRef(Frame(edict)))
return DGLHeteroGraph(gidx, ntype_names, etype_names, nframes, eframes)
...@@ -61,7 +61,7 @@ def load_tensors(filename, return_dgl_ndarray=False): ...@@ -61,7 +61,7 @@ def load_tensors(filename, return_dgl_ndarray=False):
tensor_dict = {} tensor_dict = {}
for key, value in nd_dict.items(): for key, value in nd_dict.items():
if return_dgl_ndarray: if return_dgl_ndarray:
tensor_dict[key] = value.data tensor_dict[key] = value
else: else:
tensor_dict[key] = F.zerocopy_from_dgl_ndarray(value.data) tensor_dict[key] = F.zerocopy_from_dgl_ndarray(value)
return tensor_dict return tensor_dict
...@@ -515,7 +515,7 @@ class RPCMessage(ObjectBase): ...@@ -515,7 +515,7 @@ class RPCMessage(ObjectBase):
def tensors(self): def tensors(self):
"""Get tensor payloads.""" """Get tensor payloads."""
rst = _CAPI_DGLRPCMessageGetTensors(self) rst = _CAPI_DGLRPCMessageGetTensors(self)
return [F.zerocopy_from_dgl_ndarray(tsor.data) for tsor in rst] return [F.zerocopy_from_dgl_ndarray(tsor) for tsor in rst]
def send_request(target, request): def send_request(target, request):
"""Send one request to the target server. """Send one request to the target server.
......
...@@ -4778,8 +4778,8 @@ def find_src_dst_ntypes(ntypes, metagraph): ...@@ -4778,8 +4778,8 @@ def find_src_dst_ntypes(ntypes, metagraph):
return None return None
else: else:
src, dst = ret src, dst = ret
srctypes = {ntypes[tid.data] : tid.data for tid in src} srctypes = {ntypes[tid] : tid for tid in src}
dsttypes = {ntypes[tid.data] : tid.data for tid in dst} dsttypes = {ntypes[tid] : tid for tid in dst}
return srctypes, dsttypes return srctypes, dsttypes
def infer_ntype_from_dict(graph, etype_dict): def infer_ntype_from_dict(graph, etype_dict):
......
...@@ -1017,7 +1017,7 @@ class HeteroSubgraphIndex(ObjectBase): ...@@ -1017,7 +1017,7 @@ class HeteroSubgraphIndex(ObjectBase):
Induced nodes Induced nodes
""" """
ret = _CAPI_DGLHeteroSubgraphGetInducedVertices(self) ret = _CAPI_DGLHeteroSubgraphGetInducedVertices(self)
return [utils.toindex(v.data, self.graph.dtype) for v in ret] return [utils.toindex(v, self.graph.dtype) for v in ret]
@property @property
def induced_edges(self): def induced_edges(self):
...@@ -1030,7 +1030,7 @@ class HeteroSubgraphIndex(ObjectBase): ...@@ -1030,7 +1030,7 @@ class HeteroSubgraphIndex(ObjectBase):
Induced edges Induced edges
""" """
ret = _CAPI_DGLHeteroSubgraphGetInducedEdges(self) ret = _CAPI_DGLHeteroSubgraphGetInducedEdges(self)
return [utils.toindex(v.data, self.graph.dtype) for v in ret] return [utils.toindex(v, self.graph.dtype) for v in ret]
################################################################# #################################################################
......
...@@ -172,7 +172,7 @@ class SparseMatrix(ObjectBase): ...@@ -172,7 +172,7 @@ class SparseMatrix(ObjectBase):
------- -------
list of boolean list of boolean
""" """
return [v.data for v in _CAPI_DGLSparseMatrixGetFlags(self)] return [v for v in _CAPI_DGLSparseMatrixGetFlags(self)]
def __getstate__(self): def __getstate__(self):
return self.format, self.num_rows, self.num_cols, self.indices, self.flags return self.format, self.num_rows, self.num_cols, self.indices, self.flags
......
...@@ -153,8 +153,8 @@ def random_walk(g, nodes, *, metapath=None, length=None, prob=None, restart_prob ...@@ -153,8 +153,8 @@ def random_walk(g, nodes, *, metapath=None, length=None, prob=None, restart_prob
traces, types = _CAPI_DGLSamplingRandomWalkWithRestart( traces, types = _CAPI_DGLSamplingRandomWalkWithRestart(
gidx, nodes, metapath, p_nd, restart_prob) gidx, nodes, metapath, p_nd, restart_prob)
traces = F.zerocopy_from_dgl_ndarray(traces.data) traces = F.zerocopy_from_dgl_ndarray(traces)
types = F.zerocopy_from_dgl_ndarray(types.data) types = F.zerocopy_from_dgl_ndarray(types)
return traces, types return traces, types
def pack_traces(traces, types): def pack_traces(traces, types):
...@@ -221,10 +221,10 @@ def pack_traces(traces, types): ...@@ -221,10 +221,10 @@ def pack_traces(traces, types):
concat_vids, concat_types, lengths, offsets = _CAPI_DGLSamplingPackTraces(traces, types) concat_vids, concat_types, lengths, offsets = _CAPI_DGLSamplingPackTraces(traces, types)
concat_vids = F.zerocopy_from_dgl_ndarray(concat_vids.data) concat_vids = F.zerocopy_from_dgl_ndarray(concat_vids)
concat_types = F.zerocopy_from_dgl_ndarray(concat_types.data) concat_types = F.zerocopy_from_dgl_ndarray(concat_types)
lengths = F.zerocopy_from_dgl_ndarray(lengths.data) lengths = F.zerocopy_from_dgl_ndarray(lengths)
offsets = F.zerocopy_from_dgl_ndarray(offsets.data) offsets = F.zerocopy_from_dgl_ndarray(offsets)
return concat_vids, concat_types, lengths, offsets return concat_vids, concat_types, lengths, offsets
......
...@@ -906,7 +906,7 @@ def compact_graphs(graphs, always_preserve=None): ...@@ -906,7 +906,7 @@ def compact_graphs(graphs, always_preserve=None):
# Compact and construct heterographs # Compact and construct heterographs
new_graph_indexes, induced_nodes = _CAPI_DGLCompactGraphs( new_graph_indexes, induced_nodes = _CAPI_DGLCompactGraphs(
[g._graph for g in graphs], always_preserve_nd) [g._graph for g in graphs], always_preserve_nd)
induced_nodes = [F.zerocopy_from_dgl_ndarray(nodes.data) for nodes in induced_nodes] induced_nodes = [F.zerocopy_from_dgl_ndarray(nodes) for nodes in induced_nodes]
new_graphs = [ new_graphs = [
DGLHeteroGraph(new_graph_index, graph.ntypes, graph.etypes) DGLHeteroGraph(new_graph_index, graph.ntypes, graph.etypes)
...@@ -1063,7 +1063,7 @@ def to_block(g, dst_nodes=None, include_dst_in_src=True): ...@@ -1063,7 +1063,7 @@ def to_block(g, dst_nodes=None, include_dst_in_src=True):
assert new_graph.is_unibipartite # sanity check assert new_graph.is_unibipartite # sanity check
for i, ntype in enumerate(g.ntypes): for i, ntype in enumerate(g.ntypes):
new_graph.srcnodes[ntype].data[NID] = F.zerocopy_from_dgl_ndarray(src_nodes_nd[i].data) new_graph.srcnodes[ntype].data[NID] = F.zerocopy_from_dgl_ndarray(src_nodes_nd[i])
if ntype in dst_nodes: if ntype in dst_nodes:
new_graph.dstnodes[ntype].data[NID] = dst_nodes[ntype] new_graph.dstnodes[ntype].data[NID] = dst_nodes[ntype]
else: else:
...@@ -1071,7 +1071,7 @@ def to_block(g, dst_nodes=None, include_dst_in_src=True): ...@@ -1071,7 +1071,7 @@ def to_block(g, dst_nodes=None, include_dst_in_src=True):
new_graph.dstnodes[ntype].data[NID] = F.tensor([], dtype=g.idtype) new_graph.dstnodes[ntype].data[NID] = F.tensor([], dtype=g.idtype)
for i, canonical_etype in enumerate(g.canonical_etypes): for i, canonical_etype in enumerate(g.canonical_etypes):
induced_edges = F.zerocopy_from_dgl_ndarray(induced_edges_nd[i].data) induced_edges = F.zerocopy_from_dgl_ndarray(induced_edges_nd[i])
utype, etype, vtype = canonical_etype utype, etype, vtype = canonical_etype
new_canonical_etype = (utype, etype, vtype) new_canonical_etype = (utype, etype, vtype)
new_graph.edges[new_canonical_etype].data[EID] = induced_edges new_graph.edges[new_canonical_etype].data[EID] = induced_edges
...@@ -1114,7 +1114,7 @@ def remove_edges(g, edge_ids): ...@@ -1114,7 +1114,7 @@ def remove_edges(g, edge_ids):
new_graph = DGLHeteroGraph(new_graph_index, g.ntypes, g.etypes) new_graph = DGLHeteroGraph(new_graph_index, g.ntypes, g.etypes)
for i, canonical_etype in enumerate(g.canonical_etypes): for i, canonical_etype in enumerate(g.canonical_etypes):
data = induced_eids_nd[i].data data = induced_eids_nd[i]
if len(data) == 0: if len(data) == 0:
# Empty means that either # Empty means that either
# (1) no edges are removed and edges are not shuffled. # (1) no edges are removed and edges are not shuffled.
...@@ -1256,8 +1256,8 @@ def to_simple(g, return_counts='count', writeback_mapping=None): ...@@ -1256,8 +1256,8 @@ def to_simple(g, return_counts='count', writeback_mapping=None):
""" """
simple_graph_index, counts, edge_maps = _CAPI_DGLToSimpleHetero(g._graph) simple_graph_index, counts, edge_maps = _CAPI_DGLToSimpleHetero(g._graph)
simple_graph = DGLHeteroGraph(simple_graph_index, g.ntypes, g.etypes) simple_graph = DGLHeteroGraph(simple_graph_index, g.ntypes, g.etypes)
counts = [F.zerocopy_from_dgl_ndarray(count.data) for count in counts] counts = [F.zerocopy_from_dgl_ndarray(count) for count in counts]
edge_maps = [F.zerocopy_from_dgl_ndarray(edge_map.data) for edge_map in edge_maps] edge_maps = [F.zerocopy_from_dgl_ndarray(edge_map) for edge_map in edge_maps]
if return_counts is not None: if return_counts is not None:
for count, canonical_etype in zip(counts, g.canonical_etypes): for count, canonical_etype in zip(counts, g.canonical_etypes):
......
...@@ -71,6 +71,13 @@ DGL_REGISTER_GLOBAL("_Map") ...@@ -71,6 +71,13 @@ DGL_REGISTER_GLOBAL("_Map")
} }
}); });
DGL_REGISTER_GLOBAL("_EmptyStrMap").set_body([](DGLArgs args, DGLRetValue* rv) {
StrMapObject::ContainerType data;
auto obj = std::make_shared<StrMapObject>();
obj->data = std::move(data);
*rv = obj;
});
DGL_REGISTER_GLOBAL("_MapSize") DGL_REGISTER_GLOBAL("_MapSize")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
auto& sptr = args[0].obj_sptr(); auto& sptr = args[0].obj_sptr();
......
/*!
* Copyright (c) 2019 by Contributors
* \file graph/serialize/dglgraph_data.h
* \brief Graph serialization header
*/
#ifndef DGL_GRAPH_SERIALIZE_DGLGRAPH_DATA_H_
#define DGL_GRAPH_SERIALIZE_DGLGRAPH_DATA_H_
#include <dgl/graph.h>
#include <dgl/array.h>
#include <dgl/immutable_graph.h>
#include <dmlc/io.h>
#include <dmlc/type_traits.h>
#include <dgl/runtime/ndarray.h>
#include <dgl/runtime/container.h>
#include <dgl/runtime/object.h>
#include <dgl/packed_func_ext.h>
#include <iostream>
#include <string>
#include <vector>
#include <algorithm>
#include <utility>
#include "../../c_api_common.h"
using dgl::runtime::NDArray;
using dgl::ImmutableGraph;
using namespace dgl::runtime;
namespace dgl {
namespace serialize {
typedef std::pair<std::string, NDArray> NamedTensor;
class GraphDataObject : public runtime::Object {
public:
ImmutableGraphPtr gptr;
std::vector<NamedTensor> node_tensors;
std::vector<NamedTensor> edge_tensors;
static constexpr const char *_type_key = "graph_serialize.GraphData";
void SetData(ImmutableGraphPtr gptr,
Map<std::string, Value> node_tensors,
Map<std::string, Value> edge_tensors);
void Save(dmlc::Stream *fs) const;
bool Load(dmlc::Stream *fs);
DGL_DECLARE_OBJECT_TYPE_INFO(GraphDataObject, runtime::Object);
};
class GraphData : public runtime::ObjectRef {
public:
DGL_DEFINE_OBJECT_REF_METHODS(GraphData, runtime::ObjectRef, GraphDataObject);
/*! \brief create a new GraphData reference */
static GraphData Create() {
return GraphData(std::make_shared<GraphDataObject>());
}
};
ImmutableGraphPtr ToImmutableGraph(GraphPtr g);
} // namespace serialize
} // namespace dgl
#endif // DGL_GRAPH_SERIALIZE_DGLGRAPH_DATA_H_
/*!
* Copyright (c) 2019 by Contributors
* \file graph/serialize/graph_serialize.cc
* \brief Graph serialization implementation
*
* The storage structure is
* {
* // MetaData Section
* uint64_t kDGLSerializeMagic
* uint64_t kVersion
* uint64_t GraphType
* ** Reserved Area till 4kB **
*
* dgl_id_t num_graphs
* vector<dgl_id_t> graph_indices (start address of each graph)
* vector<dgl_id_t> nodes_num_list (list of number of nodes for each graph)
* vector<dgl_id_t> edges_num_list (list of number of edges for each graph)
*
* vector<GraphData> graph_datas;
*
* }
*
* Storage of GraphData is
* {
* // Everything uses in csr
* NDArray indptr
* NDArray indices
* NDArray edge_ids
* vector<pair<string, NDArray>> node_tensors;
* vector<pair<string, NDArray>> edge_tensors;
* }
*
*/
#include <dgl/graph_op.h>
#include <dgl/immutable_graph.h>
#include <dgl/runtime/container.h>
#include <dgl/runtime/object.h>
#include <dmlc/io.h>
#include <dmlc/type_traits.h>
#include <algorithm>
#include <iostream>
#include <string>
#include <utility>
#include <vector>
#include "graph_serialize.h"
using namespace dgl::runtime;
using dgl::COO;
using dgl::COOPtr;
using dgl::ImmutableGraph;
using dgl::runtime::NDArray;
using dgl::serialize::GraphData;
using dgl::serialize::GraphDataObject;
using dmlc::SeekStream;
using std::vector;
namespace dmlc {
DMLC_DECLARE_TRAITS(has_saveload, GraphDataObject, true);
}
namespace dgl {
namespace serialize {
bool SaveDGLGraphs(std::string filename, List<GraphData> graph_data,
std::vector<NamedTensor> labels_list) {
auto fs = std::unique_ptr<SeekStream>(dynamic_cast<SeekStream *>(
SeekStream::Create(filename.c_str(), "w", true)));
CHECK(fs) << "File name is not a valid local file name";
// Write DGL MetaData
const uint64_t kVersion = 1;
fs->Write(kDGLSerializeMagic);
fs->Write(kVersion);
fs->Write(GraphType::kImmutableGraph);
fs->Seek(4096);
// Write Graph Meta Data
dgl_id_t num_graph = graph_data.size();
std::vector<dgl_id_t> graph_indices(num_graph);
std::vector<int64_t> nodes_num_list(num_graph);
std::vector<int64_t> edges_num_list(num_graph);
for (uint64_t i = 0; i < num_graph; ++i) {
nodes_num_list[i] = graph_data[i]->gptr->NumVertices();
edges_num_list[i] = graph_data[i]->gptr->NumEdges();
}
// Reserve spaces for graph indices
fs->Write(num_graph);
dgl_id_t indices_start_ptr = fs->Tell();
fs->Write(graph_indices);
fs->Write(nodes_num_list);
fs->Write(edges_num_list);
fs->Write(labels_list);
// Write GraphData
for (uint64_t i = 0; i < num_graph; ++i) {
graph_indices[i] = fs->Tell();
GraphDataObject gdata = *graph_data[i].as<GraphDataObject>();
fs->Write(gdata);
}
fs->Seek(indices_start_ptr);
fs->Write(graph_indices);
return true;
}
StorageMetaData LoadDGLGraphs(const std::string &filename,
std::vector<dgl_id_t> idx_list, bool onlyMeta) {
auto fs = std::unique_ptr<SeekStream>(
SeekStream::CreateForRead(filename.c_str(), true));
CHECK(fs) << "Filename is invalid";
// Read DGL MetaData
uint64_t magicNum, graphType, version;
fs->Read(&magicNum);
fs->Read(&version);
fs->Read(&graphType);
fs->Seek(4096);
CHECK_EQ(magicNum, kDGLSerializeMagic) << "Invalid DGL files";
CHECK_EQ(version, 1) << "Invalid DGL files";
StorageMetaData metadata = StorageMetaData::Create();
// Read Graph MetaData
dgl_id_t num_graph;
CHECK(fs->Read(&num_graph)) << "Invalid num of graph";
std::vector<dgl_id_t> graph_indices;
std::vector<int64_t> nodes_num_list;
std::vector<int64_t> edges_num_list;
std::vector<NamedTensor> labels_list;
CHECK(fs->Read(&graph_indices)) << "Invalid graph indices";
CHECK(fs->Read(&nodes_num_list)) << "Invalid node num list";
CHECK(fs->Read(&edges_num_list)) << "Invalid edge num list";
CHECK(fs->Read(&labels_list)) << "Invalid label list";
metadata->SetMetaData(num_graph, nodes_num_list, edges_num_list, labels_list);
std::vector<GraphData> gdata_refs;
// Early Return
if (onlyMeta) {
return metadata;
}
if (idx_list.empty()) {
// Read All Graphs
gdata_refs.reserve(num_graph);
for (uint64_t i = 0; i < num_graph; ++i) {
GraphData gdata = GraphData::Create();
GraphDataObject *gdata_ptr =
const_cast<GraphDataObject *>(gdata.as<GraphDataObject>());
fs->Read(gdata_ptr);
gdata_refs.push_back(gdata);
}
} else {
// Read Selected Graphss
gdata_refs.reserve(idx_list.size());
// Would be better if idx_list is sorted. However the returned the graphs
// should be the same order as the idx_list
for (uint64_t i = 0; i < idx_list.size(); ++i) {
fs->Seek(graph_indices[idx_list[i]]);
GraphData gdata = GraphData::Create();
GraphDataObject *gdata_ptr =
const_cast<GraphDataObject *>(gdata.as<GraphDataObject>());
fs->Read(gdata_ptr);
gdata_refs.push_back(gdata);
}
}
metadata->SetGraphData(gdata_refs);
return metadata;
}
void GraphDataObject::SetData(ImmutableGraphPtr gptr,
Map<std::string, Value> node_tensors,
Map<std::string, Value> edge_tensors) {
this->gptr = gptr;
for (auto kv : node_tensors) {
std::string name = kv.first;
Value v = kv.second;
NDArray ndarray = static_cast<NDArray>(v->data);
this->node_tensors.emplace_back(name, ndarray);
}
for (auto kv : edge_tensors) {
std::string &name = kv.first;
Value v = kv.second;
const NDArray &ndarray = static_cast<NDArray>(v->data);
this->edge_tensors.emplace_back(name, ndarray);
}
}
void GraphDataObject::Save(dmlc::Stream *fs) const {
// Using in csr for storage
const CSRPtr g_csr = this->gptr->GetInCSR();
fs->Write(g_csr->indptr());
fs->Write(g_csr->indices());
fs->Write(g_csr->edge_ids());
fs->Write(node_tensors);
fs->Write(edge_tensors);
}
bool GraphDataObject::Load(dmlc::Stream *fs) {
NDArray indptr, indices, edge_ids;
fs->Read(&indptr);
fs->Read(&indices);
fs->Read(&edge_ids);
this->gptr = ImmutableGraph::CreateFromCSR(indptr, indices, edge_ids, "in");
fs->Read(&this->node_tensors);
fs->Read(&this->edge_tensors);
return true;
}
ImmutableGraphPtr BatchLoadedGraphs(std::vector<GraphData> gdata_list) {
std::vector<GraphPtr> gptrs;
gptrs.reserve(gdata_list.size());
for (auto gdata : gdata_list) {
gptrs.push_back(static_cast<GraphPtr>(gdata->gptr));
}
ImmutableGraphPtr imGPtr =
std::dynamic_pointer_cast<ImmutableGraph>(GraphOp::DisjointUnion(gptrs));
return imGPtr;
}
ImmutableGraphPtr ToImmutableGraph(GraphPtr g) {
ImmutableGraphPtr imgr = std::dynamic_pointer_cast<ImmutableGraph>(g);
if (imgr) {
return imgr;
} else {
MutableGraphPtr mgr = std::dynamic_pointer_cast<Graph>(g);
CHECK(mgr) << "Invalid Graph Pointer";
EdgeArray earray = mgr->Edges("eid");
IdArray srcs_array = earray.src;
IdArray dsts_array = earray.dst;
ImmutableGraphPtr imgptr =
ImmutableGraph::CreateFromCOO(mgr->NumVertices(), srcs_array, dsts_array);
return imgptr;
}
}
void StorageMetaDataObject::SetMetaData(dgl_id_t num_graph,
std::vector<int64_t> nodes_num_list,
std::vector<int64_t> edges_num_list,
std::vector<NamedTensor> labels_list) {
this->num_graph = num_graph;
this->nodes_num_list = Value(MakeValue(aten::VecToIdArray(nodes_num_list)));
this->edges_num_list = Value(MakeValue(aten::VecToIdArray(edges_num_list)));
for (auto kv : labels_list) {
this->labels_list.Set(kv.first, Value(MakeValue(kv.second)));
}
}
void StorageMetaDataObject::SetGraphData(std::vector<GraphData> gdata) {
this->graph_data = List<GraphData>(gdata);
}
} // namespace serialize
} // namespace dgl
...@@ -32,28 +32,32 @@ ...@@ -32,28 +32,32 @@
* *
*/ */
#include "graph_serialize.h" #include "graph_serialize.h"
#include <dmlc/io.h>
#include <dmlc/type_traits.h> #include <dgl/graph_op.h>
#include <dgl/runtime/container.h>
#include <dgl/immutable_graph.h> #include <dgl/immutable_graph.h>
#include <dgl/runtime/container.h>
#include <dgl/runtime/object.h> #include <dgl/runtime/object.h>
#include <dgl/graph_op.h> #include <dmlc/io.h>
#include <dmlc/logging.h>
#include <dmlc/type_traits.h>
#include <algorithm>
#include <iostream> #include <iostream>
#include <string> #include <string>
#include <vector>
#include <algorithm>
#include <utility> #include <utility>
#include <vector>
using namespace dgl::runtime; using namespace dgl::runtime;
using dgl::COO; using dgl::COO;
using dgl::COOPtr; using dgl::COOPtr;
using dgl::ImmutableGraph; using dgl::ImmutableGraph;
using dmlc::SeekStream;
using dgl::runtime::NDArray; using dgl::runtime::NDArray;
using std::vector;
using dgl::serialize::GraphData; using dgl::serialize::GraphData;
using dgl::serialize::GraphDataObject; using dgl::serialize::GraphDataObject;
using dmlc::SeekStream;
using dmlc::Stream;
using std::vector;
namespace dmlc { namespace dmlc {
DMLC_DECLARE_TRAITS(has_saveload, GraphDataObject, true); DMLC_DECLARE_TRAITS(has_saveload, GraphDataObject, true);
...@@ -62,13 +66,8 @@ DMLC_DECLARE_TRAITS(has_saveload, GraphDataObject, true); ...@@ -62,13 +66,8 @@ DMLC_DECLARE_TRAITS(has_saveload, GraphDataObject, true);
namespace dgl { namespace dgl {
namespace serialize { namespace serialize {
enum GraphType {
kMutableGraph = 0ull,
kImmutableGraph = 1ull
};
DGL_REGISTER_GLOBAL("data.graph_serialize._CAPI_MakeGraphData") DGL_REGISTER_GLOBAL("data.graph_serialize._CAPI_MakeGraphData")
.set_body([](DGLArgs args, DGLRetValue *rv) { .set_body([](DGLArgs args, DGLRetValue *rv) {
GraphRef gptr = args[0]; GraphRef gptr = args[0];
ImmutableGraphPtr imGPtr = ToImmutableGraph(gptr.sptr()); ImmutableGraphPtr imGPtr = ToImmutableGraph(gptr.sptr());
Map<std::string, Value> node_tensors = args[1]; Map<std::string, Value> node_tensors = args[1];
...@@ -76,10 +75,10 @@ DGL_REGISTER_GLOBAL("data.graph_serialize._CAPI_MakeGraphData") ...@@ -76,10 +75,10 @@ DGL_REGISTER_GLOBAL("data.graph_serialize._CAPI_MakeGraphData")
GraphData gd = GraphData::Create(); GraphData gd = GraphData::Create();
gd->SetData(imGPtr, node_tensors, edge_tensors); gd->SetData(imGPtr, node_tensors, edge_tensors);
*rv = gd; *rv = gd;
}); });
DGL_REGISTER_GLOBAL("data.graph_serialize._CAPI_DGLSaveGraphs") DGL_REGISTER_GLOBAL("data.graph_serialize._CAPI_SaveDGLGraphs_V0")
.set_body([](DGLArgs args, DGLRetValue *rv) { .set_body([](DGLArgs args, DGLRetValue *rv) {
std::string filename = args[0]; std::string filename = args[0];
List<GraphData> graph_data = args[1]; List<GraphData> graph_data = args[1];
Map<std::string, Value> labels = args[2]; Map<std::string, Value> labels = args[2];
...@@ -91,257 +90,67 @@ DGL_REGISTER_GLOBAL("data.graph_serialize._CAPI_DGLSaveGraphs") ...@@ -91,257 +90,67 @@ DGL_REGISTER_GLOBAL("data.graph_serialize._CAPI_DGLSaveGraphs")
labels_list.emplace_back(name, ndarray); labels_list.emplace_back(name, ndarray);
} }
SaveDGLGraphs(filename, graph_data, labels_list); SaveDGLGraphs(filename, graph_data, labels_list);
}); });
DGL_REGISTER_GLOBAL("data.graph_serialize._CAPI_DGLLoadGraphs")
.set_body([](DGLArgs args, DGLRetValue *rv) {
std::string filename = args[0];
List<Value> idxs = args[1];
bool onlyMeta = args[2];
std::vector<dgl_id_t> idx_list(idxs.size());
for (uint64_t i = 0; i < idxs.size(); ++i) {
idx_list[i] = static_cast<dgl_id_t >(idxs[i]->data);
}
*rv = LoadDGLGraphs(filename, idx_list, onlyMeta);
});
DGL_REGISTER_GLOBAL("data.graph_serialize._CAPI_GDataGraphHandle") DGL_REGISTER_GLOBAL("data.graph_serialize._CAPI_GDataGraphHandle")
.set_body([](DGLArgs args, DGLRetValue *rv) { .set_body([](DGLArgs args, DGLRetValue *rv) {
GraphData gdata = args[0]; GraphData gdata = args[0];
*rv = gdata->gptr; *rv = gdata->gptr;
}); });
DGL_REGISTER_GLOBAL("data.graph_serialize._CAPI_GDataNodeTensors") DGL_REGISTER_GLOBAL("data.graph_serialize._CAPI_GDataNodeTensors")
.set_body([](DGLArgs args, DGLRetValue *rv) { .set_body([](DGLArgs args, DGLRetValue *rv) {
GraphData gdata = args[0]; GraphData gdata = args[0];
Map<std::string, Value> rvmap; Map<std::string, Value> rvmap;
for (auto kv : gdata->node_tensors) { for (auto kv : gdata->node_tensors) {
rvmap.Set(kv.first, Value(MakeValue(kv.second))); rvmap.Set(kv.first, Value(MakeValue(kv.second)));
} }
*rv = rvmap; *rv = rvmap;
}); });
DGL_REGISTER_GLOBAL("data.graph_serialize._CAPI_GDataEdgeTensors") DGL_REGISTER_GLOBAL("data.graph_serialize._CAPI_GDataEdgeTensors")
.set_body([](DGLArgs args, DGLRetValue *rv) { .set_body([](DGLArgs args, DGLRetValue *rv) {
GraphData gdata = args[0]; GraphData gdata = args[0];
Map<std::string, Value> rvmap; Map<std::string, Value> rvmap;
for (auto kv : gdata->edge_tensors) { for (auto kv : gdata->edge_tensors) {
rvmap.Set(kv.first, Value(MakeValue(kv.second))); rvmap.Set(kv.first, Value(MakeValue(kv.second)));
} }
*rv = rvmap; *rv = rvmap;
}); });
constexpr uint64_t kDGLSerializeMagic = 0xDD2E4FF046B4A13F;
bool SaveDGLGraphs(std::string filename,
List<GraphData> graph_data,
std::vector<NamedTensor> labels_list) {
auto *fs = dynamic_cast<SeekStream *>(SeekStream::Create(filename.c_str(), "w",
true));
CHECK(fs) << "File name is not a valid local file name";
// Write DGL MetaData
const uint64_t kVersion = 1;
fs->Write(kDGLSerializeMagic);
fs->Write(kVersion);
fs->Write(kImmutableGraph);
fs->Seek(4096);
// Write Graph Meta Data
dgl_id_t num_graph = graph_data.size();
std::vector<dgl_id_t> graph_indices(num_graph);
std::vector<int64_t> nodes_num_list(num_graph);
std::vector<int64_t> edges_num_list(num_graph);
for (uint64_t i = 0; i < num_graph; ++i) {
nodes_num_list[i] = graph_data[i]->gptr->NumVertices();
edges_num_list[i] = graph_data[i]->gptr->NumEdges();
}
// Reserve spaces for graph indices
fs->Write(num_graph);
dgl_id_t indices_start_ptr = fs->Tell();
fs->Write(graph_indices);
fs->Write(nodes_num_list);
fs->Write(edges_num_list);
fs->Write(labels_list);
// Write GraphData uint64_t GetFileVersion(const std::string &filename) {
for (uint64_t i = 0; i < num_graph; ++i) { auto fs = std::unique_ptr<SeekStream>(
graph_indices[i] = fs->Tell(); SeekStream::CreateForRead(filename.c_str(), false));
GraphDataObject gdata = *graph_data[i].as<GraphDataObject>(); CHECK(fs) << "File " << filename << " not found";
fs->Write(gdata); uint64_t magicNum, version;
}
fs->Seek(indices_start_ptr);
fs->Write(graph_indices);
std::vector<dgl_id_t> test;
fs->Seek(indices_start_ptr);
fs->Read(&test);
delete fs;
return true;
}
StorageMetaData LoadDGLGraphs(const std::string &filename,
std::vector<dgl_id_t> idx_list,
bool onlyMeta) {
SeekStream *fs = SeekStream::CreateForRead(filename.c_str(), true);
CHECK(fs) << "Filename is invalid";
StorageMetaData metadata = StorageMetaData::Create();
// Read DGL MetaData
uint64_t magicNum, graphType, version;
fs->Read(&magicNum); fs->Read(&magicNum);
fs->Read(&graphType);
fs->Read(&version); fs->Read(&version);
fs->Seek(4096);
CHECK_EQ(magicNum, kDGLSerializeMagic) << "Invalid DGL files"; CHECK_EQ(magicNum, kDGLSerializeMagic) << "Invalid DGL files";
CHECK_EQ(graphType, kImmutableGraph) << "Invalid DGL files"; return version;
CHECK_EQ(version, 1) << "Invalid Serialization Version";
// Read Graph MetaData
dgl_id_t num_graph;
CHECK(fs->Read(&num_graph)) << "Invalid num of graph";
std::vector<dgl_id_t> graph_indices;
std::vector<int64_t> nodes_num_list;
std::vector<int64_t> edges_num_list;
std::vector<NamedTensor> labels_list;
CHECK(fs->Read(&graph_indices)) << "Invalid graph indices";
CHECK(fs->Read(&nodes_num_list)) << "Invalid node num list";
CHECK(fs->Read(&edges_num_list)) << "Invalid edge num list";
CHECK(fs->Read(&labels_list)) << "Invalid label list";
metadata->SetMetaData(num_graph, nodes_num_list, edges_num_list, labels_list);
std::vector<GraphData> gdata_refs;
// Early Return
if (onlyMeta) {
delete fs;
return metadata;
}
if (idx_list.empty()) {
// Read All Graphs
gdata_refs.reserve(num_graph);
for (uint64_t i = 0; i < num_graph; ++i) {
GraphData gdata = GraphData::Create();
GraphDataObject *gdata_ptr =
const_cast<GraphDataObject *>(gdata.as<GraphDataObject>());
fs->Read(gdata_ptr);
gdata_refs.push_back(gdata);
}
} else {
// Read Selected Graphss
gdata_refs.reserve(idx_list.size());
// Would be better if idx_list is sorted. However the returned the graphs should be the same
// order as the idx_list
for (uint64_t i = 0; i < idx_list.size(); ++i) {
fs->Seek(graph_indices[idx_list[i]]);
GraphData gdata = GraphData::Create();
GraphDataObject *gdata_ptr =
const_cast<GraphDataObject *>(gdata.as<GraphDataObject>());
fs->Read(gdata_ptr);
gdata_refs.push_back(gdata);
}
}
metadata->SetGraphData(gdata_refs);
delete fs;
return metadata;
}
void GraphDataObject::SetData(ImmutableGraphPtr gptr,
Map<std::string, Value> node_tensors,
Map<std::string, Value> edge_tensors) {
this->gptr = gptr;
for (auto kv : node_tensors) {
std::string name = kv.first;
Value v = kv.second;
NDArray ndarray = static_cast<NDArray>(v->data);
this->node_tensors.emplace_back(name, ndarray);
}
for (auto kv : edge_tensors) {
std::string &name = kv.first;
Value v = kv.second;
const NDArray &ndarray = static_cast<NDArray>(v->data);
this->edge_tensors.emplace_back(name, ndarray);
}
}
void GraphDataObject::Save(dmlc::Stream *fs) const {
// Using in csr for storage
const CSRPtr g_csr = this->gptr->GetInCSR();
fs->Write(g_csr->indptr());
fs->Write(g_csr->indices());
fs->Write(g_csr->edge_ids());
fs->Write(node_tensors);
fs->Write(edge_tensors);
}
bool GraphDataObject::Load(dmlc::Stream *fs) {
NDArray indptr, indices, edge_ids;
fs->Read(&indptr);
fs->Read(&indices);
fs->Read(&edge_ids);
this->gptr = ImmutableGraph::CreateFromCSR(indptr, indices, edge_ids, "in");
fs->Read(&this->node_tensors);
fs->Read(&this->edge_tensors);
return true;
} }
ImmutableGraphPtr BatchLoadedGraphs(std::vector<GraphData> gdata_list) { DGL_REGISTER_GLOBAL("data.graph_serialize._CAPI_GetFileVersion")
std::vector<GraphPtr> gptrs; .set_body([](DGLArgs args, DGLRetValue *rv) {
gptrs.reserve(gdata_list.size()); std::string filename = args[0];
for (auto gdata : gdata_list) { *rv = static_cast<int64_t>(GetFileVersion(filename));
gptrs.push_back(static_cast<GraphPtr>(gdata->gptr)); });
}
ImmutableGraphPtr imGPtr = std::dynamic_pointer_cast<ImmutableGraph>(
GraphOp::DisjointUnion(gptrs));
return imGPtr;
}
ImmutableGraphPtr ToImmutableGraph(GraphPtr g) {
ImmutableGraphPtr imgr = std::dynamic_pointer_cast<ImmutableGraph>(g);
if (imgr) {
return imgr;
} else {
MutableGraphPtr mgr = std::dynamic_pointer_cast<Graph>(g);
CHECK(mgr) << "Invalid Graph Pointer";
EdgeArray earray = mgr->Edges("eid");
IdArray srcs_array = earray.src;
IdArray dsts_array = earray.dst;
ImmutableGraphPtr imgptr = ImmutableGraph::CreateFromCOO(mgr->NumVertices(), srcs_array,
dsts_array);
return imgptr;
}
}
void StorageMetaDataObject::SetMetaData(dgl_id_t num_graph, DGL_REGISTER_GLOBAL("data.graph_serialize._CAPI_LoadGraphFiles_V1")
std::vector<int64_t> nodes_num_list, .set_body([](DGLArgs args, DGLRetValue *rv) {
std::vector<int64_t> edges_num_list, std::string filename = args[0];
std::vector<NamedTensor> labels_list) { List<Value> idxs = args[1];
this->num_graph = num_graph; bool onlyMeta = args[2];
this->nodes_num_list = Value(MakeValue(aten::VecToIdArray(nodes_num_list))); auto idx_list = ListValueToVector<dgl_id_t>(idxs);
this->edges_num_list = Value(MakeValue(aten::VecToIdArray(edges_num_list))); *rv = LoadDGLGraphs(filename, idx_list, onlyMeta);
for (auto kv : labels_list) { });
this->labels_list.Set(kv.first, Value(MakeValue(kv.second)));
}
}
void StorageMetaDataObject::SetGraphData(std::vector<GraphData> gdata) { DGL_REGISTER_GLOBAL("data.graph_serialize._CAPI_LoadGraphFiles_V2")
this->graph_data = List<GraphData>(gdata); .set_body([](DGLArgs args, DGLRetValue *rv) {
} std::string filename = args[0];
List<Value> idxs = args[1];
auto idx_list = ListValueToVector<dgl_id_t>(idxs);
*rv = List<HeteroGraphData>(LoadHeteroGraphs(filename, idx_list));
});
} // namespace serialize } // namespace serialize
} // namespace dgl } // namespace dgl
...@@ -6,72 +6,54 @@ ...@@ -6,72 +6,54 @@
#ifndef DGL_GRAPH_SERIALIZE_GRAPH_SERIALIZE_H_ #ifndef DGL_GRAPH_SERIALIZE_GRAPH_SERIALIZE_H_
#define DGL_GRAPH_SERIALIZE_GRAPH_SERIALIZE_H_ #define DGL_GRAPH_SERIALIZE_GRAPH_SERIALIZE_H_
#include <dgl/graph.h>
#include <dgl/array.h> #include <dgl/array.h>
#include <dgl/graph.h>
#include <dgl/immutable_graph.h> #include <dgl/immutable_graph.h>
#include <dmlc/io.h> #include <dgl/packed_func_ext.h>
#include <dmlc/type_traits.h>
#include <dgl/runtime/ndarray.h>
#include <dgl/runtime/container.h> #include <dgl/runtime/container.h>
#include <dgl/runtime/ndarray.h>
#include <dgl/runtime/object.h> #include <dgl/runtime/object.h>
#include <dgl/packed_func_ext.h> #include <dmlc/io.h>
#include <dmlc/type_traits.h>
#include <algorithm>
#include <iostream> #include <iostream>
#include <memory>
#include <string> #include <string>
#include <vector>
#include <algorithm>
#include <utility> #include <utility>
#include <memory> #include <vector>
#include "../../c_api_common.h" #include "../../c_api_common.h"
#include "dglgraph_data.h"
#include "heterograph_data.h"
using dgl::runtime::NDArray;
using dgl::ImmutableGraph; using dgl::ImmutableGraph;
using dgl::runtime::NDArray;
using namespace dgl::runtime; using namespace dgl::runtime;
namespace dgl { namespace dgl {
namespace serialize { namespace serialize {
typedef std::pair<std::string, NDArray> NamedTensor; enum GraphType : uint64_t {
kMutableGraph = 0ull,
class GraphDataObject : public runtime::Object { kImmutableGraph = 1ull,
public: kHeteroGraph = 2ull
ImmutableGraphPtr gptr;
std::vector<NamedTensor> node_tensors;
std::vector<NamedTensor> edge_tensors;
static constexpr const char *_type_key = "graph_serialize.GraphData";
void SetData(ImmutableGraphPtr gptr,
Map<std::string, Value> node_tensors,
Map<std::string, Value> edge_tensors);
void Save(dmlc::Stream *fs) const;
bool Load(dmlc::Stream *fs);
DGL_DECLARE_OBJECT_TYPE_INFO(GraphDataObject, runtime::Object);
}; };
constexpr uint64_t kDGLSerializeMagic = 0xDD2E4FF046B4A13F;
class GraphData : public runtime::ObjectRef {
public:
DGL_DEFINE_OBJECT_REF_METHODS(GraphData, runtime::ObjectRef, GraphDataObject);
/*! \brief create a new GraphData reference */
static GraphData Create() {
return GraphData(std::make_shared<GraphDataObject>());
}
};
class StorageMetaDataObject : public runtime::Object { class StorageMetaDataObject : public runtime::Object {
public: public:
// For saving DGLGraph
dgl_id_t num_graph; dgl_id_t num_graph;
Value nodes_num_list; Value nodes_num_list;
Value edges_num_list; Value edges_num_list;
Map<std::string, Value> labels_list; Map<std::string, Value> labels_list;
List<GraphData> graph_data; List<GraphData> graph_data;
static constexpr const char *_type_key = "graph_serialize.StorageMetaData"; static constexpr const char *_type_key = "graph_serialize.StorageMetaData";
void SetMetaData(dgl_id_t num_graph, void SetMetaData(dgl_id_t num_graph, std::vector<int64_t> nodes_num_list,
std::vector<int64_t> nodes_num_list,
std::vector<int64_t> edges_num_list, std::vector<int64_t> edges_num_list,
std::vector<NamedTensor> labels_list); std::vector<NamedTensor> labels_list);
...@@ -88,10 +70,10 @@ class StorageMetaDataObject : public runtime::Object { ...@@ -88,10 +70,10 @@ class StorageMetaDataObject : public runtime::Object {
DGL_DECLARE_OBJECT_TYPE_INFO(StorageMetaDataObject, runtime::Object); DGL_DECLARE_OBJECT_TYPE_INFO(StorageMetaDataObject, runtime::Object);
}; };
class StorageMetaData : public runtime::ObjectRef { class StorageMetaData : public runtime::ObjectRef {
public: public:
DGL_DEFINE_OBJECT_REF_METHODS(StorageMetaData, runtime::ObjectRef, StorageMetaDataObject); DGL_DEFINE_OBJECT_REF_METHODS(StorageMetaData, runtime::ObjectRef,
StorageMetaDataObject);
/*! \brief create a new StorageMetaData reference */ /*! \brief create a new StorageMetaData reference */
static StorageMetaData Create() { static StorageMetaData Create() {
...@@ -99,14 +81,18 @@ class StorageMetaData : public runtime::ObjectRef { ...@@ -99,14 +81,18 @@ class StorageMetaData : public runtime::ObjectRef {
} }
}; };
StorageMetaData LoadDGLGraphFiles(const std::string &filename,
std::vector<dgl_id_t> idx_list,
bool onlyMeta);
bool SaveDGLGraphs(std::string filename, StorageMetaData LoadDGLGraphs(const std::string &filename,
List<GraphData> graph_data, std::vector<dgl_id_t> idx_list, bool onlyMeta);
bool SaveDGLGraphs(std::string filename, List<GraphData> graph_data,
std::vector<NamedTensor> labels_list); std::vector<NamedTensor> labels_list);
StorageMetaData LoadDGLGraphs(const std::string &filename, std::vector<HeteroGraphData> LoadHeteroGraphs(const std::string &filename,
std::vector<dgl_id_t> idx_list, std::vector<dgl_id_t> idx_list);
bool onlyMeta = false);
ImmutableGraphPtr ToImmutableGraph(GraphPtr g); ImmutableGraphPtr ToImmutableGraph(GraphPtr g);
......
/*!
* Copyright (c) 2019 by Contributors
* \file graph/serialize/heterograph_data.h
* \brief Graph serialization header
*/
#ifndef DGL_GRAPH_SERIALIZE_HETEROGRAPH_DATA_H_
#define DGL_GRAPH_SERIALIZE_HETEROGRAPH_DATA_H_
#include <dgl/array.h>
#include <dgl/graph.h>
#include <dgl/packed_func_ext.h>
#include <dgl/runtime/container.h>
#include <dgl/runtime/ndarray.h>
#include <dgl/runtime/object.h>
#include <dmlc/io.h>
#include <dmlc/type_traits.h>
#include <algorithm>
#include <iostream>
#include <string>
#include <utility>
#include <vector>
#include "../../c_api_common.h"
#include "../heterograph.h"
using dgl::runtime::NDArray;
using namespace dgl::runtime;
namespace dgl {
namespace serialize {
typedef std::pair<std::string, NDArray> NamedTensor;
class HeteroGraphDataObject : public runtime::Object {
public:
std::shared_ptr<HeteroGraph> gptr;
std::vector<std::vector<NamedTensor>> node_tensors;
std::vector<std::vector<NamedTensor>> edge_tensors;
std::vector<std::string> etype_names;
std::vector<std::string> ntype_names;
static constexpr const char *_type_key =
"heterograph_serialize.HeteroGraphData";
HeteroGraphDataObject() {}
HeteroGraphDataObject(HeteroGraphPtr gptr,
List<Map<std::string, Value>> ndata,
List<Map<std::string, Value>> edata,
List<Value> ntype_names, List<Value> etype_names) {
this->gptr = std::dynamic_pointer_cast<HeteroGraph>(gptr);
CHECK_NOTNULL(this->gptr);
for (auto nd_dict : ndata) {
node_tensors.emplace_back();
for (auto kv : nd_dict) {
auto last = &node_tensors.back();
NDArray ndarray = kv.second->data;
last->emplace_back(kv.first, ndarray);
}
}
for (auto nd_dict : edata) {
edge_tensors.emplace_back();
for (auto kv : nd_dict) {
auto last = &edge_tensors.back();
NDArray ndarray = kv.second->data;
last->emplace_back(kv.first, ndarray);
}
}
this->ntype_names = ListValueToVector<std::string>(ntype_names);
this->etype_names = ListValueToVector<std::string>(etype_names);
}
void Save(dmlc::Stream *fs) const {
fs->Write(gptr);
fs->Write(node_tensors);
fs->Write(edge_tensors);
fs->Write(ntype_names);
fs->Write(etype_names);
}
bool Load(dmlc::Stream *fs) {
fs->Read(&gptr);
fs->Read(&node_tensors);
fs->Read(&edge_tensors);
fs->Read(&ntype_names);
fs->Read(&etype_names);
return true;
}
DGL_DECLARE_OBJECT_TYPE_INFO(HeteroGraphDataObject, runtime::Object);
};
class HeteroGraphData : public runtime::ObjectRef {
public:
DGL_DEFINE_OBJECT_REF_METHODS(HeteroGraphData, runtime::ObjectRef,
HeteroGraphDataObject);
/*! \brief create a new GraphData reference */
static HeteroGraphData Create(HeteroGraphPtr gptr,
List<Map<std::string, Value>> node_tensors,
List<Map<std::string, Value>> edge_tensors,
List<Value> ntype_names,
List<Value> etype_names) {
return HeteroGraphData(std::make_shared<HeteroGraphDataObject>(
gptr, node_tensors, edge_tensors, ntype_names, etype_names));
}
/*! \brief create an empty GraphData reference */
static HeteroGraphData Create() {
return HeteroGraphData(std::make_shared<HeteroGraphDataObject>());
}
};
} // namespace serialize
} // namespace dgl
namespace dmlc {
DMLC_DECLARE_TRAITS(has_saveload, dgl::serialize::HeteroGraphDataObject, true);
}
#endif // DGL_GRAPH_SERIALIZE_HETEROGRAPH_DATA_H_
/*!
* Copyright (c) 2019 by Contributors
* \file graph/serialize/heterograph_serialize.cc
* \brief DGLHeteroGraph serialization implementation
*
* The storage structure is
* {
* // MetaData Section
* uint64_t kDGLSerializeMagic
* uint64_t kVersion = 2
* uint64_t GraphType = kDGLHeteroGraph
* dgl_id_t num_graphs
* ** Reserved Area till 4kB **
*
* uint64_t gdata_start_pos (This stores the start position of graph_data,
* which is used to skip label dict part if unnecessary)
* vector<pair<string, NDArray>> label_dict (To store the dict[str, NDArray])
*
* vector<HeteroGraphData> graph_datas;
* vector<dgl_id_t> graph_indices (start address of each graph)
* uint64_t size_of_graph_indices_vector (Used to seek to graph_indices
* vector)
*
* }
*
* Storage of HeteroGraphData is
* {
* HeteroGraphPtr ptr;
* vector<vector<pair<string, NDArray>>> node_tensors;
* vector<vector<pair<string, NDArray>>> edge_tensors;
* vector<string> ntype_name;
* vector<string> etype_name;
* }
*
*/
#include <dgl/graph_op.h>
#include <dgl/immutable_graph.h>
#include <dgl/runtime/container.h>
#include <dgl/runtime/object.h>
#include <dmlc/io.h>
#include <dmlc/type_traits.h>
#include <algorithm>
#include <array>
#include <iostream>
#include <string>
#include <utility>
#include <vector>
#include "../heterograph.h"
#include "./graph_serialize.h"
#include "./streamwithcount.h"
#include "dmlc/memory_io.h"
namespace dgl {
namespace serialize {
using namespace dgl::runtime;
using dmlc::SeekStream;
using dmlc::Stream;
using dmlc::io::FileSystem;
using dmlc::io::URI;
bool SaveHeteroGraphs(std::string filename, List<HeteroGraphData> hdata,
const std::vector<NamedTensor> &nd_list) {
auto fs = std::unique_ptr<StreamWithCount>(
StreamWithCount::Create(filename.c_str(), "w", false));
CHECK(fs->IsValid()) << "File name " << filename << " is not a valid name";
// Write DGL MetaData
const uint64_t kVersion = 2;
std::array<char, 4096> meta_buffer;
// Write metadata into char buffer with size 4096
dmlc::MemoryFixedSizeStream meta_fs_(meta_buffer.data(), 4096);
auto meta_fs = static_cast<Stream *>(&meta_fs_);
meta_fs->Write(kDGLSerializeMagic);
meta_fs->Write(kVersion);
meta_fs->Write(GraphType::kHeteroGraph);
uint64_t num_graph = hdata.size();
meta_fs->Write(num_graph);
// Write metadata into files
fs->Write(meta_buffer.data(), 4096);
// Calculate label dict binary size
std::string labels_blob;
dmlc::MemoryStringStream label_fs_(&labels_blob);
auto label_fs = static_cast<Stream *>(&label_fs_);
label_fs->Write(nd_list);
uint64_t gdata_start_pos =
fs->Count() + sizeof(uint64_t) + labels_blob.size();
// Write start position of gdata, which can be skipped when only reading gdata
// And label dict
fs->Write(gdata_start_pos);
fs->Write(labels_blob.c_str(), labels_blob.size());
std::vector<uint64_t> graph_indices(num_graph);
// Write HeteroGraphData
for (uint64_t i = 0; i < num_graph; ++i) {
graph_indices[i] = fs->Count();
auto gdata = hdata[i].sptr();
fs->Write(gdata);
}
// Write indptr into string to count size
std::string indptr_blob;
dmlc::MemoryStringStream indptr_fs_(&indptr_blob);
auto indptr_fs = static_cast<Stream *>(&indptr_fs_);
indptr_fs->Write(graph_indices);
uint64_t indptr_buffer_size = indptr_blob.size();
fs->Write(indptr_blob);
fs->Write(indptr_buffer_size);
return true;
}
std::vector<HeteroGraphData> LoadHeteroGraphs(const std::string &filename,
std::vector<dgl_id_t> idx_list) {
auto fs = std::unique_ptr<SeekStream>(
SeekStream::CreateForRead(filename.c_str(), false));
CHECK(fs) << "File name " << filename << " is not a valid name";
// Read DGL MetaData
uint64_t magicNum, graphType, version, num_graph;
fs->Read(&magicNum);
fs->Read(&version);
fs->Read(&graphType);
CHECK(fs->Read(&num_graph)) << "Invalid num of graph";
fs->Seek(4096);
CHECK_EQ(magicNum, kDGLSerializeMagic) << "Invalid DGL files";
CHECK_EQ(version, 2) << "Invalid GraphType";
CHECK_EQ(graphType, GraphType::kHeteroGraph) << "Invalid GraphType";
uint64_t gdata_start_pos;
fs->Read(&gdata_start_pos);
// Skip labels part
fs->Seek(gdata_start_pos);
std::vector<HeteroGraphData> gdata_refs;
if (idx_list.empty()) {
// Read All Graphs
gdata_refs.reserve(num_graph);
for (uint64_t i = 0; i < num_graph; ++i) {
HeteroGraphData gdata = HeteroGraphData::Create();
auto hetero_data = gdata.sptr();
fs->Read(&hetero_data);
gdata_refs.push_back(gdata);
}
} else {
uint64_t gdata_start_pos = fs->Tell();
// Read Selected Graphss
gdata_refs.reserve(idx_list.size());
URI uri(filename.c_str());
uint64_t filesize = FileSystem::GetInstance(uri)->GetPathInfo(uri).size;
fs->Seek(filesize - sizeof(uint64_t));
uint64_t indptr_buffer_size;
fs->Read(&indptr_buffer_size);
std::vector<uint64_t> graph_indices(num_graph);
fs->Seek(filesize - sizeof(uint64_t) - indptr_buffer_size);
fs->Read(&graph_indices);
fs->Seek(gdata_start_pos);
// Would be better if idx_list is sorted. However the returned the graphs
// should be the same order as the idx_list
for (uint64_t i = 0; i < idx_list.size(); ++i) {
fs->Seek(graph_indices[idx_list[i]]);
HeteroGraphData gdata = HeteroGraphData::Create();
auto hetero_data = gdata.sptr();
fs->Read(&hetero_data);
gdata_refs.push_back(gdata);
}
}
return gdata_refs;
}
std::vector<NamedTensor> LoadLabels_V2(const std::string &filename) {
auto fs = std::unique_ptr<SeekStream>(
SeekStream::CreateForRead(filename.c_str(), false));
CHECK(fs) << "File name " << filename << " is not a valid name";
// Read DGL MetaData
uint64_t magicNum, graphType, version, num_graph;
fs->Read(&magicNum);
fs->Read(&version);
fs->Read(&graphType);
CHECK(fs->Read(&num_graph)) << "Invalid num of graph";
fs->Seek(4096);
uint64_t gdata_start_pos;
fs->Read(&gdata_start_pos);
std::vector<NamedTensor> labels_list;
fs->Read(&labels_list);
return labels_list;
}
DGL_REGISTER_GLOBAL("data.heterograph_serialize._CAPI_MakeHeteroGraphData")
.set_body([](DGLArgs args, DGLRetValue *rv) {
HeteroGraphRef hg = args[0];
List<Map<std::string, Value>> ndata = args[1];
List<Map<std::string, Value>> edata = args[2];
List<Value> ntype_names = args[3];
List<Value> etype_names = args[4];
*rv = HeteroGraphData::Create(hg.sptr(), ndata, edata, ntype_names,
etype_names);
});
DGL_REGISTER_GLOBAL("data.heterograph_serialize._CAPI_SaveHeteroGraphData")
.set_body([](DGLArgs args, DGLRetValue *rv) {
std::string filename = args[0];
List<HeteroGraphData> hgdata = args[1];
Map<std::string, Value> nd_map = args[2];
std::vector<NamedTensor> nd_list;
for (auto kv : nd_map) {
NDArray ndarray = static_cast<NDArray>(kv.second->data);
nd_list.emplace_back(kv.first, ndarray);
}
*rv = dgl::serialize::SaveHeteroGraphs(filename, hgdata, nd_list);
});
DGL_REGISTER_GLOBAL(
"data.heterograph_serialize._CAPI_GetGindexFromHeteroGraphData")
.set_body([](DGLArgs args, DGLRetValue *rv) {
HeteroGraphData hdata = args[0];
*rv = HeteroGraphRef(hdata->gptr);
});
DGL_REGISTER_GLOBAL(
"data.heterograph_serialize._CAPI_GetEtypesFromHeteroGraphData")
.set_body([](DGLArgs args, DGLRetValue *rv) {
HeteroGraphData hdata = args[0];
List<Value> etype_names;
for (const auto &name : hdata->etype_names) {
etype_names.push_back(Value(MakeValue(name)));
}
*rv = etype_names;
});
DGL_REGISTER_GLOBAL(
"data.heterograph_serialize._CAPI_GetNtypesFromHeteroGraphData")
.set_body([](DGLArgs args, DGLRetValue *rv) {
HeteroGraphData hdata = args[0];
List<Value> ntype_names;
for (auto name : hdata->ntype_names) {
ntype_names.push_back(Value(MakeValue(name)));
}
*rv = ntype_names;
});
DGL_REGISTER_GLOBAL(
"data.heterograph_serialize._CAPI_GetNDataFromHeteroGraphData")
.set_body([](DGLArgs args, DGLRetValue *rv) {
HeteroGraphData hdata = args[0];
List<List<Value>> ntensors;
for (auto tensor_list : hdata->node_tensors) {
List<Value> nlist;
for (const auto &kv : tensor_list) {
nlist.push_back(Value(MakeValue(kv.first)));
nlist.push_back(Value(MakeValue(kv.second)));
}
ntensors.push_back(nlist);
}
*rv = ntensors;
});
DGL_REGISTER_GLOBAL(
"data.heterograph_serialize._CAPI_GetEDataFromHeteroGraphData")
.set_body([](DGLArgs args, DGLRetValue *rv) {
HeteroGraphData hdata = args[0];
List<List<Value>> etensors;
for (auto tensor_list : hdata->edge_tensors) {
List<Value> elist;
for (const auto &kv : tensor_list) {
elist.push_back(Value(MakeValue(kv.first)));
elist.push_back(Value(MakeValue(kv.second)));
}
etensors.push_back(elist);
}
*rv = etensors;
});
DGL_REGISTER_GLOBAL("data.graph_serialize._CAPI_LoadLabels_V2")
.set_body([](DGLArgs args, DGLRetValue *rv) {
std::string filename = args[0];
auto labels_list = LoadLabels_V2(filename);
Map<std::string, Value> rvmap;
for (auto kv : labels_list) {
rvmap.Set(kv.first, Value(MakeValue(kv.second)));
}
*rv = rvmap;
});
} // namespace serialize
} // namespace dgl
/*!
* Copyright (c) 2019 by Contributors
* \file graph/serialize/streamwithcount.h
* \brief Graph serialization header
*/
#ifndef DGL_GRAPH_SERIALIZE_STREAMWITHCOUNT_H_
#define DGL_GRAPH_SERIALIZE_STREAMWITHCOUNT_H_
#include <dmlc/io.h>
#include <dmlc/type_traits.h>
/*!
* \brief StreamWithCount counts the bytes that already written into the
* underlying stream.
*/
class StreamWithCount : dmlc::Stream {
public:
static StreamWithCount *Create(const char *uri, const char *const flag,
bool allow_null = false) {
return new StreamWithCount(uri, flag, allow_null);
}
size_t Read(void *ptr, size_t size) override {
return strm_->Read(ptr, size);
}
void Write(const void *ptr, size_t size) override {
count_ += size;
strm_->Write(ptr, size);
}
using dmlc::Stream::Read;
using dmlc::Stream::Write;
bool IsValid() { return strm_.get(); }
uint64_t Count() const { return count_; }
private:
StreamWithCount(const char *uri, const char *const flag, bool allow_null)
: strm_(dmlc::Stream::Create(uri, flag, allow_null)) {}
std::unique_ptr<dmlc::Stream> strm_;
uint64_t count_ = 0;
};
#endif // DGL_GRAPH_SERIALIZE_STREAMWITHCOUNT_H_
...@@ -24,7 +24,8 @@ constexpr uint64_t kDGLSerialize_Tensors = 0xDD5A9FBE3FA2443F; ...@@ -24,7 +24,8 @@ constexpr uint64_t kDGLSerialize_Tensors = 0xDD5A9FBE3FA2443F;
DGL_REGISTER_GLOBAL("data.tensor_serialize._CAPI_SaveNDArrayDict") DGL_REGISTER_GLOBAL("data.tensor_serialize._CAPI_SaveNDArrayDict")
.set_body([](DGLArgs args, DGLRetValue *rv) { .set_body([](DGLArgs args, DGLRetValue *rv) {
std::string filename = args[0]; std::string filename = args[0];
auto *fs = dmlc::Stream::Create(filename.c_str(), "w"); auto fs = std::unique_ptr<dmlc::Stream>(
dmlc::Stream::Create(filename.c_str(), "w"));
CHECK(fs) << "Filename is invalid"; CHECK(fs) << "Filename is invalid";
fs->Write(kDGLSerialize_Tensors); fs->Write(kDGLSerialize_Tensors);
bool empty_dict = args[2]; bool empty_dict = args[2];
...@@ -40,13 +41,13 @@ DGL_REGISTER_GLOBAL("data.tensor_serialize._CAPI_SaveNDArrayDict") ...@@ -40,13 +41,13 @@ DGL_REGISTER_GLOBAL("data.tensor_serialize._CAPI_SaveNDArrayDict")
} }
fs->Write(namedTensors); fs->Write(namedTensors);
*rv = true; *rv = true;
delete fs;
}); });
DGL_REGISTER_GLOBAL("data.tensor_serialize._CAPI_LoadNDArrayDict") DGL_REGISTER_GLOBAL("data.tensor_serialize._CAPI_LoadNDArrayDict")
.set_body([](DGLArgs args, DGLRetValue *rv) { .set_body([](DGLArgs args, DGLRetValue *rv) {
std::string filename = args[0]; std::string filename = args[0];
auto *fs = dmlc::Stream::Create(filename.c_str(), "r"); auto fs = std::unique_ptr<dmlc::Stream>(
dmlc::Stream::Create(filename.c_str(), "r"));
CHECK(fs) << "Filename is invalid or file doesn't exists"; CHECK(fs) << "Filename is invalid or file doesn't exists";
uint64_t magincNum, num_elements; uint64_t magincNum, num_elements;
CHECK(fs->Read(&magincNum)) << "Invalid file"; CHECK(fs->Read(&magincNum)) << "Invalid file";
...@@ -60,7 +61,6 @@ DGL_REGISTER_GLOBAL("data.tensor_serialize._CAPI_LoadNDArrayDict") ...@@ -60,7 +61,6 @@ DGL_REGISTER_GLOBAL("data.tensor_serialize._CAPI_LoadNDArrayDict")
nd_dict.Set(kv.first, ndarray); nd_dict.Set(kv.first, ndarray);
} }
*rv = nd_dict; *rv = nd_dict;
delete fs;
}); });
} // namespace serialize } // namespace serialize
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment