Unverified Commit e4ddafe9 authored by Chen Sirui's avatar Chen Sirui Committed by GitHub
Browse files

[NN] add multihead in DotGatConv (#2549)



* add multihead in DotGatConv

* Fix spacing issue

* Add Unit test for dotgat

* Modified Unit test for dotgat

* Add transformer like divisor

* Update dotgatconv.py
Co-authored-by: default avatarChen <chesirui@3c22fbe5458c.ant.amazon.com>
Co-authored-by: default avatarZihao Ye <expye@outlook.com>
parent 4ca706e1
...@@ -41,6 +41,8 @@ class DotGatConv(nn.Module): ...@@ -41,6 +41,8 @@ class DotGatConv(nn.Module):
same value. same value.
out_feats : int out_feats : int
Output feature size; i.e, the number of dimensions of :math:`h_i^{(l+1)}`. Output feature size; i.e, the number of dimensions of :math:`h_i^{(l+1)}`.
num_heads : int
Number of head in Multi-Head Attention
allow_zero_in_degree : bool, optional allow_zero_in_degree : bool, optional
If there are 0-in-degree nodes in the graph, output for those nodes will be invalid If there are 0-in-degree nodes in the graph, output for those nodes will be invalid
since no message will be passed to those nodes. This is harmful for some applications since no message will be passed to those nodes. This is harmful for some applications
...@@ -75,15 +77,27 @@ class DotGatConv(nn.Module): ...@@ -75,15 +77,27 @@ class DotGatConv(nn.Module):
>>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3])) >>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))
>>> g = dgl.add_self_loop(g) >>> g = dgl.add_self_loop(g)
>>> feat = th.ones(6, 10) >>> feat = th.ones(6, 10)
>>> gatconv = DotGatConv(10, 2) >>> dotgatconv = DotGatConv(10, 2, num_heads=3)
>>> res = gatconv(g, feat) >>> res = dotgatconv(g, feat)
>>> res >>> res
tensor([[-0.6958, -0.8752], tensor([[[ 3.4570, 1.8634],
[-0.6958, -0.8752], [ 1.3805, -0.0762],
[-0.6958, -0.8752], [ 1.0390, -1.1479]],
[-0.6958, -0.8752], [[ 3.4570, 1.8634],
[-0.6958, -0.8752], [ 1.3805, -0.0762],
[-0.6958, -0.8752]], grad_fn=<CopyReduceBackward>) [ 1.0390, -1.1479]],
[[ 3.4570, 1.8634],
[ 1.3805, -0.0762],
[ 1.0390, -1.1479]],
[[ 3.4570, 1.8634],
[ 1.3805, -0.0762],
[ 1.0390, -1.1479]],
[[ 3.4570, 1.8634],
[ 1.3805, -0.0762],
[ 1.0390, -1.1479]],
[[ 3.4570, 1.8634],
[ 1.3805, -0.0762],
[ 1.0390, -1.1479]]], grad_fn=<BinaryReduceBackward>)
>>> # Case 2: Unidirectional bipartite graph >>> # Case 2: Unidirectional bipartite graph
>>> u = [0, 1, 0, 0, 1] >>> u = [0, 1, 0, 0, 1]
...@@ -91,28 +105,38 @@ class DotGatConv(nn.Module): ...@@ -91,28 +105,38 @@ class DotGatConv(nn.Module):
>>> g = dgl.bipartite((u, v)) >>> g = dgl.bipartite((u, v))
>>> u_feat = th.tensor(np.random.rand(2, 5).astype(np.float32)) >>> u_feat = th.tensor(np.random.rand(2, 5).astype(np.float32))
>>> v_feat = th.tensor(np.random.rand(4, 10).astype(np.float32)) >>> v_feat = th.tensor(np.random.rand(4, 10).astype(np.float32))
>>> gatconv = DotGatConv((5,10), 2) >>> dotgatconv = DotGatConv((5,10), 2, 3)
>>> res = gatconv(g, (u_feat, v_feat)) >>> res = dotgatconv(g, (u_feat, v_feat))
>>> res >>> res
tensor([[ 0.4718, 0.0864], tensor([[[-0.6066, 1.0268],
[ 0.7099, -0.0335], [-0.5945, -0.4801],
[ 0.5869, 0.0284], [ 0.1594, 0.3825]],
[ 0.4718, 0.0864]], grad_fn=<CopyReduceBackward>) [[ 0.0268, 1.0783],
[ 0.5041, -1.3025],
[ 0.6568, 0.7048]],
[[-0.2688, 1.0543],
[-0.0315, -0.9016],
[ 0.3943, 0.5347]],
[[-0.6066, 1.0268],
[-0.5945, -0.4801],
[ 0.1594, 0.3825]]], grad_fn=<BinaryReduceBackward>)
""" """
def __init__(self, def __init__(self,
in_feats, in_feats,
out_feats, out_feats,
num_heads,
allow_zero_in_degree=False): allow_zero_in_degree=False):
super(DotGatConv, self).__init__() super(DotGatConv, self).__init__()
self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats) self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)
self._out_feats = out_feats self._out_feats = out_feats
self._allow_zero_in_degree = allow_zero_in_degree self._allow_zero_in_degree = allow_zero_in_degree
self._num_heads = num_heads
if isinstance(in_feats, tuple): if isinstance(in_feats, tuple):
self.fc_src = nn.Linear(self._in_src_feats, self._out_feats, bias=False) self.fc_src = nn.Linear(self._in_src_feats, self._out_feats*self._num_heads, bias=False)
self.fc_dst = nn.Linear(self._in_dst_feats, self._out_feats, bias=False) self.fc_dst = nn.Linear(self._in_dst_feats, self._out_feats*self._num_heads, bias=False)
else: else:
self.fc = nn.Linear(self._in_src_feats, self._out_feats, bias=False) self.fc = nn.Linear(self._in_src_feats, self._out_feats*self._num_heads, bias=False)
def forward(self, graph, feat, get_attention=False): def forward(self, graph, feat, get_attention=False):
r""" r"""
...@@ -168,11 +192,11 @@ class DotGatConv(nn.Module): ...@@ -168,11 +192,11 @@ class DotGatConv(nn.Module):
if isinstance(feat, tuple): if isinstance(feat, tuple):
h_src = feat[0] h_src = feat[0]
h_dst = feat[1] h_dst = feat[1]
feat_src = self.fc_src(h_src) feat_src = self.fc_src(h_src).view(-1, self._num_heads, self._out_feats)
feat_dst = self.fc_dst(h_dst) feat_dst = self.fc_dst(h_dst).view(-1, self._num_heads, self._out_feats)
else: else:
h_src = feat h_src = feat
feat_src = feat_dst = self.fc(h_src) feat_src = feat_dst = self.fc(h_src).view(-1, self._num_heads, self._out_feats)
if graph.is_block: if graph.is_block:
feat_dst = feat_src[:graph.number_of_dst_nodes()] feat_dst = feat_src[:graph.number_of_dst_nodes()]
...@@ -184,7 +208,7 @@ class DotGatConv(nn.Module): ...@@ -184,7 +208,7 @@ class DotGatConv(nn.Module):
graph.apply_edges(fn.u_dot_v('ft', 'ft', 'a')) graph.apply_edges(fn.u_dot_v('ft', 'ft', 'a'))
# Step 2. edge softmax to compute attention scores # Step 2. edge softmax to compute attention scores
graph.edata['sa'] = edge_softmax(graph, graph.edata['a']) graph.edata['sa'] = edge_softmax(graph, graph.edata['a'])/(self._out_feats**0.5)
# Step 3. Broadcast softmax value to each edge, and aggregate dst node # Step 3. Broadcast softmax value to each edge, and aggregate dst node
graph.update_all(fn.u_mul_e('ft', 'sa', 'attn'), fn.sum('attn', 'agg_u')) graph.update_all(fn.u_mul_e('ft', 'sa', 'attn'), fn.sum('attn', 'agg_u'))
......
...@@ -779,6 +779,32 @@ def test_edge_conv_bi(g, idtype): ...@@ -779,6 +779,32 @@ def test_edge_conv_bi(g, idtype):
x0 = F.randn((g.number_of_dst_nodes(), 5)) x0 = F.randn((g.number_of_dst_nodes(), 5))
h1 = edge_conv(g, (h0, x0)) h1 = edge_conv(g, (h0, x0))
assert h1.shape == (g.number_of_dst_nodes(), 2) assert h1.shape == (g.number_of_dst_nodes(), 2)
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree']))
def test_dotgat_conv(g, idtype):
g = g.astype(idtype).to(F.ctx())
ctx = F.ctx()
dotgat = nn.DotGatConv(5, 2, 4)
feat = F.randn((g.number_of_nodes(), 5))
dotgat = dotgat.to(ctx)
h = dotgat(g, feat)
assert h.shape == (g.number_of_nodes(), 4, 2)
_, a = dotgat(g, feat, get_attention=True)
assert a.shape == (g.number_of_edges(), 4, 1)
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
def test_dotgat_conv_bi(g, idtype):
g = g.astype(idtype).to(F.ctx())
ctx = F.ctx()
dotgat = nn.DotGatConv((5, 5), 2, 4)
feat = (F.randn((g.number_of_src_nodes(), 5)), F.randn((g.number_of_dst_nodes(), 5)))
dotgat = dotgat.to(ctx)
h = dotgat(g, feat)
assert h.shape == (g.number_of_dst_nodes(), 4, 2)
_, a = dotgat(g, feat, get_attention=True)
assert a.shape == (g.number_of_edges(), 4, 1)
def test_dense_cheb_conv(): def test_dense_cheb_conv():
for k in range(1, 4): for k in range(1, 4):
...@@ -1016,6 +1042,7 @@ if __name__ == '__main__': ...@@ -1016,6 +1042,7 @@ if __name__ == '__main__':
test_gated_graph_conv() test_gated_graph_conv()
test_nn_conv() test_nn_conv()
test_gmm_conv() test_gmm_conv()
test_dotgat_conv()
test_dense_graph_conv() test_dense_graph_conv()
test_dense_sage_conv() test_dense_sage_conv()
test_dense_cheb_conv() test_dense_cheb_conv()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment