Unverified Commit c13903bf authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

fix batched heterograph serializations (#1794)

parent 200340ab
...@@ -302,6 +302,14 @@ class BatchedDGLHeteroGraph(DGLHeteroGraph): ...@@ -302,6 +302,14 @@ class BatchedDGLHeteroGraph(DGLHeteroGraph):
batch_num_nodes=self._batch_num_nodes, batch_num_nodes=self._batch_num_nodes,
batch_num_edges=self._batch_num_edges) batch_num_edges=self._batch_num_edges)
def __getstate__(self):
state = super().__getstate__()
return state, self._batch_size, self._batch_num_nodes, self._batch_num_edges
def __setstate__(self, state):
state, self._batch_size, self._batch_num_nodes, self._batch_num_edges = state
super().__setstate__(state)
def unbatch_hetero(graph): def unbatch_hetero(graph):
"""Return the list of heterographs in this batch. """Return the list of heterographs in this batch.
......
...@@ -78,6 +78,13 @@ def _assert_is_identical_batchedgraph(bg1, bg2): ...@@ -78,6 +78,13 @@ def _assert_is_identical_batchedgraph(bg1, bg2):
assert bg1.batch_num_nodes == bg2.batch_num_nodes assert bg1.batch_num_nodes == bg2.batch_num_nodes
assert bg1.batch_num_edges == bg2.batch_num_edges assert bg1.batch_num_edges == bg2.batch_num_edges
def _assert_is_identical_batchedhetero(bg1, bg2):
_assert_is_identical_hetero(bg1, bg2)
for ntype in bg1.ntypes:
assert bg1.batch_num_nodes(ntype) == bg2.batch_num_nodes(ntype)
for canonical_etype in bg1.canonical_etypes:
assert bg1.batch_num_edges(canonical_etype) == bg2.batch_num_edges(canonical_etype)
def _assert_is_identical_index(i1, i2): def _assert_is_identical_index(i1, i2):
assert i1.slice_data() == i2.slice_data() assert i1.slice_data() == i2.slice_data()
assert F.array_equal(i1.tousertensor(), i2.tousertensor()) assert F.array_equal(i1.tousertensor(), i2.tousertensor())
...@@ -258,6 +265,33 @@ def test_pickling_heterograph(): ...@@ -258,6 +265,33 @@ def test_pickling_heterograph():
new_g = _reconstruct_pickle(g) new_g = _reconstruct_pickle(g)
_assert_is_identical_hetero(g, new_g) _assert_is_identical_hetero(g, new_g)
def test_pickling_batched_heterograph():
# copied from test_heterograph.create_test_heterograph()
plays_spmat = ssp.coo_matrix(([1, 1, 1, 1], ([0, 1, 2, 1], [0, 0, 1, 1])))
wishes_nx = nx.DiGraph()
wishes_nx.add_nodes_from(['u0', 'u1', 'u2'], bipartite=0)
wishes_nx.add_nodes_from(['g0', 'g1'], bipartite=1)
wishes_nx.add_edge('u0', 'g1', id=0)
wishes_nx.add_edge('u2', 'g0', id=1)
follows_g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows')
plays_g = dgl.bipartite(plays_spmat, 'user', 'plays', 'game')
wishes_g = dgl.bipartite(wishes_nx, 'user', 'wishes', 'game')
develops_g = dgl.bipartite([(0, 0), (1, 1)], 'developer', 'develops', 'game')
g = dgl.hetero_from_relations([follows_g, plays_g, wishes_g, develops_g])
g2 = dgl.hetero_from_relations([follows_g, plays_g, wishes_g, develops_g])
g.nodes['user'].data['u_h'] = F.randn((3, 4))
g.nodes['game'].data['g_h'] = F.randn((2, 5))
g.edges['plays'].data['p_h'] = F.randn((4, 6))
g2.nodes['user'].data['u_h'] = F.randn((3, 4))
g2.nodes['game'].data['g_h'] = F.randn((2, 5))
g2.edges['plays'].data['p_h'] = F.randn((4, 6))
bg = dgl.batch_hetero([g, g2])
new_bg = _reconstruct_pickle(bg)
_assert_is_identical_batchedhetero(bg, new_bg)
@unittest.skipIf(dgl.backend.backend_name != "pytorch", reason="Only test for pytorch format file") @unittest.skipIf(dgl.backend.backend_name != "pytorch", reason="Only test for pytorch format file")
def test_pickling_heterograph_index_compatibility(): def test_pickling_heterograph_index_compatibility():
plays_spmat = ssp.coo_matrix(([1, 1, 1, 1], ([0, 1, 2, 1], [0, 0, 1, 1]))) plays_spmat = ssp.coo_matrix(([1, 1, 1, 1], ([0, 1, 2, 1], [0, 0, 1, 1])))
...@@ -287,3 +321,4 @@ if __name__ == '__main__': ...@@ -287,3 +321,4 @@ if __name__ == '__main__':
test_pickling_nodeflow() test_pickling_nodeflow()
test_pickling_batched_graph() test_pickling_batched_graph()
test_pickling_heterograph() test_pickling_heterograph()
test_pickling_batched_heterograph()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment