"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "94b27fb8da1002a560449f4b8c0fc92e22115c40"
Unverified Commit 12e97c54 authored by Sai Kandregula's avatar Sai Kandregula Committed by GitHub
Browse files

[Feature] get_attention parameter in GlobalAttentionPooling (#3837)



* get_attention parameter in GlobalAttentionPooling

* removed trailing whitespace

* lint fix
Co-authored-by: default avatardecoherencer <decoherencer@users.noreply.github.com>
Co-authored-by: default avatarMufei Li <mufeili1996@gmail.com>
Co-authored-by: default avatarQuan (Andy) Gan <coin2028@hotmail.com>
parent e9fd65e9
......@@ -419,7 +419,7 @@ class GlobalAttentionPooling(nn.Module):
self.gate_nn = gate_nn
self.feat_nn = feat_nn
def forward(self, graph, feat):
def forward(self, graph, feat, get_attention=False):
r"""
Compute global attention pooling.
......@@ -431,12 +431,17 @@ class GlobalAttentionPooling(nn.Module):
feat : torch.Tensor
The input node feature with shape :math:`(N, D)` where :math:`N` is the
number of nodes in the graph, and :math:`D` means the size of features.
get_attention : bool, optional
Whether to return the attention values from gate_nn. Default to False.
Returns
-------
torch.Tensor
The output feature with shape :math:`(B, D)`, where :math:`B` refers
to the batch size.
torch.Tensor, optional
The attention values of shape :math:`(N, 1)`, where :math:`N` is the number of
nodes in the graph. This is returned only when :attr:`get_attention` is ``True``.
"""
with graph.local_scope():
gate = self.gate_nn(feat)
......@@ -451,7 +456,10 @@ class GlobalAttentionPooling(nn.Module):
readout = sum_nodes(graph, 'r')
graph.ndata.pop('r')
return readout
if get_attention:
return readout, gate
else:
return readout
class Set2Set(nn.Module):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment