Unverified Commit 3ef757db authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Feature] Pickle support for heterograph (#1177)



* [Feature] Pickle support for heterograph

* fix lint
Co-authored-by: default avatarVoVAllen <VoVAllen@users.noreply.github.com>
parent ca302a13
......@@ -25,12 +25,27 @@ class HeteroGraphIndex(ObjectBase):
return obj
def __getstate__(self):
# TODO
return
metagraph = self.metagraph
number_of_nodes = [self.number_of_nodes(i) for i in range(self.number_of_ntypes())]
edges = [self.edges(i, order='eid') for i in range(self.number_of_etypes())]
# multigraph and readonly are not used.
return metagraph, number_of_nodes, edges
def __setstate__(self, state):
# TODO
pass
metagraph, number_of_nodes, edges = state
self._cache = {}
# loop over etypes and recover unit graphs
rel_graphs = []
for i, edges_per_type in enumerate(edges):
src_ntype, dst_ntype = metagraph.find_edge(i)
num_src = number_of_nodes[src_ntype]
num_dst = number_of_nodes[dst_ntype]
src_id, dst_id, _ = edges_per_type
rel_graphs.append(create_unitgraph_from_coo(
1 if src_ntype == dst_ntype else 2, num_src, num_dst, src_id, dst_id))
self.__init_handle_by_constructor__(
_CAPI_DGLHeteroCreateHeteroGraph, metagraph, rel_graphs)
@property
def metagraph(self):
......
......@@ -442,11 +442,14 @@ def cached_member(cache, prefix):
"""
def _creator(func):
@wraps(func)
def wrapper(self, *args):
def wrapper(self, *args, **kwargs):
dic = getattr(self, cache)
key = '%s-%s' % (prefix, '-'.join([str(a) for a in args]))
key = '%s-%s-%s' % (
prefix,
'-'.join([str(a) for a in args]),
'-'.join([str(k) + ':' + str(v) for k, v in kwargs.items()]))
if key not in dic:
dic[key] = func(self, *args)
dic[key] = func(self, *args, **kwargs)
return dic[key]
return wrapper
return _creator
......
import networkx as nx
import scipy.sparse as ssp
import dgl
import dgl.contrib as contrib
from dgl.frame import Frame, FrameRef, Column
......@@ -13,8 +14,8 @@ def _assert_is_identical(g, g2):
assert g.is_multigraph == g2.is_multigraph
assert g.is_readonly == g2.is_readonly
assert g.number_of_nodes() == g2.number_of_nodes()
src, dst = g.all_edges()
src2, dst2 = g2.all_edges()
src, dst = g.all_edges(order='eid')
src2, dst2 = g2.all_edges(order='eid')
assert F.array_equal(src, src2)
assert F.array_equal(dst, dst2)
......@@ -25,6 +26,32 @@ def _assert_is_identical(g, g2):
for k in g.edata:
assert F.allclose(g.edata[k], g2.edata[k])
def _assert_is_identical_hetero(g, g2):
assert g.is_multigraph == g2.is_multigraph
assert g.is_readonly == g2.is_readonly
assert g.ntypes == g2.ntypes
assert g.canonical_etypes == g2.canonical_etypes
# check if two metagraphs are identical
for edges, features in g.metagraph.edges(keys=True).items():
assert g2.metagraph.edges(keys=True)[edges] == features
# check if node ID spaces and feature spaces are equal
for ntype in g.ntypes:
assert g.number_of_nodes(ntype) == g2.number_of_nodes(ntype)
assert len(g.nodes[ntype].data) == len(g2.nodes[ntype].data)
for k in g.nodes[ntype].data:
assert F.allclose(g.nodes[ntype].data[k], g2.nodes[ntype].data[k])
# check if edge ID spaces and feature spaces are equal
for etype in g.canonical_etypes:
src, dst = g.all_edges(etype=etype, order='eid')
src2, dst2 = g2.all_edges(etype=etype, order='eid')
assert F.array_equal(src, src2)
assert F.array_equal(dst, dst2)
for k in g.edges[etype].data:
assert F.allclose(g.edges[etype].data[k], g2.edges[etype].data[k])
def _assert_is_identical_nodeflow(nf1, nf2):
assert nf1.is_multigraph == nf2.is_multigraph
assert nf1.is_readonly == nf2.is_readonly
......@@ -211,6 +238,29 @@ def test_pickling_batched_graph():
new_bg = _reconstruct_pickle(bg)
_assert_is_identical_batchedgraph(bg, new_bg)
def test_pickling_heterograph():
# copied from test_heterograph.create_test_heterograph()
plays_spmat = ssp.coo_matrix(([1, 1, 1, 1], ([0, 1, 2, 1], [0, 0, 1, 1])))
wishes_nx = nx.DiGraph()
wishes_nx.add_nodes_from(['u0', 'u1', 'u2'], bipartite=0)
wishes_nx.add_nodes_from(['g0', 'g1'], bipartite=1)
wishes_nx.add_edge('u0', 'g1', id=0)
wishes_nx.add_edge('u2', 'g0', id=1)
follows_g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows')
plays_g = dgl.bipartite(plays_spmat, 'user', 'plays', 'game')
wishes_g = dgl.bipartite(wishes_nx, 'user', 'wishes', 'game')
develops_g = dgl.bipartite([(0, 0), (1, 1)], 'developer', 'develops', 'game')
g = dgl.hetero_from_relations([follows_g, plays_g, wishes_g, develops_g])
g.nodes['user'].data['u_h'] = F.randn((3, 4))
g.nodes['game'].data['g_h'] = F.randn((2, 5))
g.edges['plays'].data['p_h'] = F.randn((4, 6))
new_g = _reconstruct_pickle(g)
_assert_is_identical_hetero(g, new_g)
if __name__ == '__main__':
test_pickling_index()
test_pickling_graph_index()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment