Unverified Commit fcfe52ae authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

fix residual (#2962)

parent 2ad7a9e9
...@@ -269,6 +269,8 @@ class GATConv(nn.Block): ...@@ -269,6 +269,8 @@ class GATConv(nn.Block):
*src_prefix_shape, self._num_heads, self._out_feats) *src_prefix_shape, self._num_heads, self._out_feats)
if graph.is_block: if graph.is_block:
feat_dst = feat_src[:graph.number_of_dst_nodes()] feat_dst = feat_src[:graph.number_of_dst_nodes()]
h_dst = h_dst[:graph.number_of_dst_nodes()]
dst_prefix_shape = (graph.number_of_dst_nodes(),) + dst_prefix_shape[1:]
# NOTE: GAT paper uses "first concatenation then linear projection" # NOTE: GAT paper uses "first concatenation then linear projection"
# to compute attention scores, while ours is "first projection then # to compute attention scores, while ours is "first projection then
# addition", the two approaches are mathematically equivalent: # addition", the two approaches are mathematically equivalent:
......
...@@ -287,6 +287,8 @@ class GATConv(nn.Module): ...@@ -287,6 +287,8 @@ class GATConv(nn.Module):
*src_prefix_shape, self._num_heads, self._out_feats) *src_prefix_shape, self._num_heads, self._out_feats)
if graph.is_block: if graph.is_block:
feat_dst = feat_src[:graph.number_of_dst_nodes()] feat_dst = feat_src[:graph.number_of_dst_nodes()]
h_dst = h_dst[:graph.number_of_dst_nodes()]
dst_prefix_shape = (graph.number_of_dst_nodes(),) + dst_prefix_shape[1:]
# NOTE: GAT paper uses "first concatenation then linear projection" # NOTE: GAT paper uses "first concatenation then linear projection"
# to compute attention scores, while ours is "first projection then # to compute attention scores, while ours is "first projection then
# addition", the two approaches are mathematically equivalent: # addition", the two approaches are mathematically equivalent:
......
...@@ -238,7 +238,10 @@ class SAGEConv(nn.Module): ...@@ -238,7 +238,10 @@ class SAGEConv(nn.Module):
if isinstance(feat, tuple): # heterogeneous if isinstance(feat, tuple): # heterogeneous
graph.dstdata['h'] = self.fc_neigh(feat_dst) if lin_before_mp else feat_dst graph.dstdata['h'] = self.fc_neigh(feat_dst) if lin_before_mp else feat_dst
else: else:
graph.dstdata['h'] = graph.srcdata['h'] if graph.is_block:
graph.dstdata['h'] = graph.srcdata['h'][:graph.num_dst_nodes()]
else:
graph.dstdata['h'] = graph.srcdata['h']
graph.update_all(msg_fn, fn.sum('m', 'neigh')) graph.update_all(msg_fn, fn.sum('m', 'neigh'))
# divide in_degrees # divide in_degrees
degs = graph.in_degrees().to(feat_dst) degs = graph.in_degrees().to(feat_dst)
......
...@@ -263,6 +263,8 @@ class GATConv(layers.Layer): ...@@ -263,6 +263,8 @@ class GATConv(layers.Layer):
self.fc(h_src), src_prefix_shape + (self._num_heads, self._out_feats)) self.fc(h_src), src_prefix_shape + (self._num_heads, self._out_feats))
if graph.is_block: if graph.is_block:
feat_dst = feat_src[:graph.number_of_dst_nodes()] feat_dst = feat_src[:graph.number_of_dst_nodes()]
h_dst = h_dst[:graph.number_of_dst_nodes()]
dst_prefix_shape = (graph.number_of_dst_nodes(),) + dst_prefix_shape[1:]
# NOTE: GAT paper uses "first concatenation then linear projection" # NOTE: GAT paper uses "first concatenation then linear projection"
# to compute attention scores, while ours is "first projection then # to compute attention scores, while ours is "first projection then
# addition", the two approaches are mathematically equivalent: # addition", the two approaches are mathematically equivalent:
......
...@@ -90,7 +90,8 @@ def test_graph_conv2(idtype, g, norm, weight, bias, out_dim): ...@@ -90,7 +90,8 @@ def test_graph_conv2(idtype, g, norm, weight, bias, out_dim):
conv = nn.GraphConv(5, out_dim, norm=norm, weight=weight, bias=bias) conv = nn.GraphConv(5, out_dim, norm=norm, weight=weight, bias=bias)
conv.initialize(ctx=F.ctx()) conv.initialize(ctx=F.ctx())
ext_w = F.randn((5, out_dim)).as_in_context(F.ctx()) ext_w = F.randn((5, out_dim)).as_in_context(F.ctx())
nsrc = ndst = g.number_of_nodes() nsrc = g.number_of_src_nodes()
ndst = g.number_of_dst_nodes()
h = F.randn((nsrc, 5)).as_in_context(F.ctx()) h = F.randn((nsrc, 5)).as_in_context(F.ctx())
if weight: if weight:
h_out = conv(g, h) h_out = conv(g, h)
...@@ -170,12 +171,17 @@ def test_gat_conv(g, idtype, out_dim, num_heads): ...@@ -170,12 +171,17 @@ def test_gat_conv(g, idtype, out_dim, num_heads):
gat = nn.GATConv(10, out_dim, num_heads) # n_heads = 5 gat = nn.GATConv(10, out_dim, num_heads) # n_heads = 5
gat.initialize(ctx=ctx) gat.initialize(ctx=ctx)
print(gat) print(gat)
feat = F.randn((g.number_of_nodes(), 10)) feat = F.randn((g.number_of_src_nodes(), 10))
h = gat(g, feat) h = gat(g, feat)
assert h.shape == (g.number_of_nodes(), num_heads, out_dim) assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim)
_, a = gat(g, feat, True) _, a = gat(g, feat, True)
assert a.shape == (g.number_of_edges(), num_heads, 1) assert a.shape == (g.number_of_edges(), num_heads, 1)
# test residual connection
gat = nn.GATConv(10, out_dim, num_heads, residual=True)
gat.initialize(ctx=ctx)
h = gat(g, feat)
@parametrize_dtype @parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree'])) @pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
@pytest.mark.parametrize('out_dim', [1, 2]) @pytest.mark.parametrize('out_dim', [1, 2])
...@@ -199,7 +205,7 @@ def test_sage_conv(idtype, g, aggre_type, out_dim): ...@@ -199,7 +205,7 @@ def test_sage_conv(idtype, g, aggre_type, out_dim):
g = g.astype(idtype).to(F.ctx()) g = g.astype(idtype).to(F.ctx())
ctx = F.ctx() ctx = F.ctx()
sage = nn.SAGEConv(5, out_dim, aggre_type) sage = nn.SAGEConv(5, out_dim, aggre_type)
feat = F.randn((g.number_of_nodes(), 5)) feat = F.randn((g.number_of_src_nodes(), 5))
sage.initialize(ctx=ctx) sage.initialize(ctx=ctx)
h = sage(g, feat) h = sage(g, feat)
assert h.shape[-1] == out_dim assert h.shape[-1] == out_dim
...@@ -277,9 +283,9 @@ def test_agnn_conv(g, idtype): ...@@ -277,9 +283,9 @@ def test_agnn_conv(g, idtype):
agnn_conv = nn.AGNNConv(0.1, True) agnn_conv = nn.AGNNConv(0.1, True)
agnn_conv.initialize(ctx=ctx) agnn_conv.initialize(ctx=ctx)
print(agnn_conv) print(agnn_conv)
feat = F.randn((g.number_of_nodes(), 10)) feat = F.randn((g.number_of_src_nodes(), 10))
h = agnn_conv(g, feat) h = agnn_conv(g, feat)
assert h.shape == (g.number_of_nodes(), 10) assert h.shape == (g.number_of_dst_nodes(), 10)
@parametrize_dtype @parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree'])) @pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
...@@ -387,9 +393,9 @@ def test_edge_conv(g, idtype, out_dim): ...@@ -387,9 +393,9 @@ def test_edge_conv(g, idtype, out_dim):
edge_conv.initialize(ctx=ctx) edge_conv.initialize(ctx=ctx)
print(edge_conv) print(edge_conv)
# test #1: basic # test #1: basic
h0 = F.randn((g.number_of_nodes(), 5)) h0 = F.randn((g.number_of_src_nodes(), 5))
h1 = edge_conv(g, h0) h1 = edge_conv(g, h0)
assert h1.shape == (g.number_of_nodes(), out_dim) assert h1.shape == (g.number_of_dst_nodes(), out_dim)
@parametrize_dtype @parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree'])) @pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
...@@ -418,9 +424,9 @@ def test_gin_conv(g, idtype, aggregator_type): ...@@ -418,9 +424,9 @@ def test_gin_conv(g, idtype, aggregator_type):
print(gin_conv) print(gin_conv)
# test #1: basic # test #1: basic
feat = F.randn((g.number_of_nodes(), 5)) feat = F.randn((g.number_of_src_nodes(), 5))
h = gin_conv(g, feat) h = gin_conv(g, feat)
assert h.shape == (g.number_of_nodes(), 5) assert h.shape == (g.number_of_dst_nodes(), 5)
@parametrize_dtype @parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite'])) @pytest.mark.parametrize('g', get_cases(['bipartite']))
...@@ -446,10 +452,10 @@ def test_gmm_conv(g, idtype): ...@@ -446,10 +452,10 @@ def test_gmm_conv(g, idtype):
ctx = F.ctx() ctx = F.ctx()
gmm_conv = nn.GMMConv(5, 2, 5, 3, 'max') gmm_conv = nn.GMMConv(5, 2, 5, 3, 'max')
gmm_conv.initialize(ctx=ctx) gmm_conv.initialize(ctx=ctx)
h0 = F.randn((g.number_of_nodes(), 5)) h0 = F.randn((g.number_of_src_nodes(), 5))
pseudo = F.randn((g.number_of_edges(), 5)) pseudo = F.randn((g.number_of_edges(), 5))
h1 = gmm_conv(g, h0, pseudo) h1 = gmm_conv(g, h0, pseudo)
assert h1.shape == (g.number_of_nodes(), 2) assert h1.shape == (g.number_of_dst_nodes(), 2)
@parametrize_dtype @parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree'])) @pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
...@@ -473,10 +479,10 @@ def test_nn_conv(g, idtype): ...@@ -473,10 +479,10 @@ def test_nn_conv(g, idtype):
nn_conv = nn.NNConv(5, 2, gluon.nn.Embedding(3, 5 * 2), 'max') nn_conv = nn.NNConv(5, 2, gluon.nn.Embedding(3, 5 * 2), 'max')
nn_conv.initialize(ctx=ctx) nn_conv.initialize(ctx=ctx)
# test #1: basic # test #1: basic
h0 = F.randn((g.number_of_nodes(), 5)) h0 = F.randn((g.number_of_src_nodes(), 5))
etypes = nd.random.randint(0, 4, g.number_of_edges()).as_in_context(ctx) etypes = nd.random.randint(0, 4, g.number_of_edges()).as_in_context(ctx)
h1 = nn_conv(g, h0, etypes) h1 = nn_conv(g, h0, etypes)
assert h1.shape == (g.number_of_nodes(), 2) assert h1.shape == (g.number_of_dst_nodes(), 2)
@parametrize_dtype @parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite'])) @pytest.mark.parametrize('g', get_cases(['bipartite']))
......
...@@ -533,14 +533,14 @@ def test_gat_conv(g, idtype, out_dim, num_heads): ...@@ -533,14 +533,14 @@ def test_gat_conv(g, idtype, out_dim, num_heads):
g = g.astype(idtype).to(F.ctx()) g = g.astype(idtype).to(F.ctx())
ctx = F.ctx() ctx = F.ctx()
gat = nn.GATConv(5, out_dim, num_heads) gat = nn.GATConv(5, out_dim, num_heads)
feat = F.randn((g.number_of_nodes(), 5)) feat = F.randn((g.number_of_src_nodes(), 5))
gat = gat.to(ctx) gat = gat.to(ctx)
h = gat(g, feat) h = gat(g, feat)
# test pickle # test pickle
th.save(gat, tmp_buffer) th.save(gat, tmp_buffer)
assert h.shape == (g.number_of_nodes(), num_heads, out_dim) assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim)
_, a = gat(g, feat, get_attention=True) _, a = gat(g, feat, get_attention=True)
assert a.shape == (g.number_of_edges(), num_heads, 1) assert a.shape == (g.number_of_edges(), num_heads, 1)
...@@ -570,7 +570,7 @@ def test_gat_conv_bi(g, idtype, out_dim, num_heads): ...@@ -570,7 +570,7 @@ def test_gat_conv_bi(g, idtype, out_dim, num_heads):
def test_sage_conv(idtype, g, aggre_type): def test_sage_conv(idtype, g, aggre_type):
g = g.astype(idtype).to(F.ctx()) g = g.astype(idtype).to(F.ctx())
sage = nn.SAGEConv(5, 10, aggre_type) sage = nn.SAGEConv(5, 10, aggre_type)
feat = F.randn((g.number_of_nodes(), 5)) feat = F.randn((g.number_of_src_nodes(), 5))
sage = sage.to(F.ctx()) sage = sage.to(F.ctx())
# test pickle # test pickle
th.save(sage, tmp_buffer) th.save(sage, tmp_buffer)
...@@ -664,14 +664,14 @@ def test_gin_conv(g, idtype, aggregator_type): ...@@ -664,14 +664,14 @@ def test_gin_conv(g, idtype, aggregator_type):
th.nn.Linear(5, 12), th.nn.Linear(5, 12),
aggregator_type aggregator_type
) )
feat = F.randn((g.number_of_nodes(), 5)) feat = F.randn((g.number_of_src_nodes(), 5))
gin = gin.to(ctx) gin = gin.to(ctx)
h = gin(g, feat) h = gin(g, feat)
# test pickle # test pickle
th.save(h, tmp_buffer) th.save(h, tmp_buffer)
assert h.shape == (g.number_of_nodes(), 12) assert h.shape == (g.number_of_dst_nodes(), 12)
@parametrize_dtype @parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree'])) @pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
...@@ -694,10 +694,10 @@ def test_agnn_conv(g, idtype): ...@@ -694,10 +694,10 @@ def test_agnn_conv(g, idtype):
g = g.astype(idtype).to(F.ctx()) g = g.astype(idtype).to(F.ctx())
ctx = F.ctx() ctx = F.ctx()
agnn = nn.AGNNConv(1) agnn = nn.AGNNConv(1)
feat = F.randn((g.number_of_nodes(), 5)) feat = F.randn((g.number_of_src_nodes(), 5))
agnn = agnn.to(ctx) agnn = agnn.to(ctx)
h = agnn(g, feat) h = agnn(g, feat)
assert h.shape == (g.number_of_nodes(), 5) assert h.shape == (g.number_of_dst_nodes(), 5)
@parametrize_dtype @parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree'])) @pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
...@@ -732,7 +732,7 @@ def test_nn_conv(g, idtype): ...@@ -732,7 +732,7 @@ def test_nn_conv(g, idtype):
ctx = F.ctx() ctx = F.ctx()
edge_func = th.nn.Linear(4, 5 * 10) edge_func = th.nn.Linear(4, 5 * 10)
nnconv = nn.NNConv(5, 10, edge_func, 'mean') nnconv = nn.NNConv(5, 10, edge_func, 'mean')
feat = F.randn((g.number_of_nodes(), 5)) feat = F.randn((g.number_of_src_nodes(), 5))
efeat = F.randn((g.number_of_edges(), 4)) efeat = F.randn((g.number_of_edges(), 4))
nnconv = nnconv.to(ctx) nnconv = nnconv.to(ctx)
h = nnconv(g, feat, efeat) h = nnconv(g, feat, efeat)
...@@ -837,9 +837,9 @@ def test_edge_conv(g, idtype, out_dim): ...@@ -837,9 +837,9 @@ def test_edge_conv(g, idtype, out_dim):
# test pickle # test pickle
th.save(edge_conv, tmp_buffer) th.save(edge_conv, tmp_buffer)
h0 = F.randn((g.number_of_nodes(), 5)) h0 = F.randn((g.number_of_src_nodes(), 5))
h1 = edge_conv(g, h0) h1 = edge_conv(g, h0)
assert h1.shape == (g.number_of_nodes(), out_dim) assert h1.shape == (g.number_of_dst_nodes(), out_dim)
@parametrize_dtype @parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree'])) @pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
...@@ -862,14 +862,14 @@ def test_dotgat_conv(g, idtype, out_dim, num_heads): ...@@ -862,14 +862,14 @@ def test_dotgat_conv(g, idtype, out_dim, num_heads):
g = g.astype(idtype).to(F.ctx()) g = g.astype(idtype).to(F.ctx())
ctx = F.ctx() ctx = F.ctx()
dotgat = nn.DotGatConv(5, out_dim, num_heads) dotgat = nn.DotGatConv(5, out_dim, num_heads)
feat = F.randn((g.number_of_nodes(), 5)) feat = F.randn((g.number_of_src_nodes(), 5))
dotgat = dotgat.to(ctx) dotgat = dotgat.to(ctx)
# test pickle # test pickle
th.save(dotgat, tmp_buffer) th.save(dotgat, tmp_buffer)
h = dotgat(g, feat) h = dotgat(g, feat)
assert h.shape == (g.number_of_nodes(), num_heads, out_dim) assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim)
_, a = dotgat(g, feat, get_attention=True) _, a = dotgat(g, feat, get_attention=True)
assert a.shape == (g.number_of_edges(), num_heads, 1) assert a.shape == (g.number_of_edges(), num_heads, 1)
......
...@@ -270,12 +270,16 @@ def test_gat_conv(g, idtype, out_dim, num_heads): ...@@ -270,12 +270,16 @@ def test_gat_conv(g, idtype, out_dim, num_heads):
g = g.astype(idtype).to(F.ctx()) g = g.astype(idtype).to(F.ctx())
ctx = F.ctx() ctx = F.ctx()
gat = nn.GATConv(5, out_dim, num_heads) gat = nn.GATConv(5, out_dim, num_heads)
feat = F.randn((g.number_of_nodes(), 5)) feat = F.randn((g.number_of_src_nodes(), 5))
h = gat(g, feat) h = gat(g, feat)
assert h.shape == (g.number_of_nodes(), num_heads, out_dim) assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim)
_, a = gat(g, feat, get_attention=True) _, a = gat(g, feat, get_attention=True)
assert a.shape == (g.number_of_edges(), num_heads, 1) assert a.shape == (g.number_of_edges(), num_heads, 1)
# test residual connection
gat = nn.GATConv(5, out_dim, num_heads, residual=True)
h = gat(g, feat)
@parametrize_dtype @parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree'])) @pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
@pytest.mark.parametrize('out_dim', [1, 2]) @pytest.mark.parametrize('out_dim', [1, 2])
...@@ -297,7 +301,7 @@ def test_gat_conv_bi(g, idtype, out_dim, num_heads): ...@@ -297,7 +301,7 @@ def test_gat_conv_bi(g, idtype, out_dim, num_heads):
def test_sage_conv(idtype, g, aggre_type, out_dim): def test_sage_conv(idtype, g, aggre_type, out_dim):
g = g.astype(idtype).to(F.ctx()) g = g.astype(idtype).to(F.ctx())
sage = nn.SAGEConv(5, out_dim, aggre_type) sage = nn.SAGEConv(5, out_dim, aggre_type)
feat = F.randn((g.number_of_nodes(), 5)) feat = F.randn((g.number_of_src_nodes(), 5))
h = sage(g, feat) h = sage(g, feat)
assert h.shape[-1] == out_dim assert h.shape[-1] == out_dim
...@@ -374,9 +378,9 @@ def test_gin_conv(g, idtype, aggregator_type): ...@@ -374,9 +378,9 @@ def test_gin_conv(g, idtype, aggregator_type):
tf.keras.layers.Dense(12), tf.keras.layers.Dense(12),
aggregator_type aggregator_type
) )
feat = F.randn((g.number_of_nodes(), 5)) feat = F.randn((g.number_of_src_nodes(), 5))
h = gin(g, feat) h = gin(g, feat)
assert h.shape == (g.number_of_nodes(), 12) assert h.shape == (g.number_of_dst_nodes(), 12)
@parametrize_dtype @parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite'])) @pytest.mark.parametrize('g', get_cases(['bipartite']))
...@@ -398,9 +402,9 @@ def test_edge_conv(g, idtype, out_dim): ...@@ -398,9 +402,9 @@ def test_edge_conv(g, idtype, out_dim):
g = g.astype(idtype).to(F.ctx()) g = g.astype(idtype).to(F.ctx())
edge_conv = nn.EdgeConv(out_dim) edge_conv = nn.EdgeConv(out_dim)
h0 = F.randn((g.number_of_nodes(), 5)) h0 = F.randn((g.number_of_src_nodes(), 5))
h1 = edge_conv(g, h0) h1 = edge_conv(g, h0)
assert h1.shape == (g.number_of_nodes(), out_dim) assert h1.shape == (g.number_of_dst_nodes(), out_dim)
@parametrize_dtype @parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree'])) @pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
......
...@@ -96,7 +96,7 @@ def batched_graph0(): ...@@ -96,7 +96,7 @@ def batched_graph0():
g3 = dgl.add_self_loop(dgl.graph(([0], [1]))) g3 = dgl.add_self_loop(dgl.graph(([0], [1])))
return dgl.batch([g1, g2, g3]) return dgl.batch([g1, g2, g3])
@register_case(['block', 'bipartite', 'block-biparitite']) @register_case(['block', 'bipartite', 'block-bipartite'])
def block_graph0(): def block_graph0():
g = dgl.graph(([2, 3, 4], [5, 6, 7]), num_nodes=100) g = dgl.graph(([2, 3, 4], [5, 6, 7]), num_nodes=100)
g = g.to(F.cpu()) g = g.to(F.cpu())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment