Unverified Commit 12e97c54 authored by Sai Kandregula's avatar Sai Kandregula Committed by GitHub
Browse files

[Feature] get_attention parameter in GlobalAttentionPooling (#3837)



* get_attention parameter in GlobalAttentionPooling

* removed trailing whitespace

* lint fix
Co-authored-by: default avatardecoherencer <decoherencer@users.noreply.github.com>
Co-authored-by: default avatarMufei Li <mufeili1996@gmail.com>
Co-authored-by: default avatarQuan (Andy) Gan <coin2028@hotmail.com>
parent e9fd65e9
...@@ -419,7 +419,7 @@ class GlobalAttentionPooling(nn.Module): ...@@ -419,7 +419,7 @@ class GlobalAttentionPooling(nn.Module):
self.gate_nn = gate_nn self.gate_nn = gate_nn
self.feat_nn = feat_nn self.feat_nn = feat_nn
def forward(self, graph, feat): def forward(self, graph, feat, get_attention=False):
r""" r"""
Compute global attention pooling. Compute global attention pooling.
...@@ -431,12 +431,17 @@ class GlobalAttentionPooling(nn.Module): ...@@ -431,12 +431,17 @@ class GlobalAttentionPooling(nn.Module):
feat : torch.Tensor feat : torch.Tensor
The input node feature with shape :math:`(N, D)` where :math:`N` is the The input node feature with shape :math:`(N, D)` where :math:`N` is the
number of nodes in the graph, and :math:`D` means the size of features. number of nodes in the graph, and :math:`D` means the size of features.
get_attention : bool, optional
Whether to return the attention values from gate_nn. Default to False.
Returns Returns
------- -------
torch.Tensor torch.Tensor
The output feature with shape :math:`(B, D)`, where :math:`B` refers The output feature with shape :math:`(B, D)`, where :math:`B` refers
to the batch size. to the batch size.
torch.Tensor, optional
The attention values of shape :math:`(N, 1)`, where :math:`N` is the number of
nodes in the graph. This is returned only when :attr:`get_attention` is ``True``.
""" """
with graph.local_scope(): with graph.local_scope():
gate = self.gate_nn(feat) gate = self.gate_nn(feat)
...@@ -451,7 +456,10 @@ class GlobalAttentionPooling(nn.Module): ...@@ -451,7 +456,10 @@ class GlobalAttentionPooling(nn.Module):
readout = sum_nodes(graph, 'r') readout = sum_nodes(graph, 'r')
graph.ndata.pop('r') graph.ndata.pop('r')
return readout if get_attention:
return readout, gate
else:
return readout
class Set2Set(nn.Module): class Set2Set(nn.Module):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment