Unverified Commit 013d1456 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[NN] Attention Retrieval for NN Modules (#2397)

* Update

* Update
parent c038b71f
......@@ -201,7 +201,7 @@ class GATConv(nn.Block):
"""
self._allow_zero_in_degree = set_value
def forward(self, graph, feat):
def forward(self, graph, feat, get_attention=False):
r"""
Description
......@@ -217,12 +217,17 @@ class GATConv(nn.Block):
:math:`D_{in}` is size of input feature, :math:`N` is the number of nodes.
If a pair of mxnet.NDArray is given, the pair must contain two tensors of shape
:math:`(N_{in}, D_{in_{src}})` and :math:`(N_{out}, D_{in_{dst}})`.
get_attention : bool, optional
Whether to return the attention values. Default to False.
Returns
-------
mxnet.NDArray
The output feature of shape :math:`(N, H, D_{out})` where :math:`H`
is the number of heads, and :math:`D_{out}` is size of output feature.
mxnet.NDArray, optional
The attention values of shape :math:`(E, H, 1)`, where :math:`E` is the number of
edges. This is returned only when :attr:`get_attention` is ``True``.
Raises
------
......@@ -288,4 +293,8 @@ class GATConv(nn.Block):
# activation
if self.activation:
rst = self.activation(rst)
if get_attention:
return rst, graph.edata['a']
else:
return rst
......@@ -114,7 +114,7 @@ class DotGatConv(nn.Module):
else:
self.fc = nn.Linear(self._in_src_feats, self._out_feats, bias=False)
def forward(self, graph, feat):
def forward(self, graph, feat, get_attention=False):
r"""
Description
......@@ -130,12 +130,17 @@ class DotGatConv(nn.Module):
:math:`D_{in}` is size of input feature, :math:`N` is the number of nodes.
If a pair of torch.Tensor is given, the pair must contain two tensors of shape
:math:`(N_{in}, D_{in_{src}})` and :math:`(N_{out}, D_{in_{dst}})`.
get_attention : bool, optional
Whether to return the attention values. Default to False.
Returns
-------
torch.Tensor
The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}` is size
of output feature.
torch.Tensor, optional
The attention values of shape :math:`(E, 1)`, where :math:`E` is the number of
edges. This is returned only when :attr:`get_attention` is ``True``.
Raises
------
......@@ -187,4 +192,7 @@ class DotGatConv(nn.Module):
# output results to the destination nodes
rst = graph.dstdata['agg_u']
if get_attention:
return rst, graph.edata['sa']
else:
return rst
......@@ -208,7 +208,7 @@ class GATConv(nn.Module):
"""
self._allow_zero_in_degree = set_value
def forward(self, graph, feat):
def forward(self, graph, feat, get_attention=False):
r"""
Description
......@@ -224,12 +224,17 @@ class GATConv(nn.Module):
:math:`D_{in}` is size of input feature, :math:`N` is the number of nodes.
If a pair of torch.Tensor is given, the pair must contain two tensors of shape
:math:`(N_{in}, D_{in_{src}})` and :math:`(N_{out}, D_{in_{dst}})`.
get_attention : bool, optional
Whether to return the attention values. Default to False.
Returns
-------
torch.Tensor
The output feature of shape :math:`(N, H, D_{out})` where :math:`H`
is the number of heads, and :math:`D_{out}` is size of output feature.
torch.Tensor, optional
The attention values of shape :math:`(E, H, 1)`, where :math:`E` is the number of
edges. This is returned only when :attr:`get_attention` is ``True``.
Raises
------
......@@ -294,4 +299,8 @@ class GATConv(nn.Module):
# activation
if self.activation:
rst = self.activation(rst)
if get_attention:
return rst, graph.edata['a']
else:
return rst
......@@ -195,7 +195,7 @@ class GATConv(layers.Layer):
"""
self._allow_zero_in_degree = set_value
def call(self, graph, feat):
def call(self, graph, feat, get_attention=False):
r"""
Description
......@@ -211,12 +211,17 @@ class GATConv(layers.Layer):
:math:`D_{in}` is size of input feature, :math:`N` is the number of nodes.
If a pair of tf.Tensor is given, the pair must contain two tensors of shape
:math:`(N_{in}, D_{in_{src}})` and :math:`(N_{out}, D_{in_{dst}})`.
get_attention : bool, optional
Whether to return the attention values. Default to False.
Returns
-------
tf.Tensor
The output feature of shape :math:`(N, H, D_{out})` where :math:`H`
is the number of heads, and :math:`D_{out}` is size of output feature.
tf.Tensor, optional
The attention values of shape :math:`(E, H, 1)`, where :math:`E` is the number of
edges. This is returned only when :attr:`get_attention` is ``True``.
Raises
------
......@@ -282,4 +287,8 @@ class GATConv(layers.Layer):
# activation
if self.activation:
rst = self.activation(rst)
if get_attention:
return rst, graph.edata['a']
else:
return rst
......@@ -167,6 +167,8 @@ def test_gat_conv(g, idtype):
feat = F.randn((g.number_of_nodes(), 10))
h = gat(g, feat)
assert h.shape == (g.number_of_nodes(), 5, 20)
_, a = gat(g, feat, True)
assert a.shape == (g.number_of_edges(), 5, 1)
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
......@@ -178,6 +180,8 @@ def test_gat_conv_bi(g, idtype):
feat = (F.randn((g.number_of_src_nodes(), 5)), F.randn((g.number_of_dst_nodes(), 5)))
h = gat(g, feat)
assert h.shape == (g.number_of_dst_nodes(), 4, 2)
_, a = gat(g, feat, True)
assert a.shape == (g.number_of_edges(), 4, 1)
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite']))
......
......@@ -379,6 +379,8 @@ def test_gat_conv(g, idtype):
gat = gat.to(ctx)
h = gat(g, feat)
assert h.shape == (g.number_of_nodes(), 4, 2)
_, a = gat(g, feat, get_attention=True)
assert a.shape == (g.number_of_edges(), 4, 1)
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
......@@ -390,6 +392,8 @@ def test_gat_conv_bi(g, idtype):
gat = gat.to(ctx)
h = gat(g, feat)
assert h.shape == (g.number_of_dst_nodes(), 4, 2)
_, a = gat(g, feat, get_attention=True)
assert a.shape == (g.number_of_edges(), 4, 1)
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite']))
......
......@@ -266,6 +266,8 @@ def test_gat_conv(g, idtype):
feat = F.randn((g.number_of_nodes(), 5))
h = gat(g, feat)
assert h.shape == (g.number_of_nodes(), 4, 2)
_, a = gat(g, feat, get_attention=True)
assert a.shape == (g.number_of_edges(), 4, 1)
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
......@@ -276,6 +278,8 @@ def test_gat_conv_bi(g, idtype):
feat = (F.randn((g.number_of_src_nodes(), 5)), F.randn((g.number_of_dst_nodes(), 5)))
h = gat(g, feat)
assert h.shape == (g.number_of_dst_nodes(), 4, 2)
_, a = gat(g, feat, get_attention=True)
assert a.shape == (g.number_of_edges(), 4, 1)
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite']))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment