Unverified Commit 1f6eba9e authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Bug] Heterogeneous graph convolution bugfix (#2578)



* fix heterograph conv

* remove test cases

* fix test

* fix test

* fix test
Co-authored-by: default avatarMinjie Wang <wmjlyjemaine@gmail.com>
parent e4ddafe9
......@@ -40,7 +40,7 @@ class HeteroGraphConv(nn.Block):
``'user'`` and ``'game'`` nodes.
>>> import mxnet.ndarray as nd
>>> h1 = {'user' : nd.randomrandn(g.number_of_nodes('user'), 5)}
>>> h1 = {'user' : nd.random.randn(g.number_of_nodes('user'), 5)}
>>> h2 = conv(g, h1)
>>> print(h2.keys())
dict_keys(['user', 'game'])
......@@ -167,7 +167,7 @@ class HeteroGraphConv(nn.Block):
continue
dstdata = self.mods[etype](
rel_graph,
inputs[stype],
(inputs[stype], inputs[dtype]),
*mod_args.get(etype, ()),
**mod_kwargs.get(etype, {}))
outputs[dtype].append(dstdata)
......
......@@ -169,7 +169,7 @@ class HeteroGraphConv(nn.Module):
continue
dstdata = self.mods[etype](
rel_graph,
inputs[stype],
(inputs[stype], inputs[dtype]),
*mod_args.get(etype, ()),
**mod_kwargs.get(etype, {}))
outputs[dtype].append(dstdata)
......
......@@ -164,7 +164,7 @@ class HeteroGraphConv(layers.Layer):
continue
dstdata = self.mods[etype](
rel_graph,
inputs[stype],
(inputs[stype], inputs[dtype]),
*mod_args.get(etype, ()),
**mod_kwargs.get(etype, {}))
outputs[dtype].append(dstdata)
......
......@@ -707,19 +707,18 @@ def test_hetero_conv(agg, idtype):
uf = F.randn((4, 2))
gf = F.randn((4, 4))
sf = F.randn((2, 3))
uf_dst = F.randn((4, 3))
gf_dst = F.randn((4, 4))
h = conv(g, {'user': uf})
h = conv(g, {'user': uf, 'store': sf, 'game': gf})
assert set(h.keys()) == {'user', 'game'}
if agg != 'stack':
assert h['user'].shape == (4, 3)
assert h['game'].shape == (4, 4)
else:
assert h['user'].shape == (4, 1, 3)
assert h['game'].shape == (4, 1, 4)
assert h['game'].shape == (4, 2, 4)
h = conv(g, {'user': uf, 'store': sf})
block = dgl.to_block(g.to(F.cpu()), {'user': [0, 1, 2, 3], 'game': [0, 1, 2, 3], 'store': []}).to(F.ctx())
h = conv(block, ({'user': uf, 'game': gf, 'store': sf}, {'user': uf, 'game': gf, 'store': sf[0:0]}))
assert set(h.keys()) == {'user', 'game'}
if agg != 'stack':
assert h['user'].shape == (4, 3)
......@@ -728,37 +727,14 @@ def test_hetero_conv(agg, idtype):
assert h['user'].shape == (4, 1, 3)
assert h['game'].shape == (4, 2, 4)
h = conv(g, {'store': sf})
assert set(h.keys()) == {'game'}
if agg != 'stack':
assert h['game'].shape == (4, 4)
else:
assert h['game'].shape == (4, 1, 4)
# test with pair input
conv = nn.HeteroGraphConv({
'follows': nn.SAGEConv(2, 3, 'mean'),
'plays': nn.SAGEConv((2, 4), 4, 'mean'),
'sells': nn.SAGEConv(3, 4, 'mean')},
agg)
conv.initialize(ctx=F.ctx())
h = conv(g, ({'user': uf}, {'user' : uf, 'game' : gf}))
h = conv(block, {'user': uf, 'game': gf, 'store': sf})
assert set(h.keys()) == {'user', 'game'}
if agg != 'stack':
assert h['user'].shape == (4, 3)
assert h['game'].shape == (4, 4)
else:
assert h['user'].shape == (4, 1, 3)
assert h['game'].shape == (4, 1, 4)
# pair input requires both src and dst type features to be provided
h = conv(g, ({'user': uf}, {'game' : gf}))
assert set(h.keys()) == {'game'}
if agg != 'stack':
assert h['game'].shape == (4, 4)
else:
assert h['game'].shape == (4, 1, 4)
assert h['game'].shape == (4, 2, 4)
# test with mod args
class MyMod(mx.gluon.nn.Block):
......@@ -781,7 +757,7 @@ def test_hetero_conv(agg, idtype):
agg)
conv.initialize(ctx=F.ctx())
mod_args = {'follows' : (1,), 'plays' : (1,)}
h = conv(g, {'user' : uf, 'store' : sf}, mod_args)
h = conv(g, {'user' : uf, 'store' : sf, 'game': gf}, mod_args)
assert mod1.carg1 == 1
assert mod2.carg1 == 1
assert mod3.carg1 == 0
......
......@@ -939,16 +939,17 @@ def test_hetero_conv(agg, idtype):
gf = F.randn((4, 4))
sf = F.randn((2, 3))
h = conv(g, {'user': uf})
h = conv(g, {'user': uf, 'game': gf, 'store': sf})
assert set(h.keys()) == {'user', 'game'}
if agg != 'stack':
assert h['user'].shape == (4, 3)
assert h['game'].shape == (4, 4)
else:
assert h['user'].shape == (4, 1, 3)
assert h['game'].shape == (4, 1, 4)
assert h['game'].shape == (4, 2, 4)
h = conv(g, {'user': uf, 'store': sf})
block = dgl.to_block(g.to(F.cpu()), {'user': [0, 1, 2, 3], 'game': [0, 1, 2, 3], 'store': []}).to(F.ctx())
h = conv(block, ({'user': uf, 'game': gf, 'store': sf}, {'user': uf, 'game': gf, 'store': sf[0:0]}))
assert set(h.keys()) == {'user', 'game'}
if agg != 'stack':
assert h['user'].shape == (4, 3)
......@@ -957,37 +958,14 @@ def test_hetero_conv(agg, idtype):
assert h['user'].shape == (4, 1, 3)
assert h['game'].shape == (4, 2, 4)
h = conv(g, {'store': sf})
assert set(h.keys()) == {'game'}
if agg != 'stack':
assert h['game'].shape == (4, 4)
else:
assert h['game'].shape == (4, 1, 4)
# test with pair input
conv = nn.HeteroGraphConv({
'follows': nn.SAGEConv(2, 3, 'mean'),
'plays': nn.SAGEConv((2, 4), 4, 'mean'),
'sells': nn.SAGEConv(3, 4, 'mean')},
agg)
conv = conv.to(F.ctx())
h = conv(g, ({'user': uf}, {'user' : uf, 'game' : gf}))
h = conv(block, {'user': uf, 'game': gf, 'store': sf})
assert set(h.keys()) == {'user', 'game'}
if agg != 'stack':
assert h['user'].shape == (4, 3)
assert h['game'].shape == (4, 4)
else:
assert h['user'].shape == (4, 1, 3)
assert h['game'].shape == (4, 1, 4)
# pair input requires both src and dst type features to be provided
h = conv(g, ({'user': uf}, {'game' : gf}))
assert set(h.keys()) == {'game'}
if agg != 'stack':
assert h['game'].shape == (4, 4)
else:
assert h['game'].shape == (4, 1, 4)
assert h['game'].shape == (4, 2, 4)
# test with mod args
class MyMod(th.nn.Module):
......@@ -1014,7 +992,7 @@ def test_hetero_conv(agg, idtype):
conv = conv.to(F.ctx())
mod_args = {'follows' : (1,), 'plays' : (1,)}
mod_kwargs = {'sells' : {'arg2' : 'abc'}}
h = conv(g, {'user' : uf, 'store' : sf}, mod_args=mod_args, mod_kwargs=mod_kwargs)
h = conv(g, {'user' : uf, 'game': gf, 'store' : sf}, mod_args=mod_args, mod_kwargs=mod_kwargs)
assert mod1.carg1 == 1
assert mod1.carg2 == 0
assert mod2.carg1 == 1
......
......@@ -401,19 +401,18 @@ def test_hetero_conv(agg, idtype):
uf = F.randn((4, 2))
gf = F.randn((4, 4))
sf = F.randn((2, 3))
uf_dst = F.randn((4, 3))
gf_dst = F.randn((4, 4))
h = conv(g, {'user': uf})
h = conv(g, {'user': uf, 'store': sf, 'game': gf})
assert set(h.keys()) == {'user', 'game'}
if agg != 'stack':
assert h['user'].shape == (4, 3)
assert h['game'].shape == (4, 4)
else:
assert h['user'].shape == (4, 1, 3)
assert h['game'].shape == (4, 1, 4)
assert h['game'].shape == (4, 2, 4)
h = conv(g, {'user': uf, 'store': sf})
block = dgl.to_block(g.to(F.cpu()), {'user': [0, 1, 2, 3], 'game': [0, 1, 2, 3], 'store': []}).to(F.ctx())
h = conv(block, ({'user': uf, 'game': gf, 'store': sf}, {'user': uf, 'game': gf, 'store': sf[0:0]}))
assert set(h.keys()) == {'user', 'game'}
if agg != 'stack':
assert h['user'].shape == (4, 3)
......@@ -422,36 +421,14 @@ def test_hetero_conv(agg, idtype):
assert h['user'].shape == (4, 1, 3)
assert h['game'].shape == (4, 2, 4)
h = conv(g, {'store': sf})
assert set(h.keys()) == {'game'}
if agg != 'stack':
assert h['game'].shape == (4, 4)
else:
assert h['game'].shape == (4, 1, 4)
# test with pair input
conv = nn.HeteroGraphConv({
'follows': nn.SAGEConv(2, 3, 'mean'),
'plays': nn.SAGEConv((2, 4), 4, 'mean'),
'sells': nn.SAGEConv(3, 4, 'mean')},
agg)
h = conv(g, ({'user': uf}, {'user' : uf, 'game' : gf}))
h = conv(block, {'user': uf, 'game': gf, 'store': sf})
assert set(h.keys()) == {'user', 'game'}
if agg != 'stack':
assert h['user'].shape == (4, 3)
assert h['game'].shape == (4, 4)
else:
assert h['user'].shape == (4, 1, 3)
assert h['game'].shape == (4, 1, 4)
# pair input requires both src and dst type features to be provided
h = conv(g, ({'user': uf}, {'game' : gf}))
assert set(h.keys()) == {'game'}
if agg != 'stack':
assert h['game'].shape == (4, 4)
else:
assert h['game'].shape == (4, 1, 4)
assert h['game'].shape == (4, 2, 4)
# test with mod args
class MyMod(tf.keras.layers.Layer):
......@@ -477,7 +454,7 @@ def test_hetero_conv(agg, idtype):
agg)
mod_args = {'follows' : (1,), 'plays' : (1,)}
mod_kwargs = {'sells' : {'arg2' : 'abc'}}
h = conv(g, {'user' : uf, 'store' : sf}, mod_args=mod_args, mod_kwargs=mod_kwargs)
h = conv(g, {'user' : uf, 'game': gf, 'store' : sf}, mod_args=mod_args, mod_kwargs=mod_kwargs)
assert mod1.carg1 == 1
assert mod1.carg2 == 0
assert mod2.carg1 == 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment