Unverified Commit 1f6eba9e authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Bug] Heterogeneous graph convolution bugfix (#2578)



* fix heterograph conv

* remove test cases

* fix test

* fix test

* fix test
Co-authored-by: default avatarMinjie Wang <wmjlyjemaine@gmail.com>
parent e4ddafe9
...@@ -40,7 +40,7 @@ class HeteroGraphConv(nn.Block): ...@@ -40,7 +40,7 @@ class HeteroGraphConv(nn.Block):
``'user'`` and ``'game'`` nodes. ``'user'`` and ``'game'`` nodes.
>>> import mxnet.ndarray as nd >>> import mxnet.ndarray as nd
>>> h1 = {'user' : nd.randomrandn(g.number_of_nodes('user'), 5)} >>> h1 = {'user' : nd.random.randn(g.number_of_nodes('user'), 5)}
>>> h2 = conv(g, h1) >>> h2 = conv(g, h1)
>>> print(h2.keys()) >>> print(h2.keys())
dict_keys(['user', 'game']) dict_keys(['user', 'game'])
...@@ -167,7 +167,7 @@ class HeteroGraphConv(nn.Block): ...@@ -167,7 +167,7 @@ class HeteroGraphConv(nn.Block):
continue continue
dstdata = self.mods[etype]( dstdata = self.mods[etype](
rel_graph, rel_graph,
inputs[stype], (inputs[stype], inputs[dtype]),
*mod_args.get(etype, ()), *mod_args.get(etype, ()),
**mod_kwargs.get(etype, {})) **mod_kwargs.get(etype, {}))
outputs[dtype].append(dstdata) outputs[dtype].append(dstdata)
......
...@@ -169,7 +169,7 @@ class HeteroGraphConv(nn.Module): ...@@ -169,7 +169,7 @@ class HeteroGraphConv(nn.Module):
continue continue
dstdata = self.mods[etype]( dstdata = self.mods[etype](
rel_graph, rel_graph,
inputs[stype], (inputs[stype], inputs[dtype]),
*mod_args.get(etype, ()), *mod_args.get(etype, ()),
**mod_kwargs.get(etype, {})) **mod_kwargs.get(etype, {}))
outputs[dtype].append(dstdata) outputs[dtype].append(dstdata)
......
...@@ -164,7 +164,7 @@ class HeteroGraphConv(layers.Layer): ...@@ -164,7 +164,7 @@ class HeteroGraphConv(layers.Layer):
continue continue
dstdata = self.mods[etype]( dstdata = self.mods[etype](
rel_graph, rel_graph,
inputs[stype], (inputs[stype], inputs[dtype]),
*mod_args.get(etype, ()), *mod_args.get(etype, ()),
**mod_kwargs.get(etype, {})) **mod_kwargs.get(etype, {}))
outputs[dtype].append(dstdata) outputs[dtype].append(dstdata)
......
...@@ -707,19 +707,18 @@ def test_hetero_conv(agg, idtype): ...@@ -707,19 +707,18 @@ def test_hetero_conv(agg, idtype):
uf = F.randn((4, 2)) uf = F.randn((4, 2))
gf = F.randn((4, 4)) gf = F.randn((4, 4))
sf = F.randn((2, 3)) sf = F.randn((2, 3))
uf_dst = F.randn((4, 3))
gf_dst = F.randn((4, 4))
h = conv(g, {'user': uf}) h = conv(g, {'user': uf, 'store': sf, 'game': gf})
assert set(h.keys()) == {'user', 'game'} assert set(h.keys()) == {'user', 'game'}
if agg != 'stack': if agg != 'stack':
assert h['user'].shape == (4, 3) assert h['user'].shape == (4, 3)
assert h['game'].shape == (4, 4) assert h['game'].shape == (4, 4)
else: else:
assert h['user'].shape == (4, 1, 3) assert h['user'].shape == (4, 1, 3)
assert h['game'].shape == (4, 1, 4) assert h['game'].shape == (4, 2, 4)
h = conv(g, {'user': uf, 'store': sf}) block = dgl.to_block(g.to(F.cpu()), {'user': [0, 1, 2, 3], 'game': [0, 1, 2, 3], 'store': []}).to(F.ctx())
h = conv(block, ({'user': uf, 'game': gf, 'store': sf}, {'user': uf, 'game': gf, 'store': sf[0:0]}))
assert set(h.keys()) == {'user', 'game'} assert set(h.keys()) == {'user', 'game'}
if agg != 'stack': if agg != 'stack':
assert h['user'].shape == (4, 3) assert h['user'].shape == (4, 3)
...@@ -728,37 +727,14 @@ def test_hetero_conv(agg, idtype): ...@@ -728,37 +727,14 @@ def test_hetero_conv(agg, idtype):
assert h['user'].shape == (4, 1, 3) assert h['user'].shape == (4, 1, 3)
assert h['game'].shape == (4, 2, 4) assert h['game'].shape == (4, 2, 4)
h = conv(g, {'store': sf}) h = conv(block, {'user': uf, 'game': gf, 'store': sf})
assert set(h.keys()) == {'game'}
if agg != 'stack':
assert h['game'].shape == (4, 4)
else:
assert h['game'].shape == (4, 1, 4)
# test with pair input
conv = nn.HeteroGraphConv({
'follows': nn.SAGEConv(2, 3, 'mean'),
'plays': nn.SAGEConv((2, 4), 4, 'mean'),
'sells': nn.SAGEConv(3, 4, 'mean')},
agg)
conv.initialize(ctx=F.ctx())
h = conv(g, ({'user': uf}, {'user' : uf, 'game' : gf}))
assert set(h.keys()) == {'user', 'game'} assert set(h.keys()) == {'user', 'game'}
if agg != 'stack': if agg != 'stack':
assert h['user'].shape == (4, 3) assert h['user'].shape == (4, 3)
assert h['game'].shape == (4, 4) assert h['game'].shape == (4, 4)
else: else:
assert h['user'].shape == (4, 1, 3) assert h['user'].shape == (4, 1, 3)
assert h['game'].shape == (4, 1, 4) assert h['game'].shape == (4, 2, 4)
# pair input requires both src and dst type features to be provided
h = conv(g, ({'user': uf}, {'game' : gf}))
assert set(h.keys()) == {'game'}
if agg != 'stack':
assert h['game'].shape == (4, 4)
else:
assert h['game'].shape == (4, 1, 4)
# test with mod args # test with mod args
class MyMod(mx.gluon.nn.Block): class MyMod(mx.gluon.nn.Block):
...@@ -781,7 +757,7 @@ def test_hetero_conv(agg, idtype): ...@@ -781,7 +757,7 @@ def test_hetero_conv(agg, idtype):
agg) agg)
conv.initialize(ctx=F.ctx()) conv.initialize(ctx=F.ctx())
mod_args = {'follows' : (1,), 'plays' : (1,)} mod_args = {'follows' : (1,), 'plays' : (1,)}
h = conv(g, {'user' : uf, 'store' : sf}, mod_args) h = conv(g, {'user' : uf, 'store' : sf, 'game': gf}, mod_args)
assert mod1.carg1 == 1 assert mod1.carg1 == 1
assert mod2.carg1 == 1 assert mod2.carg1 == 1
assert mod3.carg1 == 0 assert mod3.carg1 == 0
......
...@@ -939,16 +939,17 @@ def test_hetero_conv(agg, idtype): ...@@ -939,16 +939,17 @@ def test_hetero_conv(agg, idtype):
gf = F.randn((4, 4)) gf = F.randn((4, 4))
sf = F.randn((2, 3)) sf = F.randn((2, 3))
h = conv(g, {'user': uf}) h = conv(g, {'user': uf, 'game': gf, 'store': sf})
assert set(h.keys()) == {'user', 'game'} assert set(h.keys()) == {'user', 'game'}
if agg != 'stack': if agg != 'stack':
assert h['user'].shape == (4, 3) assert h['user'].shape == (4, 3)
assert h['game'].shape == (4, 4) assert h['game'].shape == (4, 4)
else: else:
assert h['user'].shape == (4, 1, 3) assert h['user'].shape == (4, 1, 3)
assert h['game'].shape == (4, 1, 4) assert h['game'].shape == (4, 2, 4)
h = conv(g, {'user': uf, 'store': sf}) block = dgl.to_block(g.to(F.cpu()), {'user': [0, 1, 2, 3], 'game': [0, 1, 2, 3], 'store': []}).to(F.ctx())
h = conv(block, ({'user': uf, 'game': gf, 'store': sf}, {'user': uf, 'game': gf, 'store': sf[0:0]}))
assert set(h.keys()) == {'user', 'game'} assert set(h.keys()) == {'user', 'game'}
if agg != 'stack': if agg != 'stack':
assert h['user'].shape == (4, 3) assert h['user'].shape == (4, 3)
...@@ -957,37 +958,14 @@ def test_hetero_conv(agg, idtype): ...@@ -957,37 +958,14 @@ def test_hetero_conv(agg, idtype):
assert h['user'].shape == (4, 1, 3) assert h['user'].shape == (4, 1, 3)
assert h['game'].shape == (4, 2, 4) assert h['game'].shape == (4, 2, 4)
h = conv(g, {'store': sf}) h = conv(block, {'user': uf, 'game': gf, 'store': sf})
assert set(h.keys()) == {'game'}
if agg != 'stack':
assert h['game'].shape == (4, 4)
else:
assert h['game'].shape == (4, 1, 4)
# test with pair input
conv = nn.HeteroGraphConv({
'follows': nn.SAGEConv(2, 3, 'mean'),
'plays': nn.SAGEConv((2, 4), 4, 'mean'),
'sells': nn.SAGEConv(3, 4, 'mean')},
agg)
conv = conv.to(F.ctx())
h = conv(g, ({'user': uf}, {'user' : uf, 'game' : gf}))
assert set(h.keys()) == {'user', 'game'} assert set(h.keys()) == {'user', 'game'}
if agg != 'stack': if agg != 'stack':
assert h['user'].shape == (4, 3) assert h['user'].shape == (4, 3)
assert h['game'].shape == (4, 4) assert h['game'].shape == (4, 4)
else: else:
assert h['user'].shape == (4, 1, 3) assert h['user'].shape == (4, 1, 3)
assert h['game'].shape == (4, 1, 4) assert h['game'].shape == (4, 2, 4)
# pair input requires both src and dst type features to be provided
h = conv(g, ({'user': uf}, {'game' : gf}))
assert set(h.keys()) == {'game'}
if agg != 'stack':
assert h['game'].shape == (4, 4)
else:
assert h['game'].shape == (4, 1, 4)
# test with mod args # test with mod args
class MyMod(th.nn.Module): class MyMod(th.nn.Module):
...@@ -1014,7 +992,7 @@ def test_hetero_conv(agg, idtype): ...@@ -1014,7 +992,7 @@ def test_hetero_conv(agg, idtype):
conv = conv.to(F.ctx()) conv = conv.to(F.ctx())
mod_args = {'follows' : (1,), 'plays' : (1,)} mod_args = {'follows' : (1,), 'plays' : (1,)}
mod_kwargs = {'sells' : {'arg2' : 'abc'}} mod_kwargs = {'sells' : {'arg2' : 'abc'}}
h = conv(g, {'user' : uf, 'store' : sf}, mod_args=mod_args, mod_kwargs=mod_kwargs) h = conv(g, {'user' : uf, 'game': gf, 'store' : sf}, mod_args=mod_args, mod_kwargs=mod_kwargs)
assert mod1.carg1 == 1 assert mod1.carg1 == 1
assert mod1.carg2 == 0 assert mod1.carg2 == 0
assert mod2.carg1 == 1 assert mod2.carg1 == 1
......
...@@ -401,19 +401,18 @@ def test_hetero_conv(agg, idtype): ...@@ -401,19 +401,18 @@ def test_hetero_conv(agg, idtype):
uf = F.randn((4, 2)) uf = F.randn((4, 2))
gf = F.randn((4, 4)) gf = F.randn((4, 4))
sf = F.randn((2, 3)) sf = F.randn((2, 3))
uf_dst = F.randn((4, 3))
gf_dst = F.randn((4, 4))
h = conv(g, {'user': uf}) h = conv(g, {'user': uf, 'store': sf, 'game': gf})
assert set(h.keys()) == {'user', 'game'} assert set(h.keys()) == {'user', 'game'}
if agg != 'stack': if agg != 'stack':
assert h['user'].shape == (4, 3) assert h['user'].shape == (4, 3)
assert h['game'].shape == (4, 4) assert h['game'].shape == (4, 4)
else: else:
assert h['user'].shape == (4, 1, 3) assert h['user'].shape == (4, 1, 3)
assert h['game'].shape == (4, 1, 4) assert h['game'].shape == (4, 2, 4)
h = conv(g, {'user': uf, 'store': sf}) block = dgl.to_block(g.to(F.cpu()), {'user': [0, 1, 2, 3], 'game': [0, 1, 2, 3], 'store': []}).to(F.ctx())
h = conv(block, ({'user': uf, 'game': gf, 'store': sf}, {'user': uf, 'game': gf, 'store': sf[0:0]}))
assert set(h.keys()) == {'user', 'game'} assert set(h.keys()) == {'user', 'game'}
if agg != 'stack': if agg != 'stack':
assert h['user'].shape == (4, 3) assert h['user'].shape == (4, 3)
...@@ -422,36 +421,14 @@ def test_hetero_conv(agg, idtype): ...@@ -422,36 +421,14 @@ def test_hetero_conv(agg, idtype):
assert h['user'].shape == (4, 1, 3) assert h['user'].shape == (4, 1, 3)
assert h['game'].shape == (4, 2, 4) assert h['game'].shape == (4, 2, 4)
h = conv(g, {'store': sf}) h = conv(block, {'user': uf, 'game': gf, 'store': sf})
assert set(h.keys()) == {'game'}
if agg != 'stack':
assert h['game'].shape == (4, 4)
else:
assert h['game'].shape == (4, 1, 4)
# test with pair input
conv = nn.HeteroGraphConv({
'follows': nn.SAGEConv(2, 3, 'mean'),
'plays': nn.SAGEConv((2, 4), 4, 'mean'),
'sells': nn.SAGEConv(3, 4, 'mean')},
agg)
h = conv(g, ({'user': uf}, {'user' : uf, 'game' : gf}))
assert set(h.keys()) == {'user', 'game'} assert set(h.keys()) == {'user', 'game'}
if agg != 'stack': if agg != 'stack':
assert h['user'].shape == (4, 3) assert h['user'].shape == (4, 3)
assert h['game'].shape == (4, 4) assert h['game'].shape == (4, 4)
else: else:
assert h['user'].shape == (4, 1, 3) assert h['user'].shape == (4, 1, 3)
assert h['game'].shape == (4, 1, 4) assert h['game'].shape == (4, 2, 4)
# pair input requires both src and dst type features to be provided
h = conv(g, ({'user': uf}, {'game' : gf}))
assert set(h.keys()) == {'game'}
if agg != 'stack':
assert h['game'].shape == (4, 4)
else:
assert h['game'].shape == (4, 1, 4)
# test with mod args # test with mod args
class MyMod(tf.keras.layers.Layer): class MyMod(tf.keras.layers.Layer):
...@@ -477,7 +454,7 @@ def test_hetero_conv(agg, idtype): ...@@ -477,7 +454,7 @@ def test_hetero_conv(agg, idtype):
agg) agg)
mod_args = {'follows' : (1,), 'plays' : (1,)} mod_args = {'follows' : (1,), 'plays' : (1,)}
mod_kwargs = {'sells' : {'arg2' : 'abc'}} mod_kwargs = {'sells' : {'arg2' : 'abc'}}
h = conv(g, {'user' : uf, 'store' : sf}, mod_args=mod_args, mod_kwargs=mod_kwargs) h = conv(g, {'user' : uf, 'game': gf, 'store' : sf}, mod_args=mod_args, mod_kwargs=mod_kwargs)
assert mod1.carg1 == 1 assert mod1.carg1 == 1
assert mod1.carg2 == 0 assert mod1.carg2 == 0
assert mod2.carg1 == 1 assert mod2.carg1 == 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment