Unverified Commit e18c2ab4 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[NN] Fix GATConv for Broadcasting with Residual Connections (#2867)



* Update

* update
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-59-108.us-west-2.compute.internal>
parent b2b531e0
......@@ -198,6 +198,7 @@ class GATConv(nn.Module):
nn.init.xavier_normal_(self.fc_dst.weight, gain=gain)
nn.init.xavier_normal_(self.attn_l, gain=gain)
nn.init.xavier_normal_(self.attn_r, gain=gain)
if self.bias is not None:
nn.init.constant_(self.bias, 0)
if isinstance(self.res_fc, nn.Linear):
nn.init.xavier_normal_(self.res_fc.weight, gain=gain)
......@@ -304,7 +305,8 @@ class GATConv(nn.Module):
rst = graph.dstdata['ft']
# residual
if self.res_fc is not None:
resval = self.res_fc(h_dst).view(h_dst.shape[0], self._num_heads, self._out_feats)
# Use -1 rather than self._num_heads to handle broadcasting
resval = self.res_fc(h_dst).view(h_dst.shape[0], -1, self._out_feats)
rst = rst + resval
# bias
if self.bias is not None:
......
......@@ -527,7 +527,7 @@ def test_rgcn_sorted(O):
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree']))
@pytest.mark.parametrize('out_dim', [1, 2])
@pytest.mark.parametrize('out_dim', [1, 5])
@pytest.mark.parametrize('num_heads', [1, 4])
def test_gat_conv(g, idtype, out_dim, num_heads):
g = g.astype(idtype).to(F.ctx())
......@@ -544,6 +544,11 @@ def test_gat_conv(g, idtype, out_dim, num_heads):
_, a = gat(g, feat, get_attention=True)
assert a.shape == (g.number_of_edges(), num_heads, 1)
# test residual connection
gat = nn.GATConv(5, out_dim, num_heads, residual=True)
gat = gat.to(ctx)
h = gat(g, feat)
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
@pytest.mark.parametrize('out_dim', [1, 2])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment