"""Readout for AttentiveFP""" import dgl import torch import torch.nn as nn import torch.nn.functional as F __all__ = ['AttentiveFPReadout'] class GlobalPool(nn.Module): """One-step readout in AttentiveFP Parameters ---------- feat_size : int Size for the input node features, graph features and output graph representations. dropout : float The probability for performing dropout. """ def __init__(self, feat_size, dropout): super(GlobalPool, self).__init__() self.compute_logits = nn.Sequential( nn.Linear(2 * feat_size, 1), nn.LeakyReLU() ) self.project_nodes = nn.Sequential( nn.Dropout(dropout), nn.Linear(feat_size, feat_size) ) self.gru = nn.GRUCell(feat_size, feat_size) def forward(self, g, node_feats, g_feats, get_node_weight=False): """Perform one-step readout Parameters ---------- g : DGLGraph DGLGraph for a batch of graphs. node_feats : float32 tensor of shape (V, node_feat_size) Input node features. V for the number of nodes. g_feats : float32 tensor of shape (G, graph_feat_size) Input graph features. G for the number of graphs. get_node_weight : bool Whether to get the weights of atoms during readout. Returns ------- float32 tensor of shape (G, graph_feat_size) Updated graph features. float32 tensor of shape (V, 1) The weights of nodes in readout. """ with g.local_scope(): g.ndata['z'] = self.compute_logits( torch.cat([dgl.broadcast_nodes(g, F.relu(g_feats)), node_feats], dim=1)) g.ndata['a'] = dgl.softmax_nodes(g, 'z') g.ndata['hv'] = self.project_nodes(node_feats) g_repr = dgl.sum_nodes(g, 'hv', 'a') context = F.elu(g_repr) if get_node_weight: return self.gru(context, g_feats), g.ndata['a'] else: return self.gru(context, g_feats) class AttentiveFPReadout(nn.Module): """Readout in AttentiveFP AttentiveFP is introduced in `Pushing the Boundaries of Molecular Representation for Drug Discovery with the Graph Attention Mechanism `__ This class computes graph representations out of node features. Parameters ---------- feat_size : int Size for the input node features, graph features and output graph representations. num_timesteps : int Times of updating the graph representations with GRU. Default to 2. dropout : float The probability for performing dropout. Default to 0. """ def __init__(self, feat_size, num_timesteps=2, dropout=0.): super(AttentiveFPReadout, self).__init__() self.readouts = nn.ModuleList() for t in range(num_timesteps): self.readouts.append(GlobalPool(feat_size, dropout)) def forward(self, g, node_feats, get_node_weight=False): """Computes graph representations out of node features. Parameters ---------- g : DGLGraph DGLGraph for a batch of graphs. node_feats : float32 tensor of shape (V, node_feat_size) Input node features. V for the number of nodes. get_node_weight : bool Whether to get the weights of nodes in readout. Default to False. Returns ------- g_feats : float32 tensor of shape (G, graph_feat_size) Graph representations computed. G for the number of graphs. node_weights : list of float32 tensor of shape (V, 1), optional This is returned when ``get_node_weight`` is ``True``. The list has a length ``num_timesteps`` and ``node_weights[i]`` gives the node weights in the i-th update. """ with g.local_scope(): g.ndata['hv'] = node_feats g_feats = dgl.sum_nodes(g, 'hv') if get_node_weight: node_weights = [] for readout in self.readouts: if get_node_weight: g_feats, node_weights_t = readout(g, node_feats, g_feats, get_node_weight) node_weights.append(node_weights_t) else: g_feats = readout(g, node_feats, g_feats) if get_node_weight: return g_feats, node_weights else: return g_feats