Unverified Commit 6694f7b9 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[Bug] Fix DGLHeteroGraph.edge_type_subgraph (#1040)

* Update

* Try CI
parent 168fc2cf
......@@ -1752,12 +1752,12 @@ class DGLHeteroGraph(object):
rel_graphs = [self._graph.get_relation_graph(i) for i in etype_ids]
meta_src = meta_src.tonumpy()
meta_dst = meta_dst.tonumpy()
induced_ntype_ids = list(set(meta_src) | set(meta_dst))
mapped_meta_src = [induced_ntype_ids[v] for v in meta_src]
mapped_meta_dst = [induced_ntype_ids[v] for v in meta_dst]
node_frames = [self._node_frames[i] for i in induced_ntype_ids]
ntypes_invmap = {n: i for i, n in enumerate(set(meta_src) | set(meta_dst))}
mapped_meta_src = [ntypes_invmap[v] for v in meta_src]
mapped_meta_dst = [ntypes_invmap[v] for v in meta_dst]
node_frames = [self._node_frames[i] for i in ntypes_invmap]
edge_frames = [self._edge_frames[i] for i in etype_ids]
induced_ntypes = [self._ntypes[i] for i in induced_ntype_ids]
induced_ntypes = [self._ntypes[i] for i in ntypes_invmap]
induced_etypes = [self._etypes[i] for i in etype_ids] # get the "name" of edge type
metagraph = graph_index.from_edge_list((mapped_meta_src, mapped_meta_dst), True, True)
......
......@@ -732,7 +732,7 @@ def test_subgraph():
sg2 = g.edge_subgraph({'follows': [1], 'plays': [1], 'wishes': [1]})
_check_subgraph(g, sg2)
def _check_typed_subgraph(g, sg):
def _check_typed_subgraph1(g, sg):
assert set(sg.ntypes) == {'user', 'game'}
assert set(sg.etypes) == {'follows', 'plays', 'wishes'}
for ntype in sg.ntypes:
......@@ -749,10 +749,23 @@ def test_subgraph():
assert F.array_equal(sg.nodes['user'].data['h'], g.nodes['user'].data['h'])
assert F.array_equal(sg.edges['follows'].data['h'], g.edges['follows'].data['h'])
def _check_typed_subgraph2(g, sg):
assert set(sg.ntypes) == {'developer', 'game'}
assert set(sg.etypes) == {'develops'}
for ntype in sg.ntypes:
assert sg.number_of_nodes(ntype) == g.number_of_nodes(ntype)
for etype in sg.etypes:
src_sg, dst_sg = sg.all_edges(etype=etype, order='eid')
src_g, dst_g = g.all_edges(etype=etype, order='eid')
assert F.array_equal(src_sg, src_g)
assert F.array_equal(dst_sg, dst_g)
sg3 = g.node_type_subgraph(['user', 'game'])
_check_typed_subgraph(g, sg3)
sg4 = g.edge_type_subgraph(['follows', 'plays', 'wishes'])
_check_typed_subgraph(g, sg4)
_check_typed_subgraph1(g, sg3)
sg4 = g.edge_type_subgraph(['develops'])
_check_typed_subgraph2(g, sg4)
sg5 = g.edge_type_subgraph(['follows', 'plays', 'wishes'])
_check_typed_subgraph1(g, sg5)
def test_apply():
def node_udf(nodes):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment