Unverified Commit 8798872f authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[Bug] Do not skip graphconv even no edge exists (#3416)

parent 7c7b60be
...@@ -165,8 +165,6 @@ class HeteroGraphConv(nn.Block): ...@@ -165,8 +165,6 @@ class HeteroGraphConv(nn.Block):
src_inputs, dst_inputs = inputs src_inputs, dst_inputs = inputs
for stype, etype, dtype in g.canonical_etypes: for stype, etype, dtype in g.canonical_etypes:
rel_graph = g[stype, etype, dtype] rel_graph = g[stype, etype, dtype]
if rel_graph.number_of_edges() == 0:
continue
if stype not in src_inputs or dtype not in dst_inputs: if stype not in src_inputs or dtype not in dst_inputs:
continue continue
dstdata = self.mods[etype]( dstdata = self.mods[etype](
...@@ -178,8 +176,6 @@ class HeteroGraphConv(nn.Block): ...@@ -178,8 +176,6 @@ class HeteroGraphConv(nn.Block):
else: else:
for stype, etype, dtype in g.canonical_etypes: for stype, etype, dtype in g.canonical_etypes:
rel_graph = g[stype, etype, dtype] rel_graph = g[stype, etype, dtype]
if rel_graph.number_of_edges() == 0:
continue
if stype not in inputs: if stype not in inputs:
continue continue
dstdata = self.mods[etype]( dstdata = self.mods[etype](
......
...@@ -169,8 +169,6 @@ class HeteroGraphConv(nn.Module): ...@@ -169,8 +169,6 @@ class HeteroGraphConv(nn.Module):
for stype, etype, dtype in g.canonical_etypes: for stype, etype, dtype in g.canonical_etypes:
rel_graph = g[stype, etype, dtype] rel_graph = g[stype, etype, dtype]
if rel_graph.number_of_edges() == 0:
continue
if stype not in src_inputs or dtype not in dst_inputs: if stype not in src_inputs or dtype not in dst_inputs:
continue continue
dstdata = self.mods[etype]( dstdata = self.mods[etype](
...@@ -182,8 +180,6 @@ class HeteroGraphConv(nn.Module): ...@@ -182,8 +180,6 @@ class HeteroGraphConv(nn.Module):
else: else:
for stype, etype, dtype in g.canonical_etypes: for stype, etype, dtype in g.canonical_etypes:
rel_graph = g[stype, etype, dtype] rel_graph = g[stype, etype, dtype]
if rel_graph.number_of_edges() == 0:
continue
if stype not in inputs: if stype not in inputs:
continue continue
dstdata = self.mods[etype]( dstdata = self.mods[etype](
......
...@@ -169,8 +169,6 @@ class HeteroGraphConv(layers.Layer): ...@@ -169,8 +169,6 @@ class HeteroGraphConv(layers.Layer):
src_inputs, dst_inputs = inputs src_inputs, dst_inputs = inputs
for stype, etype, dtype in g.canonical_etypes: for stype, etype, dtype in g.canonical_etypes:
rel_graph = g[stype, etype, dtype] rel_graph = g[stype, etype, dtype]
if rel_graph.number_of_edges() == 0:
continue
if stype not in src_inputs or dtype not in dst_inputs: if stype not in src_inputs or dtype not in dst_inputs:
continue continue
dstdata = self.mods[etype]( dstdata = self.mods[etype](
...@@ -182,8 +180,6 @@ class HeteroGraphConv(layers.Layer): ...@@ -182,8 +180,6 @@ class HeteroGraphConv(layers.Layer):
else: else:
for stype, etype, dtype in g.canonical_etypes: for stype, etype, dtype in g.canonical_etypes:
rel_graph = g[stype, etype, dtype] rel_graph = g[stype, etype, dtype]
if rel_graph.number_of_edges() == 0:
continue
if stype not in inputs: if stype not in inputs:
continue continue
dstdata = self.mods[etype]( dstdata = self.mods[etype](
......
...@@ -788,6 +788,19 @@ def test_hetero_conv(agg, idtype): ...@@ -788,6 +788,19 @@ def test_hetero_conv(agg, idtype):
assert mod2.carg1 == 1 assert mod2.carg1 == 1
assert mod3.carg1 == 0 assert mod3.carg1 == 0
#conv on graph without any edges
for etype in g.etypes:
g = dgl.remove_edges(g, g.edges(form='eid', etype=etype), etype=etype)
assert g.num_edges() == 0
h = conv(g, {'user': uf, 'game': gf, 'store': sf})
assert set(h.keys()) == {'user', 'game'}
block = dgl.to_block(g.to(F.cpu()), {'user': [0, 1, 2, 3], 'game': [
0, 1, 2, 3], 'store': []}).to(F.ctx())
h = conv(block, ({'user': uf, 'game': gf, 'store': sf},
{'user': uf, 'game': gf, 'store': sf[0:0]}))
assert set(h.keys()) == {'user', 'game'}
if __name__ == '__main__': if __name__ == '__main__':
test_graph_conv() test_graph_conv()
test_gat_conv() test_gat_conv()
...@@ -809,3 +822,4 @@ if __name__ == '__main__': ...@@ -809,3 +822,4 @@ if __name__ == '__main__':
test_simple_pool() test_simple_pool()
test_rgcn() test_rgcn()
test_sequential() test_sequential()
test_hetero_conv()
...@@ -1112,6 +1112,19 @@ def test_hetero_conv(agg, idtype): ...@@ -1112,6 +1112,19 @@ def test_hetero_conv(agg, idtype):
assert mod3.carg1 == 0 assert mod3.carg1 == 0
assert mod3.carg2 == 1 assert mod3.carg2 == 1
#conv on graph without any edges
for etype in g.etypes:
g = dgl.remove_edges(g, g.edges(form='eid', etype=etype), etype=etype)
assert g.num_edges() == 0
h = conv(g, {'user': uf, 'game': gf, 'store': sf})
assert set(h.keys()) == {'user', 'game'}
block = dgl.to_block(g.to(F.cpu()), {'user': [0, 1, 2, 3], 'game': [
0, 1, 2, 3], 'store': []}).to(F.ctx())
h = conv(block, ({'user': uf, 'game': gf, 'store': sf},
{'user': uf, 'game': gf, 'store': sf[0:0]}))
assert set(h.keys()) == {'user', 'game'}
if __name__ == '__main__': if __name__ == '__main__':
test_graph_conv() test_graph_conv()
test_graph_conv_e_weight() test_graph_conv_e_weight()
...@@ -1140,3 +1153,4 @@ if __name__ == '__main__': ...@@ -1140,3 +1153,4 @@ if __name__ == '__main__':
test_sequential() test_sequential()
test_atomic_conv() test_atomic_conv()
test_cf_conv() test_cf_conv()
test_hetero_conv()
...@@ -502,6 +502,19 @@ def test_hetero_conv(agg, idtype): ...@@ -502,6 +502,19 @@ def test_hetero_conv(agg, idtype):
assert mod3.carg1 == 0 assert mod3.carg1 == 0
assert mod3.carg2 == 1 assert mod3.carg2 == 1
#conv on graph without any edges
for etype in g.etypes:
g = dgl.remove_edges(g, g.edges(form='eid', etype=etype), etype=etype)
assert g.num_edges() == 0
h = conv(g, {'user': uf, 'game': gf, 'store': sf})
assert set(h.keys()) == {'user', 'game'}
block = dgl.to_block(g.to(F.cpu()), {'user': [0, 1, 2, 3], 'game': [
0, 1, 2, 3], 'store': []}).to(F.ctx())
h = conv(block, ({'user': uf, 'game': gf, 'store': sf},
{'user': uf, 'game': gf, 'store': sf[0:0]}))
assert set(h.keys()) == {'user', 'game'}
@pytest.mark.parametrize('out_dim', [1, 2]) @pytest.mark.parametrize('out_dim', [1, 2])
def test_dense_cheb_conv(out_dim): def test_dense_cheb_conv(out_dim):
...@@ -549,3 +562,4 @@ if __name__ == '__main__': ...@@ -549,3 +562,4 @@ if __name__ == '__main__':
# test_dense_sage_conv() # test_dense_sage_conv()
test_dense_cheb_conv() test_dense_cheb_conv()
# test_sequential() # test_sequential()
test_hetero_conv()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment