Unverified Commit 8798872f authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[Bug] Do not skip graphconv even no edge exists (#3416)

parent 7c7b60be
......@@ -165,8 +165,6 @@ class HeteroGraphConv(nn.Block):
src_inputs, dst_inputs = inputs
for stype, etype, dtype in g.canonical_etypes:
rel_graph = g[stype, etype, dtype]
if rel_graph.number_of_edges() == 0:
continue
if stype not in src_inputs or dtype not in dst_inputs:
continue
dstdata = self.mods[etype](
......@@ -178,8 +176,6 @@ class HeteroGraphConv(nn.Block):
else:
for stype, etype, dtype in g.canonical_etypes:
rel_graph = g[stype, etype, dtype]
if rel_graph.number_of_edges() == 0:
continue
if stype not in inputs:
continue
dstdata = self.mods[etype](
......
......@@ -169,8 +169,6 @@ class HeteroGraphConv(nn.Module):
for stype, etype, dtype in g.canonical_etypes:
rel_graph = g[stype, etype, dtype]
if rel_graph.number_of_edges() == 0:
continue
if stype not in src_inputs or dtype not in dst_inputs:
continue
dstdata = self.mods[etype](
......@@ -182,8 +180,6 @@ class HeteroGraphConv(nn.Module):
else:
for stype, etype, dtype in g.canonical_etypes:
rel_graph = g[stype, etype, dtype]
if rel_graph.number_of_edges() == 0:
continue
if stype not in inputs:
continue
dstdata = self.mods[etype](
......
......@@ -169,8 +169,6 @@ class HeteroGraphConv(layers.Layer):
src_inputs, dst_inputs = inputs
for stype, etype, dtype in g.canonical_etypes:
rel_graph = g[stype, etype, dtype]
if rel_graph.number_of_edges() == 0:
continue
if stype not in src_inputs or dtype not in dst_inputs:
continue
dstdata = self.mods[etype](
......@@ -182,8 +180,6 @@ class HeteroGraphConv(layers.Layer):
else:
for stype, etype, dtype in g.canonical_etypes:
rel_graph = g[stype, etype, dtype]
if rel_graph.number_of_edges() == 0:
continue
if stype not in inputs:
continue
dstdata = self.mods[etype](
......
......@@ -788,6 +788,19 @@ def test_hetero_conv(agg, idtype):
assert mod2.carg1 == 1
assert mod3.carg1 == 0
#conv on graph without any edges
for etype in g.etypes:
g = dgl.remove_edges(g, g.edges(form='eid', etype=etype), etype=etype)
assert g.num_edges() == 0
h = conv(g, {'user': uf, 'game': gf, 'store': sf})
assert set(h.keys()) == {'user', 'game'}
block = dgl.to_block(g.to(F.cpu()), {'user': [0, 1, 2, 3], 'game': [
0, 1, 2, 3], 'store': []}).to(F.ctx())
h = conv(block, ({'user': uf, 'game': gf, 'store': sf},
{'user': uf, 'game': gf, 'store': sf[0:0]}))
assert set(h.keys()) == {'user', 'game'}
if __name__ == '__main__':
test_graph_conv()
test_gat_conv()
......@@ -809,3 +822,4 @@ if __name__ == '__main__':
test_simple_pool()
test_rgcn()
test_sequential()
test_hetero_conv()
......@@ -1112,6 +1112,19 @@ def test_hetero_conv(agg, idtype):
assert mod3.carg1 == 0
assert mod3.carg2 == 1
#conv on graph without any edges
for etype in g.etypes:
g = dgl.remove_edges(g, g.edges(form='eid', etype=etype), etype=etype)
assert g.num_edges() == 0
h = conv(g, {'user': uf, 'game': gf, 'store': sf})
assert set(h.keys()) == {'user', 'game'}
block = dgl.to_block(g.to(F.cpu()), {'user': [0, 1, 2, 3], 'game': [
0, 1, 2, 3], 'store': []}).to(F.ctx())
h = conv(block, ({'user': uf, 'game': gf, 'store': sf},
{'user': uf, 'game': gf, 'store': sf[0:0]}))
assert set(h.keys()) == {'user', 'game'}
if __name__ == '__main__':
test_graph_conv()
test_graph_conv_e_weight()
......@@ -1140,3 +1153,4 @@ if __name__ == '__main__':
test_sequential()
test_atomic_conv()
test_cf_conv()
test_hetero_conv()
......@@ -502,6 +502,19 @@ def test_hetero_conv(agg, idtype):
assert mod3.carg1 == 0
assert mod3.carg2 == 1
#conv on graph without any edges
for etype in g.etypes:
g = dgl.remove_edges(g, g.edges(form='eid', etype=etype), etype=etype)
assert g.num_edges() == 0
h = conv(g, {'user': uf, 'game': gf, 'store': sf})
assert set(h.keys()) == {'user', 'game'}
block = dgl.to_block(g.to(F.cpu()), {'user': [0, 1, 2, 3], 'game': [
0, 1, 2, 3], 'store': []}).to(F.ctx())
h = conv(block, ({'user': uf, 'game': gf, 'store': sf},
{'user': uf, 'game': gf, 'store': sf[0:0]}))
assert set(h.keys()) == {'user', 'game'}
@pytest.mark.parametrize('out_dim', [1, 2])
def test_dense_cheb_conv(out_dim):
......@@ -549,3 +562,4 @@ if __name__ == '__main__':
# test_dense_sage_conv()
test_dense_cheb_conv()
# test_sequential()
test_hetero_conv()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment