Unverified Commit 013d1456 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[NN] Attention Retrieval for NN Modules (#2397)

* Update

* Update
parent c038b71f
...@@ -201,7 +201,7 @@ class GATConv(nn.Block): ...@@ -201,7 +201,7 @@ class GATConv(nn.Block):
""" """
self._allow_zero_in_degree = set_value self._allow_zero_in_degree = set_value
def forward(self, graph, feat): def forward(self, graph, feat, get_attention=False):
r""" r"""
Description Description
...@@ -217,12 +217,17 @@ class GATConv(nn.Block): ...@@ -217,12 +217,17 @@ class GATConv(nn.Block):
:math:`D_{in}` is size of input feature, :math:`N` is the number of nodes. :math:`D_{in}` is size of input feature, :math:`N` is the number of nodes.
If a pair of mxnet.NDArray is given, the pair must contain two tensors of shape If a pair of mxnet.NDArray is given, the pair must contain two tensors of shape
:math:`(N_{in}, D_{in_{src}})` and :math:`(N_{out}, D_{in_{dst}})`. :math:`(N_{in}, D_{in_{src}})` and :math:`(N_{out}, D_{in_{dst}})`.
get_attention : bool, optional
Whether to return the attention values. Default to False.
Returns Returns
------- -------
mxnet.NDArray mxnet.NDArray
The output feature of shape :math:`(N, H, D_{out})` where :math:`H` The output feature of shape :math:`(N, H, D_{out})` where :math:`H`
is the number of heads, and :math:`D_{out}` is size of output feature. is the number of heads, and :math:`D_{out}` is size of output feature.
mxnet.NDArray, optional
The attention values of shape :math:`(E, H, 1)`, where :math:`E` is the number of
edges. This is returned only when :attr:`get_attention` is ``True``.
Raises Raises
------ ------
...@@ -288,4 +293,8 @@ class GATConv(nn.Block): ...@@ -288,4 +293,8 @@ class GATConv(nn.Block):
# activation # activation
if self.activation: if self.activation:
rst = self.activation(rst) rst = self.activation(rst)
if get_attention:
return rst, graph.edata['a']
else:
return rst return rst
...@@ -114,7 +114,7 @@ class DotGatConv(nn.Module): ...@@ -114,7 +114,7 @@ class DotGatConv(nn.Module):
else: else:
self.fc = nn.Linear(self._in_src_feats, self._out_feats, bias=False) self.fc = nn.Linear(self._in_src_feats, self._out_feats, bias=False)
def forward(self, graph, feat): def forward(self, graph, feat, get_attention=False):
r""" r"""
Description Description
...@@ -130,12 +130,17 @@ class DotGatConv(nn.Module): ...@@ -130,12 +130,17 @@ class DotGatConv(nn.Module):
:math:`D_{in}` is size of input feature, :math:`N` is the number of nodes. :math:`D_{in}` is size of input feature, :math:`N` is the number of nodes.
If a pair of torch.Tensor is given, the pair must contain two tensors of shape If a pair of torch.Tensor is given, the pair must contain two tensors of shape
:math:`(N_{in}, D_{in_{src}})` and :math:`(N_{out}, D_{in_{dst}})`. :math:`(N_{in}, D_{in_{src}})` and :math:`(N_{out}, D_{in_{dst}})`.
get_attention : bool, optional
Whether to return the attention values. Default to False.
Returns Returns
------- -------
torch.Tensor torch.Tensor
The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}` is size The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}` is size
of output feature. of output feature.
torch.Tensor, optional
The attention values of shape :math:`(E, 1)`, where :math:`E` is the number of
edges. This is returned only when :attr:`get_attention` is ``True``.
Raises Raises
------ ------
...@@ -187,4 +192,7 @@ class DotGatConv(nn.Module): ...@@ -187,4 +192,7 @@ class DotGatConv(nn.Module):
# output results to the destination nodes # output results to the destination nodes
rst = graph.dstdata['agg_u'] rst = graph.dstdata['agg_u']
if get_attention:
return rst, graph.edata['sa']
else:
return rst return rst
...@@ -208,7 +208,7 @@ class GATConv(nn.Module): ...@@ -208,7 +208,7 @@ class GATConv(nn.Module):
""" """
self._allow_zero_in_degree = set_value self._allow_zero_in_degree = set_value
def forward(self, graph, feat): def forward(self, graph, feat, get_attention=False):
r""" r"""
Description Description
...@@ -224,12 +224,17 @@ class GATConv(nn.Module): ...@@ -224,12 +224,17 @@ class GATConv(nn.Module):
:math:`D_{in}` is size of input feature, :math:`N` is the number of nodes. :math:`D_{in}` is size of input feature, :math:`N` is the number of nodes.
If a pair of torch.Tensor is given, the pair must contain two tensors of shape If a pair of torch.Tensor is given, the pair must contain two tensors of shape
:math:`(N_{in}, D_{in_{src}})` and :math:`(N_{out}, D_{in_{dst}})`. :math:`(N_{in}, D_{in_{src}})` and :math:`(N_{out}, D_{in_{dst}})`.
get_attention : bool, optional
Whether to return the attention values. Default to False.
Returns Returns
------- -------
torch.Tensor torch.Tensor
The output feature of shape :math:`(N, H, D_{out})` where :math:`H` The output feature of shape :math:`(N, H, D_{out})` where :math:`H`
is the number of heads, and :math:`D_{out}` is size of output feature. is the number of heads, and :math:`D_{out}` is size of output feature.
torch.Tensor, optional
The attention values of shape :math:`(E, H, 1)`, where :math:`E` is the number of
edges. This is returned only when :attr:`get_attention` is ``True``.
Raises Raises
------ ------
...@@ -294,4 +299,8 @@ class GATConv(nn.Module): ...@@ -294,4 +299,8 @@ class GATConv(nn.Module):
# activation # activation
if self.activation: if self.activation:
rst = self.activation(rst) rst = self.activation(rst)
if get_attention:
return rst, graph.edata['a']
else:
return rst return rst
...@@ -195,7 +195,7 @@ class GATConv(layers.Layer): ...@@ -195,7 +195,7 @@ class GATConv(layers.Layer):
""" """
self._allow_zero_in_degree = set_value self._allow_zero_in_degree = set_value
def call(self, graph, feat): def call(self, graph, feat, get_attention=False):
r""" r"""
Description Description
...@@ -211,12 +211,17 @@ class GATConv(layers.Layer): ...@@ -211,12 +211,17 @@ class GATConv(layers.Layer):
:math:`D_{in}` is size of input feature, :math:`N` is the number of nodes. :math:`D_{in}` is size of input feature, :math:`N` is the number of nodes.
If a pair of tf.Tensor is given, the pair must contain two tensors of shape If a pair of tf.Tensor is given, the pair must contain two tensors of shape
:math:`(N_{in}, D_{in_{src}})` and :math:`(N_{out}, D_{in_{dst}})`. :math:`(N_{in}, D_{in_{src}})` and :math:`(N_{out}, D_{in_{dst}})`.
get_attention : bool, optional
Whether to return the attention values. Default to False.
Returns Returns
------- -------
tf.Tensor tf.Tensor
The output feature of shape :math:`(N, H, D_{out})` where :math:`H` The output feature of shape :math:`(N, H, D_{out})` where :math:`H`
is the number of heads, and :math:`D_{out}` is size of output feature. is the number of heads, and :math:`D_{out}` is size of output feature.
tf.Tensor, optional
The attention values of shape :math:`(E, H, 1)`, where :math:`E` is the number of
edges. This is returned only when :attr:`get_attention` is ``True``.
Raises Raises
------ ------
...@@ -282,4 +287,8 @@ class GATConv(layers.Layer): ...@@ -282,4 +287,8 @@ class GATConv(layers.Layer):
# activation # activation
if self.activation: if self.activation:
rst = self.activation(rst) rst = self.activation(rst)
if get_attention:
return rst, graph.edata['a']
else:
return rst return rst
...@@ -167,6 +167,8 @@ def test_gat_conv(g, idtype): ...@@ -167,6 +167,8 @@ def test_gat_conv(g, idtype):
feat = F.randn((g.number_of_nodes(), 10)) feat = F.randn((g.number_of_nodes(), 10))
h = gat(g, feat) h = gat(g, feat)
assert h.shape == (g.number_of_nodes(), 5, 20) assert h.shape == (g.number_of_nodes(), 5, 20)
_, a = gat(g, feat, True)
assert a.shape == (g.number_of_edges(), 5, 1)
@parametrize_dtype @parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree'])) @pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
...@@ -178,6 +180,8 @@ def test_gat_conv_bi(g, idtype): ...@@ -178,6 +180,8 @@ def test_gat_conv_bi(g, idtype):
feat = (F.randn((g.number_of_src_nodes(), 5)), F.randn((g.number_of_dst_nodes(), 5))) feat = (F.randn((g.number_of_src_nodes(), 5)), F.randn((g.number_of_dst_nodes(), 5)))
h = gat(g, feat) h = gat(g, feat)
assert h.shape == (g.number_of_dst_nodes(), 4, 2) assert h.shape == (g.number_of_dst_nodes(), 4, 2)
_, a = gat(g, feat, True)
assert a.shape == (g.number_of_edges(), 4, 1)
@parametrize_dtype @parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'])) @pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite']))
......
...@@ -379,6 +379,8 @@ def test_gat_conv(g, idtype): ...@@ -379,6 +379,8 @@ def test_gat_conv(g, idtype):
gat = gat.to(ctx) gat = gat.to(ctx)
h = gat(g, feat) h = gat(g, feat)
assert h.shape == (g.number_of_nodes(), 4, 2) assert h.shape == (g.number_of_nodes(), 4, 2)
_, a = gat(g, feat, get_attention=True)
assert a.shape == (g.number_of_edges(), 4, 1)
@parametrize_dtype @parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree'])) @pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
...@@ -390,6 +392,8 @@ def test_gat_conv_bi(g, idtype): ...@@ -390,6 +392,8 @@ def test_gat_conv_bi(g, idtype):
gat = gat.to(ctx) gat = gat.to(ctx)
h = gat(g, feat) h = gat(g, feat)
assert h.shape == (g.number_of_dst_nodes(), 4, 2) assert h.shape == (g.number_of_dst_nodes(), 4, 2)
_, a = gat(g, feat, get_attention=True)
assert a.shape == (g.number_of_edges(), 4, 1)
@parametrize_dtype @parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'])) @pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite']))
......
...@@ -266,6 +266,8 @@ def test_gat_conv(g, idtype): ...@@ -266,6 +266,8 @@ def test_gat_conv(g, idtype):
feat = F.randn((g.number_of_nodes(), 5)) feat = F.randn((g.number_of_nodes(), 5))
h = gat(g, feat) h = gat(g, feat)
assert h.shape == (g.number_of_nodes(), 4, 2) assert h.shape == (g.number_of_nodes(), 4, 2)
_, a = gat(g, feat, get_attention=True)
assert a.shape == (g.number_of_edges(), 4, 1)
@parametrize_dtype @parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree'])) @pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
...@@ -276,6 +278,8 @@ def test_gat_conv_bi(g, idtype): ...@@ -276,6 +278,8 @@ def test_gat_conv_bi(g, idtype):
feat = (F.randn((g.number_of_src_nodes(), 5)), F.randn((g.number_of_dst_nodes(), 5))) feat = (F.randn((g.number_of_src_nodes(), 5)), F.randn((g.number_of_dst_nodes(), 5)))
h = gat(g, feat) h = gat(g, feat)
assert h.shape == (g.number_of_dst_nodes(), 4, 2) assert h.shape == (g.number_of_dst_nodes(), 4, 2)
_, a = gat(g, feat, get_attention=True)
assert a.shape == (g.number_of_edges(), 4, 1)
@parametrize_dtype @parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'])) @pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite']))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment