Unverified Commit e4ddafe9 authored by Chen Sirui's avatar Chen Sirui Committed by GitHub
Browse files

[NN] add multihead in DotGatConv (#2549)



* add multihead in DotGatConv

* Fix spacing issue

* Add Unit test for dotgat

* Modified Unit test for dotgat

* Add transformer like divisor

* Update dotgatconv.py
Co-authored-by: default avatarChen <chesirui@3c22fbe5458c.ant.amazon.com>
Co-authored-by: default avatarZihao Ye <expye@outlook.com>
parent 4ca706e1
......@@ -41,6 +41,8 @@ class DotGatConv(nn.Module):
same value.
out_feats : int
Output feature size; i.e, the number of dimensions of :math:`h_i^{(l+1)}`.
num_heads : int
Number of head in Multi-Head Attention
allow_zero_in_degree : bool, optional
If there are 0-in-degree nodes in the graph, output for those nodes will be invalid
since no message will be passed to those nodes. This is harmful for some applications
......@@ -75,15 +77,27 @@ class DotGatConv(nn.Module):
>>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))
>>> g = dgl.add_self_loop(g)
>>> feat = th.ones(6, 10)
>>> gatconv = DotGatConv(10, 2)
>>> res = gatconv(g, feat)
>>> dotgatconv = DotGatConv(10, 2, num_heads=3)
>>> res = dotgatconv(g, feat)
>>> res
tensor([[-0.6958, -0.8752],
[-0.6958, -0.8752],
[-0.6958, -0.8752],
[-0.6958, -0.8752],
[-0.6958, -0.8752],
[-0.6958, -0.8752]], grad_fn=<CopyReduceBackward>)
tensor([[[ 3.4570, 1.8634],
[ 1.3805, -0.0762],
[ 1.0390, -1.1479]],
[[ 3.4570, 1.8634],
[ 1.3805, -0.0762],
[ 1.0390, -1.1479]],
[[ 3.4570, 1.8634],
[ 1.3805, -0.0762],
[ 1.0390, -1.1479]],
[[ 3.4570, 1.8634],
[ 1.3805, -0.0762],
[ 1.0390, -1.1479]],
[[ 3.4570, 1.8634],
[ 1.3805, -0.0762],
[ 1.0390, -1.1479]],
[[ 3.4570, 1.8634],
[ 1.3805, -0.0762],
[ 1.0390, -1.1479]]], grad_fn=<BinaryReduceBackward>)
>>> # Case 2: Unidirectional bipartite graph
>>> u = [0, 1, 0, 0, 1]
......@@ -91,28 +105,38 @@ class DotGatConv(nn.Module):
>>> g = dgl.bipartite((u, v))
>>> u_feat = th.tensor(np.random.rand(2, 5).astype(np.float32))
>>> v_feat = th.tensor(np.random.rand(4, 10).astype(np.float32))
>>> gatconv = DotGatConv((5,10), 2)
>>> res = gatconv(g, (u_feat, v_feat))
>>> dotgatconv = DotGatConv((5,10), 2, 3)
>>> res = dotgatconv(g, (u_feat, v_feat))
>>> res
tensor([[ 0.4718, 0.0864],
[ 0.7099, -0.0335],
[ 0.5869, 0.0284],
[ 0.4718, 0.0864]], grad_fn=<CopyReduceBackward>)
tensor([[[-0.6066, 1.0268],
[-0.5945, -0.4801],
[ 0.1594, 0.3825]],
[[ 0.0268, 1.0783],
[ 0.5041, -1.3025],
[ 0.6568, 0.7048]],
[[-0.2688, 1.0543],
[-0.0315, -0.9016],
[ 0.3943, 0.5347]],
[[-0.6066, 1.0268],
[-0.5945, -0.4801],
[ 0.1594, 0.3825]]], grad_fn=<BinaryReduceBackward>)
"""
def __init__(self,
in_feats,
out_feats,
num_heads,
allow_zero_in_degree=False):
super(DotGatConv, self).__init__()
self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)
self._out_feats = out_feats
self._allow_zero_in_degree = allow_zero_in_degree
self._num_heads = num_heads
if isinstance(in_feats, tuple):
self.fc_src = nn.Linear(self._in_src_feats, self._out_feats, bias=False)
self.fc_dst = nn.Linear(self._in_dst_feats, self._out_feats, bias=False)
self.fc_src = nn.Linear(self._in_src_feats, self._out_feats*self._num_heads, bias=False)
self.fc_dst = nn.Linear(self._in_dst_feats, self._out_feats*self._num_heads, bias=False)
else:
self.fc = nn.Linear(self._in_src_feats, self._out_feats, bias=False)
self.fc = nn.Linear(self._in_src_feats, self._out_feats*self._num_heads, bias=False)
def forward(self, graph, feat, get_attention=False):
r"""
......@@ -168,11 +192,11 @@ class DotGatConv(nn.Module):
if isinstance(feat, tuple):
h_src = feat[0]
h_dst = feat[1]
feat_src = self.fc_src(h_src)
feat_dst = self.fc_dst(h_dst)
feat_src = self.fc_src(h_src).view(-1, self._num_heads, self._out_feats)
feat_dst = self.fc_dst(h_dst).view(-1, self._num_heads, self._out_feats)
else:
h_src = feat
feat_src = feat_dst = self.fc(h_src)
feat_src = feat_dst = self.fc(h_src).view(-1, self._num_heads, self._out_feats)
if graph.is_block:
feat_dst = feat_src[:graph.number_of_dst_nodes()]
......@@ -184,7 +208,7 @@ class DotGatConv(nn.Module):
graph.apply_edges(fn.u_dot_v('ft', 'ft', 'a'))
# Step 2. edge softmax to compute attention scores
graph.edata['sa'] = edge_softmax(graph, graph.edata['a'])
graph.edata['sa'] = edge_softmax(graph, graph.edata['a'])/(self._out_feats**0.5)
# Step 3. Broadcast softmax value to each edge, and aggregate dst node
graph.update_all(fn.u_mul_e('ft', 'sa', 'attn'), fn.sum('attn', 'agg_u'))
......
......@@ -779,6 +779,32 @@ def test_edge_conv_bi(g, idtype):
x0 = F.randn((g.number_of_dst_nodes(), 5))
h1 = edge_conv(g, (h0, x0))
assert h1.shape == (g.number_of_dst_nodes(), 2)
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree']))
def test_dotgat_conv(g, idtype):
g = g.astype(idtype).to(F.ctx())
ctx = F.ctx()
dotgat = nn.DotGatConv(5, 2, 4)
feat = F.randn((g.number_of_nodes(), 5))
dotgat = dotgat.to(ctx)
h = dotgat(g, feat)
assert h.shape == (g.number_of_nodes(), 4, 2)
_, a = dotgat(g, feat, get_attention=True)
assert a.shape == (g.number_of_edges(), 4, 1)
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
def test_dotgat_conv_bi(g, idtype):
g = g.astype(idtype).to(F.ctx())
ctx = F.ctx()
dotgat = nn.DotGatConv((5, 5), 2, 4)
feat = (F.randn((g.number_of_src_nodes(), 5)), F.randn((g.number_of_dst_nodes(), 5)))
dotgat = dotgat.to(ctx)
h = dotgat(g, feat)
assert h.shape == (g.number_of_dst_nodes(), 4, 2)
_, a = dotgat(g, feat, get_attention=True)
assert a.shape == (g.number_of_edges(), 4, 1)
def test_dense_cheb_conv():
for k in range(1, 4):
......@@ -1016,6 +1042,7 @@ if __name__ == '__main__':
test_gated_graph_conv()
test_nn_conv()
test_gmm_conv()
test_dotgat_conv()
test_dense_graph_conv()
test_dense_sage_conv()
test_dense_cheb_conv()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment