Unverified Commit 3ef757db authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Feature] Pickle support for heterograph (#1177)



* [Feature] Pickle support for heterograph

* fix lint
Co-authored-by: default avatarVoVAllen <VoVAllen@users.noreply.github.com>
parent ca302a13
...@@ -25,12 +25,27 @@ class HeteroGraphIndex(ObjectBase): ...@@ -25,12 +25,27 @@ class HeteroGraphIndex(ObjectBase):
return obj return obj
def __getstate__(self): def __getstate__(self):
# TODO metagraph = self.metagraph
return number_of_nodes = [self.number_of_nodes(i) for i in range(self.number_of_ntypes())]
edges = [self.edges(i, order='eid') for i in range(self.number_of_etypes())]
# multigraph and readonly are not used.
return metagraph, number_of_nodes, edges
def __setstate__(self, state): def __setstate__(self, state):
# TODO metagraph, number_of_nodes, edges = state
pass
self._cache = {}
# loop over etypes and recover unit graphs
rel_graphs = []
for i, edges_per_type in enumerate(edges):
src_ntype, dst_ntype = metagraph.find_edge(i)
num_src = number_of_nodes[src_ntype]
num_dst = number_of_nodes[dst_ntype]
src_id, dst_id, _ = edges_per_type
rel_graphs.append(create_unitgraph_from_coo(
1 if src_ntype == dst_ntype else 2, num_src, num_dst, src_id, dst_id))
self.__init_handle_by_constructor__(
_CAPI_DGLHeteroCreateHeteroGraph, metagraph, rel_graphs)
@property @property
def metagraph(self): def metagraph(self):
......
...@@ -442,11 +442,14 @@ def cached_member(cache, prefix): ...@@ -442,11 +442,14 @@ def cached_member(cache, prefix):
""" """
def _creator(func): def _creator(func):
@wraps(func) @wraps(func)
def wrapper(self, *args): def wrapper(self, *args, **kwargs):
dic = getattr(self, cache) dic = getattr(self, cache)
key = '%s-%s' % (prefix, '-'.join([str(a) for a in args])) key = '%s-%s-%s' % (
prefix,
'-'.join([str(a) for a in args]),
'-'.join([str(k) + ':' + str(v) for k, v in kwargs.items()]))
if key not in dic: if key not in dic:
dic[key] = func(self, *args) dic[key] = func(self, *args, **kwargs)
return dic[key] return dic[key]
return wrapper return wrapper
return _creator return _creator
......
import networkx as nx import networkx as nx
import scipy.sparse as ssp
import dgl import dgl
import dgl.contrib as contrib import dgl.contrib as contrib
from dgl.frame import Frame, FrameRef, Column from dgl.frame import Frame, FrameRef, Column
...@@ -13,8 +14,8 @@ def _assert_is_identical(g, g2): ...@@ -13,8 +14,8 @@ def _assert_is_identical(g, g2):
assert g.is_multigraph == g2.is_multigraph assert g.is_multigraph == g2.is_multigraph
assert g.is_readonly == g2.is_readonly assert g.is_readonly == g2.is_readonly
assert g.number_of_nodes() == g2.number_of_nodes() assert g.number_of_nodes() == g2.number_of_nodes()
src, dst = g.all_edges() src, dst = g.all_edges(order='eid')
src2, dst2 = g2.all_edges() src2, dst2 = g2.all_edges(order='eid')
assert F.array_equal(src, src2) assert F.array_equal(src, src2)
assert F.array_equal(dst, dst2) assert F.array_equal(dst, dst2)
...@@ -25,6 +26,32 @@ def _assert_is_identical(g, g2): ...@@ -25,6 +26,32 @@ def _assert_is_identical(g, g2):
for k in g.edata: for k in g.edata:
assert F.allclose(g.edata[k], g2.edata[k]) assert F.allclose(g.edata[k], g2.edata[k])
def _assert_is_identical_hetero(g, g2):
assert g.is_multigraph == g2.is_multigraph
assert g.is_readonly == g2.is_readonly
assert g.ntypes == g2.ntypes
assert g.canonical_etypes == g2.canonical_etypes
# check if two metagraphs are identical
for edges, features in g.metagraph.edges(keys=True).items():
assert g2.metagraph.edges(keys=True)[edges] == features
# check if node ID spaces and feature spaces are equal
for ntype in g.ntypes:
assert g.number_of_nodes(ntype) == g2.number_of_nodes(ntype)
assert len(g.nodes[ntype].data) == len(g2.nodes[ntype].data)
for k in g.nodes[ntype].data:
assert F.allclose(g.nodes[ntype].data[k], g2.nodes[ntype].data[k])
# check if edge ID spaces and feature spaces are equal
for etype in g.canonical_etypes:
src, dst = g.all_edges(etype=etype, order='eid')
src2, dst2 = g2.all_edges(etype=etype, order='eid')
assert F.array_equal(src, src2)
assert F.array_equal(dst, dst2)
for k in g.edges[etype].data:
assert F.allclose(g.edges[etype].data[k], g2.edges[etype].data[k])
def _assert_is_identical_nodeflow(nf1, nf2): def _assert_is_identical_nodeflow(nf1, nf2):
assert nf1.is_multigraph == nf2.is_multigraph assert nf1.is_multigraph == nf2.is_multigraph
assert nf1.is_readonly == nf2.is_readonly assert nf1.is_readonly == nf2.is_readonly
...@@ -211,6 +238,29 @@ def test_pickling_batched_graph(): ...@@ -211,6 +238,29 @@ def test_pickling_batched_graph():
new_bg = _reconstruct_pickle(bg) new_bg = _reconstruct_pickle(bg)
_assert_is_identical_batchedgraph(bg, new_bg) _assert_is_identical_batchedgraph(bg, new_bg)
def test_pickling_heterograph():
# copied from test_heterograph.create_test_heterograph()
plays_spmat = ssp.coo_matrix(([1, 1, 1, 1], ([0, 1, 2, 1], [0, 0, 1, 1])))
wishes_nx = nx.DiGraph()
wishes_nx.add_nodes_from(['u0', 'u1', 'u2'], bipartite=0)
wishes_nx.add_nodes_from(['g0', 'g1'], bipartite=1)
wishes_nx.add_edge('u0', 'g1', id=0)
wishes_nx.add_edge('u2', 'g0', id=1)
follows_g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows')
plays_g = dgl.bipartite(plays_spmat, 'user', 'plays', 'game')
wishes_g = dgl.bipartite(wishes_nx, 'user', 'wishes', 'game')
develops_g = dgl.bipartite([(0, 0), (1, 1)], 'developer', 'develops', 'game')
g = dgl.hetero_from_relations([follows_g, plays_g, wishes_g, develops_g])
g.nodes['user'].data['u_h'] = F.randn((3, 4))
g.nodes['game'].data['g_h'] = F.randn((2, 5))
g.edges['plays'].data['p_h'] = F.randn((4, 6))
new_g = _reconstruct_pickle(g)
_assert_is_identical_hetero(g, new_g)
if __name__ == '__main__': if __name__ == '__main__':
test_pickling_index() test_pickling_index()
test_pickling_graph_index() test_pickling_graph_index()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment