Unverified Commit 36c7b771 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[LifeSci] Move to Independent Repo (#1592)

* Move LifeSci

* Remove doc
parent 94c67203
Model Zoo API
==================
We provide two major APIs for the model zoo. For the time being, only PyTorch is supported.
- `model_zoo.chem.[Model_Name]` to load the model skeleton
- `model_zoo.chem.load_pretrained([Pretrained_Model_Name])` to load the model with pretrained weights
Models would be placed in `python/dgl/model_zoo/chem`.
Each Model should contain the following elements:
- Papers related to the model
- Model's input and output
- Dataset compatible with the model
- Documentation for all the customizable configs
- Credits (Contributor infomation)
"""Package for model zoo."""
from . import chem
# pylint: disable=C0111
"""Model Zoo Package"""
from .classifiers import GCNClassifier, GATClassifier
from .schnet import SchNet
from .mgcn import MGCNModel
from .mpnn import MPNNModel
from .dgmg import DGMG
from .jtnn import DGLJTNNVAE
from .pretrain import load_pretrained
from .attentive_fp import AttentiveFP
from .acnn import ACNN
"""Atomic Convolutional Networks for Predicting Protein-Ligand Binding Affinity"""
# pylint: disable=C0103, C0123
import itertools
import torch
import torch.nn as nn
from ...nn.pytorch import AtomicConv
from ...contrib.deprecation import deprecated
def truncated_normal_(tensor, mean=0., std=1.):
"""Fills the given tensor in-place with elements sampled from the truncated normal
distribution parameterized by mean and std.
The generated values follow a normal distribution with specified mean and
standard deviation, except that values whose magnitude is more than 2 std
from the mean are dropped.
We credit to Ruotian Luo for this implementation:
https://discuss.pytorch.org/t/implementing-truncated-normal-initializer/4778/15.
Parameters
----------
tensor : Float32 tensor of arbitrary shape
Tensor to be filled.
mean : float
Mean of the truncated normal distribution.
std : float
Standard deviation of the truncated normal distribution.
"""
shape = tensor.shape
tmp = tensor.new_empty(shape + (4,)).normal_()
valid = (tmp < 2) & (tmp > -2)
ind = valid.max(-1, keepdim=True)[1]
tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1))
tensor.data.mul_(std).add_(mean)
class ACNNPredictor(nn.Module):
"""Predictor for ACNN.
Parameters
----------
in_size : int
Number of radial filters used.
hidden_sizes : list of int
Specifying the hidden sizes for all layers in the predictor.
weight_init_stddevs : list of float
Specifying the standard deviations to use for truncated normal
distributions in initialzing weights for the predictor.
dropouts : list of float
Specifying the dropouts to use for all layers in the predictor.
features_to_use : None or float tensor of shape (T)
In the original paper, these are atomic numbers to consider, representing the types
of atoms. T for the number of types of atomic numbers. Default to None.
num_tasks : int
Output size.
"""
def __init__(self, in_size, hidden_sizes, weight_init_stddevs,
dropouts, features_to_use, num_tasks):
super(ACNNPredictor, self).__init__()
if type(features_to_use) != type(None):
in_size *= len(features_to_use)
modules = []
for i, h in enumerate(hidden_sizes):
linear_layer = nn.Linear(in_size, h)
truncated_normal_(linear_layer.weight, std=weight_init_stddevs[i])
modules.append(linear_layer)
modules.append(nn.ReLU())
modules.append(nn.Dropout(dropouts[i]))
in_size = h
linear_layer = nn.Linear(in_size, num_tasks)
truncated_normal_(linear_layer.weight, std=weight_init_stddevs[-1])
modules.append(linear_layer)
self.project = nn.Sequential(*modules)
def forward(self, batch_size, frag1_node_indices_in_complex, frag2_node_indices_in_complex,
ligand_conv_out, protein_conv_out, complex_conv_out):
"""Perform the prediction.
Parameters
----------
batch_size : int
Number of datapoints in a batch.
frag1_node_indices_in_complex : Int64 tensor of shape (V1)
Indices for atoms in the first fragment (protein) in the batched complex.
frag2_node_indices_in_complex : list of int of length V2
Indices for atoms in the second fragment (ligand) in the batched complex.
ligand_conv_out : Float32 tensor of shape (V2, K * T)
Updated ligand node representations. V2 for the number of atoms in the
ligand, K for the number of radial filters, and T for the number of types
of atomic numbers.
protein_conv_out : Float32 tensor of shape (V1, K * T)
Updated protein node representations. V1 for the number of
atoms in the protein, K for the number of radial filters,
and T for the number of types of atomic numbers.
complex_conv_out : Float32 tensor of shape (V1 + V2, K * T)
Updated complex node representations. V1 and V2 separately
for the number of atoms in the ligand and protein, K for
the number of radial filters, and T for the number of
types of atomic numbers.
Returns
-------
Float32 tensor of shape (B, O)
Predicted protein-ligand binding affinity. B for the number
of protein-ligand pairs in the batch and O for the number of tasks.
"""
ligand_feats = self.project(ligand_conv_out) # (V1, O)
protein_feats = self.project(protein_conv_out) # (V2, O)
complex_feats = self.project(complex_conv_out) # (V1+V2, O)
ligand_energy = ligand_feats.reshape(batch_size, -1).sum(-1, keepdim=True) # (B, O)
protein_energy = protein_feats.reshape(batch_size, -1).sum(-1, keepdim=True) # (B, O)
complex_ligand_energy = complex_feats[frag1_node_indices_in_complex].reshape(
batch_size, -1).sum(-1, keepdim=True)
complex_protein_energy = complex_feats[frag2_node_indices_in_complex].reshape(
batch_size, -1).sum(-1, keepdim=True)
complex_energy = complex_ligand_energy + complex_protein_energy
return complex_energy - (ligand_energy + protein_energy)
class ACNN(nn.Module):
"""Atomic Convolutional Networks.
The model was proposed in `Atomic Convolutional Networks for
Predicting Protein-Ligand Binding Affinity <https://arxiv.org/abs/1703.10603>`__.
Parameters
----------
hidden_sizes : list of int
Specifying the hidden sizes for all layers in the predictor.
weight_init_stddevs : list of float
Specifying the standard deviations to use for truncated normal
distributions in initialzing weights for the predictor.
dropouts : list of float
Specifying the dropouts to use for all layers in the predictor.
features_to_use : None or float tensor of shape (T)
In the original paper, these are atomic numbers to consider, representing the types
of atoms. T for the number of types of atomic numbers. Default to None.
radial : None or list
If not None, the list consists of 3 lists of floats, separately for the
options of interaction cutoff, the options of rbf kernel mean and the
options of rbf kernel scaling. If None, a default option of
``[[12.0], [0.0, 2.0, 4.0, 6.0, 8.0], [4.0]]`` will be used.
num_tasks : int
Number of output tasks.
"""
@deprecated('Import ACNN from dgllife.model instead.')
def __init__(self, hidden_sizes, weight_init_stddevs, dropouts,
features_to_use=None, radial=None, num_tasks=1):
super(ACNN, self).__init__()
if radial is None:
radial = [[12.0], [0.0, 2.0, 4.0, 6.0, 8.0], [4.0]]
# Take the product of sets of options and get a list of 3-tuples.
radial_params = [x for x in itertools.product(*radial)]
radial_params = torch.stack(list(map(torch.tensor, zip(*radial_params))), dim=1)
interaction_cutoffs = radial_params[:, 0]
rbf_kernel_means = radial_params[:, 1]
rbf_kernel_scaling = radial_params[:, 2]
self.ligand_conv = AtomicConv(interaction_cutoffs, rbf_kernel_means,
rbf_kernel_scaling, features_to_use)
self.protein_conv = AtomicConv(interaction_cutoffs, rbf_kernel_means,
rbf_kernel_scaling, features_to_use)
self.complex_conv = AtomicConv(interaction_cutoffs, rbf_kernel_means,
rbf_kernel_scaling, features_to_use)
self.predictor = ACNNPredictor(radial_params.shape[0], hidden_sizes,
weight_init_stddevs, dropouts, features_to_use, num_tasks)
def forward(self, graph):
"""Apply the model for prediction.
Parameters
----------
graph : DGLHeteroGraph
DGLHeteroGraph consisting of the ligand graph, the protein graph
and the complex graph, along with preprocessed features.
Returns
-------
Float32 tensor of shape (B, O)
Predicted protein-ligand binding affinity. B for the number
of protein-ligand pairs in the batch and O for the number of tasks.
"""
ligand_graph = graph[('ligand_atom', 'ligand', 'ligand_atom')]
ligand_graph_node_feats = ligand_graph.ndata['atomic_number']
assert ligand_graph_node_feats.shape[-1] == 1
ligand_graph_distances = ligand_graph.edata['distance']
ligand_conv_out = self.ligand_conv(ligand_graph,
ligand_graph_node_feats,
ligand_graph_distances)
protein_graph = graph[('protein_atom', 'protein', 'protein_atom')]
protein_graph_node_feats = protein_graph.ndata['atomic_number']
assert protein_graph_node_feats.shape[-1] == 1
protein_graph_distances = protein_graph.edata['distance']
protein_conv_out = self.protein_conv(protein_graph,
protein_graph_node_feats,
protein_graph_distances)
complex_graph = graph[:, 'complex', :]
complex_graph_node_feats = complex_graph.ndata['atomic_number']
assert complex_graph_node_feats.shape[-1] == 1
complex_graph_distances = complex_graph.edata['distance']
complex_conv_out = self.complex_conv(complex_graph,
complex_graph_node_feats,
complex_graph_distances)
frag1_node_indices_in_complex = torch.where(complex_graph.ndata['_TYPE'] == 0)[0]
frag2_node_indices_in_complex = list(set(range(complex_graph.number_of_nodes())) -
set(frag1_node_indices_in_complex.tolist()))
return self.predictor(
graph.batch_size,
frag1_node_indices_in_complex,
frag2_node_indices_in_complex,
ligand_conv_out, protein_conv_out, complex_conv_out)
# pylint: disable=C0103, W0612, E1101
"""Pushing the Boundaries of Molecular Representation for Drug Discovery
with the Graph Attention Mechanism"""
import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F
from ... import function as fn
from ...contrib.deprecation import deprecated
from ...nn.pytorch.softmax import edge_softmax
class AttentiveGRU1(nn.Module):
"""Update node features with attention and GRU.
Parameters
----------
node_feat_size : int
Size for the input node (atom) features.
edge_feat_size : int
Size for the input edge (bond) features.
edge_hidden_size : int
Size for the intermediate edge (bond) representations.
dropout : float
The probability for performing dropout.
"""
def __init__(self, node_feat_size, edge_feat_size, edge_hidden_size, dropout):
super(AttentiveGRU1, self).__init__()
self.edge_transform = nn.Sequential(
nn.Dropout(dropout),
nn.Linear(edge_feat_size, edge_hidden_size)
)
self.gru = nn.GRUCell(edge_hidden_size, node_feat_size)
def forward(self, g, edge_logits, edge_feats, node_feats):
"""
Parameters
----------
g : DGLGraph
edge_logits : float32 tensor of shape (E, 1)
The edge logits based on which softmax will be performed for weighting
edges within 1-hop neighborhoods. E represents the number of edges.
edge_feats : float32 tensor of shape (E, M1)
Previous edge features.
node_feats : float32 tensor of shape (V, M2)
Previous node features.
Returns
-------
float32 tensor of shape (V, M2)
Updated node features.
"""
g = g.local_var()
g.edata['e'] = edge_softmax(g, edge_logits) * self.edge_transform(edge_feats)
g.update_all(fn.copy_edge('e', 'm'), fn.sum('m', 'c'))
context = F.elu(g.ndata['c'])
return F.relu(self.gru(context, node_feats))
class AttentiveGRU2(nn.Module):
"""Update node features with attention and GRU.
Parameters
----------
node_feat_size : int
Size for the input node (atom) features.
edge_hidden_size : int
Size for the intermediate edge (bond) representations.
dropout : float
The probability for performing dropout.
"""
def __init__(self, node_feat_size, edge_hidden_size, dropout):
super(AttentiveGRU2, self).__init__()
self.project_node = nn.Sequential(
nn.Dropout(dropout),
nn.Linear(node_feat_size, edge_hidden_size)
)
self.gru = nn.GRUCell(edge_hidden_size, node_feat_size)
def forward(self, g, edge_logits, node_feats):
"""
Parameters
----------
g : DGLGraph
edge_logits : float32 tensor of shape (E, 1)
The edge logits based on which softmax will be performed for weighting
edges within 1-hop neighborhoods. E represents the number of edges.
node_feats : float32 tensor of shape (V, M2)
Previous node features.
Returns
-------
float32 tensor of shape (V, M2)
Updated node features.
"""
g = g.local_var()
g.edata['a'] = edge_softmax(g, edge_logits)
g.ndata['hv'] = self.project_node(node_feats)
g.update_all(fn.src_mul_edge('hv', 'a', 'm'), fn.sum('m', 'c'))
context = F.elu(g.ndata['c'])
return F.relu(self.gru(context, node_feats))
class GetContext(nn.Module):
"""Generate context for each node (atom) by message passing at the beginning.
Parameters
----------
node_feat_size : int
Size for the input node (atom) features.
edge_feat_size : int
Size for the input edge (bond) features.
graph_feat_size : int
Size of the learned graph representation (molecular fingerprint).
dropout : float
The probability for performing dropout.
"""
def __init__(self, node_feat_size, edge_feat_size, graph_feat_size, dropout):
super(GetContext, self).__init__()
self.project_node = nn.Sequential(
nn.Linear(node_feat_size, graph_feat_size),
nn.LeakyReLU()
)
self.project_edge1 = nn.Sequential(
nn.Linear(node_feat_size + edge_feat_size, graph_feat_size),
nn.LeakyReLU()
)
self.project_edge2 = nn.Sequential(
nn.Dropout(dropout),
nn.Linear(2 * graph_feat_size, 1),
nn.LeakyReLU()
)
self.attentive_gru = AttentiveGRU1(graph_feat_size, graph_feat_size,
graph_feat_size, dropout)
def apply_edges1(self, edges):
"""Edge feature update."""
return {'he1': torch.cat([edges.src['hv'], edges.data['he']], dim=1)}
def apply_edges2(self, edges):
"""Edge feature update."""
return {'he2': torch.cat([edges.dst['hv_new'], edges.data['he1']], dim=1)}
def forward(self, g, node_feats, edge_feats):
"""
Parameters
----------
g : DGLGraph
Constructed DGLGraphs.
node_feats : float32 tensor of shape (V, N1)
Input node features. V for the number of nodes and N1 for the feature size.
edge_feats : float32 tensor of shape (E, N2)
Input edge features. E for the number of edges and N2 for the feature size.
Returns
-------
float32 tensor of shape (V, N3)
Updated node features.
"""
g = g.local_var()
g.ndata['hv'] = node_feats
g.ndata['hv_new'] = self.project_node(node_feats)
g.edata['he'] = edge_feats
g.apply_edges(self.apply_edges1)
g.edata['he1'] = self.project_edge1(g.edata['he1'])
g.apply_edges(self.apply_edges2)
logits = self.project_edge2(g.edata['he2'])
return self.attentive_gru(g, logits, g.edata['he1'], g.ndata['hv_new'])
class GNNLayer(nn.Module):
"""GNNLayer for updating node features.
Parameters
----------
node_feat_size : int
Size for the input node features.
graph_feat_size : int
Size for the input graph features.
dropout : float
The probability for performing dropout.
"""
def __init__(self, node_feat_size, graph_feat_size, dropout):
super(GNNLayer, self).__init__()
self.project_edge = nn.Sequential(
nn.Dropout(dropout),
nn.Linear(2 * node_feat_size, 1),
nn.LeakyReLU()
)
self.attentive_gru = AttentiveGRU2(node_feat_size, graph_feat_size, dropout)
def apply_edges(self, edges):
"""Edge feature update by concatenating the features of the destination
and source nodes."""
return {'he': torch.cat([edges.dst['hv'], edges.src['hv']], dim=1)}
def forward(self, g, node_feats):
"""
Parameters
----------
g : DGLGraph
Constructed DGLGraphs.
node_feats : float32 tensor of shape (V, N1)
Input node features. V for the number of nodes and N1 for the feature size.
Returns
-------
float32 tensor of shape (V, N1)
Updated node features.
"""
g = g.local_var()
g.ndata['hv'] = node_feats
g.apply_edges(self.apply_edges)
logits = self.project_edge(g.edata['he'])
return self.attentive_gru(g, logits, node_feats)
class GlobalPool(nn.Module):
"""Graph feature update.
Parameters
----------
node_feat_size : int
Size for the input node features.
graph_feat_size : int
Size for the input graph features.
dropout : float
The probability for performing dropout.
"""
def __init__(self, node_feat_size, graph_feat_size, dropout):
super(GlobalPool, self).__init__()
self.compute_logits = nn.Sequential(
nn.Linear(node_feat_size + graph_feat_size, 1),
nn.LeakyReLU()
)
self.project_nodes = nn.Sequential(
nn.Dropout(dropout),
nn.Linear(node_feat_size, graph_feat_size)
)
self.gru = nn.GRUCell(graph_feat_size, graph_feat_size)
def forward(self, g, node_feats, g_feats, get_node_weight=False):
"""
Parameters
----------
g : DGLGraph
Constructed DGLGraphs.
node_feats : float32 tensor of shape (V, N1)
Input node features. V for the number of nodes and N1 for the feature size.
g_feats : float32 tensor of shape (G, N2)
Input graph features. G for the number of graphs and N2 for the feature size.
get_node_weight : bool
Whether to get the weights of atoms during readout.
Returns
-------
float32 tensor of shape (G, N2)
Updated graph features.
float32 tensor of shape (V, 1)
The weights of nodes in readout.
"""
with g.local_scope():
g.ndata['z'] = self.compute_logits(
torch.cat([dgl.broadcast_nodes(g, F.relu(g_feats)), node_feats], dim=1))
g.ndata['a'] = dgl.softmax_nodes(g, 'z')
g.ndata['hv'] = self.project_nodes(node_feats)
context = F.elu(dgl.sum_nodes(g, 'hv', 'a'))
if get_node_weight:
return self.gru(context, g_feats), g.ndata['a']
else:
return self.gru(context, g_feats)
class AttentiveFP(nn.Module):
"""`Pushing the Boundaries of Molecular Representation for Drug Discovery with the Graph
Attention Mechanism <https://www.ncbi.nlm.nih.gov/pubmed/31408336>`__
Parameters
----------
node_feat_size : int
Size for the input node (atom) features.
edge_feat_size : int
Size for the input edge (bond) features.
num_layers : int
Number of GNN layers.
num_timesteps : int
Number of timesteps for updating the molecular representation with GRU.
graph_feat_size : int
Size of the learned graph representation (molecular fingerprint).
output_size : int
Size of the prediction (target labels).
dropout : float
The probability for performing dropout.
"""
@deprecated('Import AttentiveFPPredictor from dgllife.model instead.', 'class')
def __init__(self,
node_feat_size,
edge_feat_size,
num_layers,
num_timesteps,
graph_feat_size,
output_size,
dropout):
super(AttentiveFP, self).__init__()
self.init_context = GetContext(node_feat_size, edge_feat_size, graph_feat_size, dropout)
self.gnn_layers = nn.ModuleList()
for i in range(num_layers - 1):
self.gnn_layers.append(GNNLayer(graph_feat_size, graph_feat_size, dropout))
self.readouts = nn.ModuleList()
for t in range(num_timesteps):
self.readouts.append(GlobalPool(graph_feat_size, graph_feat_size, dropout))
self.predict = nn.Sequential(
nn.Dropout(dropout),
nn.Linear(graph_feat_size, output_size)
)
def forward(self, g, node_feats, edge_feats, get_node_weight=False):
"""
Parameters
----------
g : DGLGraph
Constructed DGLGraphs.
node_feats : float32 tensor of shape (V, N1)
Input node features. V for the number of nodes and N1 for the feature size.
edge_feats : float32 tensor of shape (E, N2)
Input edge features. E for the number of edges and N2 for the feature size.
get_node_weight : bool
Whether to get the weights of atoms during readout.
Returns
-------
float32 tensor of shape (G, N3)
Prediction for the graphs. G for the number of graphs and N3 for the output size.
node_weights : list of float32 tensors of shape (V, 1)
Weights of nodes in all readout operations.
"""
node_feats = self.init_context(g, node_feats, edge_feats)
for gnn in self.gnn_layers:
node_feats = gnn(g, node_feats)
with g.local_scope():
g.ndata['hv'] = node_feats
g_feats = dgl.sum_nodes(g, 'hv')
if get_node_weight:
node_weights = []
for readout in self.readouts:
if get_node_weight:
g_feats, node_weights_t = readout(g, node_feats, g_feats, get_node_weight)
node_weights.append(node_weights_t)
else:
g_feats = readout(g, node_feats, g_feats)
if get_node_weight:
return self.predict(g_feats), node_weights
else:
return self.predict(g_feats)
# pylint: disable=C0111, C0103, C0200
import torch
import torch.nn as nn
import torch.nn.functional as F
from .gnn import GCNLayer, GATLayer
from ...readout import max_nodes
from ...nn.pytorch import WeightAndSum
from ...contrib.deprecation import deprecated
class MLPBinaryClassifier(nn.Module):
"""MLP for soft binary classification over multiple tasks from molecule representations.
Parameters
----------
in_feats : int
Number of input molecular graph features
hidden_feats : int
Number of molecular graph features in hidden layers
n_tasks : int
Number of tasks, also output size
dropout : float
The probability for dropout. Default to be 0., i.e. no
dropout is performed.
"""
def __init__(self, in_feats, hidden_feats, n_tasks, dropout=0.):
super(MLPBinaryClassifier, self).__init__()
self.predict = nn.Sequential(
nn.Dropout(dropout),
nn.Linear(in_feats, hidden_feats),
nn.ReLU(),
nn.BatchNorm1d(hidden_feats),
nn.Linear(hidden_feats, n_tasks)
)
def forward(self, h):
"""Perform soft binary classification over multiple tasks
Parameters
----------
h : FloatTensor of shape (B, M3)
* B is the number of molecules in a batch
* M3 is the input molecule feature size, must match in_feats in initialization
Returns
-------
FloatTensor of shape (B, n_tasks)
"""
return self.predict(h)
class BaseGNNClassifier(nn.Module):
"""GCN based predictor for multitask prediction on molecular graphs
We assume each task requires to perform a binary classification.
Parameters
----------
gnn_out_feats : int
Number of atom representation features after using GNN
n_tasks : int
Number of prediction tasks
classifier_hidden_feats : int
Number of molecular graph features in hidden layers of the MLP Classifier
dropout : float
The probability for dropout. Default to be 0., i.e. no
dropout is performed.
"""
def __init__(self, gnn_out_feats, n_tasks, classifier_hidden_feats=128, dropout=0.):
super(BaseGNNClassifier, self).__init__()
self.gnn_layers = nn.ModuleList()
self.weighted_sum_readout = WeightAndSum(gnn_out_feats)
self.g_feats = 2 * gnn_out_feats
self.soft_classifier = MLPBinaryClassifier(
self.g_feats, classifier_hidden_feats, n_tasks, dropout)
def forward(self, g, feats):
"""Multi-task prediction for a batch of molecules
Parameters
----------
g : DGLGraph
DGLGraph with batch size B for processing multiple molecules in parallel
feats : FloatTensor of shape (N, M0)
Initial features for all atoms in the batch of molecules
Returns
-------
FloatTensor of shape (B, n_tasks)
Soft prediction for all tasks on the batch of molecules
"""
# Update atom features with GNNs
for gnn in self.gnn_layers:
feats = gnn(g, feats)
# Compute molecule features from atom features
h_g_sum = self.weighted_sum_readout(g, feats)
with g.local_scope():
g.ndata['h'] = feats
h_g_max = max_nodes(g, 'h')
h_g = torch.cat([h_g_sum, h_g_max], dim=1)
# Multi-task prediction
return self.soft_classifier(h_g)
class GCNClassifier(BaseGNNClassifier):
"""GCN based predictor for multitask prediction on molecular graphs
We assume each task requires to perform a binary classification.
Parameters
----------
in_feats : int
Number of input atom features
gcn_hidden_feats : list of int
gcn_hidden_feats[i] gives the number of output atom features
in the i+1-th gcn layer
n_tasks : int
Number of prediction tasks
classifier_hidden_feats : int
Number of molecular graph features in hidden layers of the MLP Classifier
dropout : float
The probability for dropout. Default to be 0., i.e. no
dropout is performed.
"""
@deprecated('Import GCNPredictor from dgllife.model instead.', 'class')
def __init__(self, in_feats, gcn_hidden_feats, n_tasks,
classifier_hidden_feats=128, dropout=0.):
super(GCNClassifier, self).__init__(gnn_out_feats=gcn_hidden_feats[-1],
n_tasks=n_tasks,
classifier_hidden_feats=classifier_hidden_feats,
dropout=dropout)
for i in range(len(gcn_hidden_feats)):
out_feats = gcn_hidden_feats[i]
self.gnn_layers.append(GCNLayer(in_feats, out_feats))
in_feats = out_feats
class GATClassifier(BaseGNNClassifier):
"""GAT based predictor for multitask prediction on molecular graphs.
We assume each task requires to perform a binary classification.
Parameters
----------
in_feats : int
Number of input atom features
"""
@deprecated('Import GATPredictor from dgllife.model instead.', 'class')
def __init__(self, in_feats, gat_hidden_feats, num_heads,
n_tasks, classifier_hidden_feats=128, dropout=0):
super(GATClassifier, self).__init__(gnn_out_feats=gat_hidden_feats[-1],
n_tasks=n_tasks,
classifier_hidden_feats=classifier_hidden_feats,
dropout=dropout)
assert len(gat_hidden_feats) == len(num_heads), \
'Got gat_hidden_feats with length {:d} and num_heads with length {:d}, ' \
'expect them to be the same.'.format(len(gat_hidden_feats), len(num_heads))
num_layers = len(num_heads)
for l in range(num_layers):
if l > 0:
in_feats = gat_hidden_feats[l - 1] * num_heads[l - 1]
if l == num_layers - 1:
agg_mode = 'mean'
agg_act = None
else:
agg_mode = 'flatten'
agg_act = F.elu
self.gnn_layers.append(GATLayer(in_feats, gat_hidden_feats[l], num_heads[l],
feat_drop=dropout, attn_drop=dropout,
agg_mode=agg_mode, activation=agg_act))
# pylint: disable=C0103, W0622, R1710, W0104
"""
Learning Deep Generative Models of Graphs
https://arxiv.org/pdf/1803.03324.pdf
"""
from functools import partial
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from torch.distributions import Categorical
import dgl
from dgl import DGLGraph
from dgl.contrib.deprecation import deprecated
try:
from rdkit import Chem
except ImportError:
pass
class MoleculeEnv(object):
"""MDP environment for generating molecules.
Parameters
----------
atom_types : list
E.g. ['C', 'N']
bond_types : list
E.g. [Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE,
Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC]
"""
def __init__(self, atom_types, bond_types):
super(MoleculeEnv, self).__init__()
self.atom_types = atom_types
self.bond_types = bond_types
self.atom_type_to_id = dict()
self.bond_type_to_id = dict()
for id, a_type in enumerate(atom_types):
self.atom_type_to_id[a_type] = id
for id, b_type in enumerate(bond_types):
self.bond_type_to_id[b_type] = id
def get_decision_sequence(self, mol, atom_order):
"""Extract a decision sequence with which DGMG can generate the
molecule with a specified atom order.
Parameters
----------
mol : Chem.rdchem.Mol
atom_order : list
Specifies a mapping between the original atom
indices and the new atom indices. In particular,
atom_order[i] is re-labeled as i.
Returns
-------
decisions : list
decisions[i] is a 2-tuple (i, j)
- If i = 0, j specifies either the type of the atom to add
self.atom_types[j] or termination with j = len(self.atom_types)
- If i = 1, j specifies either the type of the bond to add
self.bond_types[j] or termination with j = len(self.bond_types)
- If i = 2, j specifies the destination atom id for the bond to add.
With the formulation of DGMG, j must be created before the decision.
"""
decisions = []
old2new = dict()
for new_id, old_id in enumerate(atom_order):
atom = mol.GetAtomWithIdx(old_id)
a_type = atom.GetSymbol()
decisions.append((0, self.atom_type_to_id[a_type]))
for bond in atom.GetBonds():
u = bond.GetBeginAtomIdx()
v = bond.GetEndAtomIdx()
if v == old_id:
u, v = v, u
if v in old2new:
decisions.append((1, self.bond_type_to_id[bond.GetBondType()]))
decisions.append((2, old2new[v]))
decisions.append((1, len(self.bond_types)))
old2new[old_id] = new_id
decisions.append((0, len(self.atom_types)))
return decisions
def reset(self, rdkit_mol=False):
"""Setup for generating a new molecule
Parameters
----------
rdkit_mol : bool
Whether to keep a Chem.rdchem.Mol object so
that we know what molecule is being generated
"""
self.dgl_graph = DGLGraph()
# If there are some features for nodes and edges,
# zero tensors will be set for those of new nodes and edges.
self.dgl_graph.set_n_initializer(dgl.frame.zero_initializer)
self.dgl_graph.set_e_initializer(dgl.frame.zero_initializer)
self.mol = None
if rdkit_mol:
# RWMol is a molecule class that is intended to be edited.
self.mol = Chem.RWMol(Chem.MolFromSmiles(''))
def num_atoms(self):
"""Get the number of atoms for the current molecule.
Returns
-------
int
"""
return self.dgl_graph.number_of_nodes()
def add_atom(self, type):
"""Add an atom of the specified type.
Parameters
----------
type : int
Should be in the range of [0, len(self.atom_types) - 1]
"""
self.dgl_graph.add_nodes(1)
if self.mol is not None:
self.mol.AddAtom(Chem.Atom(self.atom_types[type]))
def add_bond(self, u, v, type, bi_direction=True):
"""Add a bond of the specified type between atom u and v.
Parameters
----------
u : int
Index for the first atom
v : int
Index for the second atom
type : int
Index for the bond type
bi_direction : bool
Whether to add edges for both directions in the DGLGraph.
If not, we will only add the edge (u, v).
"""
if bi_direction:
self.dgl_graph.add_edges([u, v], [v, u])
else:
self.dgl_graph.add_edge(u, v)
if self.mol is not None:
self.mol.AddBond(u, v, self.bond_types[type])
def get_current_smiles(self):
"""Get the generated molecule in SMILES
Returns
-------
s : str
SMILES
"""
assert self.mol is not None, 'Expect a Chem.rdchem.Mol object initialized.'
s = Chem.MolToSmiles(self.mol)
return s
class GraphEmbed(nn.Module):
"""Compute a molecule representations out of atom representations.
Parameters
----------
node_hidden_size : int
Size of atom representation
"""
def __init__(self, node_hidden_size):
super(GraphEmbed, self).__init__()
# Setting from the paper
self.graph_hidden_size = 2 * node_hidden_size
# Embed graphs
self.node_gating = nn.Sequential(
nn.Linear(node_hidden_size, 1),
nn.Sigmoid()
)
self.node_to_graph = nn.Linear(node_hidden_size,
self.graph_hidden_size)
def forward(self, g):
"""
Parameters
----------
g : DGLGraph
Current molecule graph
Returns
-------
tensor of dtype float32 and shape (1, self.graph_hidden_size)
Computed representation for the current molecule graph
"""
if g.number_of_nodes() == 0:
# Use a zero tensor for an empty molecule.
return torch.zeros(1, self.graph_hidden_size)
else:
# Node features are stored as hv in ndata.
hvs = g.ndata['hv']
return (self.node_gating(hvs) *
self.node_to_graph(hvs)).sum(0, keepdim=True)
class GraphProp(nn.Module):
"""Perform message passing over a molecule graph and update its atom representations.
Parameters
----------
num_prop_rounds : int
Number of message passing rounds for each time
node_hidden_size : int
Size of atom representation
edge_hidden_size : int
Size of bond representation
"""
def __init__(self, num_prop_rounds, node_hidden_size, edge_hidden_size):
super(GraphProp, self).__init__()
self.num_prop_rounds = num_prop_rounds
# Setting from the paper
self.node_activation_hidden_size = 2 * node_hidden_size
message_funcs = []
self.reduce_funcs = []
node_update_funcs = []
for t in range(num_prop_rounds):
# input being [hv, hu, xuv]
message_funcs.append(nn.Linear(2 * node_hidden_size + edge_hidden_size,
self.node_activation_hidden_size))
self.reduce_funcs.append(partial(self.dgmg_reduce, round=t))
node_update_funcs.append(
nn.GRUCell(self.node_activation_hidden_size,
node_hidden_size))
self.message_funcs = nn.ModuleList(message_funcs)
self.node_update_funcs = nn.ModuleList(node_update_funcs)
def dgmg_msg(self, edges):
"""For an edge u->v, send a message concat([h_u, x_uv])
Parameters
----------
edges : batch of edges
Returns
-------
dict
Dictionary containing messages for the edge batch,
with the messages being tensors of shape (B, F1),
B for the number of edges and F1 for the message size.
"""
return {'m': torch.cat([edges.src['hv'],
edges.data['he']],
dim=1)}
def dgmg_reduce(self, nodes, round):
"""Aggregate messages.
Parameters
----------
nodes : batch of nodes
round : int
Update round
Returns
-------
dict
Dictionary containing aggregated messages for each node
in the batch, with the messages being tensors of shape
(B, F2), B for the number of nodes and F2 for the aggregated
message size
"""
hv_old = nodes.data['hv']
m = nodes.mailbox['m']
# Make copies of original atom representations to match the
# number of messages.
message = torch.cat([
hv_old.unsqueeze(1).expand(-1, m.size(1), -1), m], dim=2)
node_activation = (self.message_funcs[round](message)).sum(1)
return {'a': node_activation}
def forward(self, g):
"""
Parameters
----------
g : DGLGraph
"""
if g.number_of_edges() == 0:
return
else:
for t in range(self.num_prop_rounds):
g.update_all(message_func=self.dgmg_msg,
reduce_func=self.reduce_funcs[t])
g.ndata['hv'] = self.node_update_funcs[t](
g.ndata['a'], g.ndata['hv'])
class AddNode(nn.Module):
"""Stop or add an atom of a particular type.
Parameters
----------
env : MoleculeEnv
Environment for generating molecules
graph_embed_func : callable taking g as input
Function for computing molecule representation
node_hidden_size : int
Size of atom representation
dropout : float
Probability for dropout
"""
def __init__(self, env, graph_embed_func, node_hidden_size, dropout):
super(AddNode, self).__init__()
self.env = env
n_node_types = len(env.atom_types)
self.graph_op = {'embed': graph_embed_func}
self.stop = n_node_types
self.add_node = nn.Sequential(
nn.Linear(graph_embed_func.graph_hidden_size, graph_embed_func.graph_hidden_size),
nn.Dropout(p=dropout),
nn.Linear(graph_embed_func.graph_hidden_size, n_node_types + 1)
)
# If to add a node, initialize its hv
self.node_type_embed = nn.Embedding(n_node_types, node_hidden_size)
self.initialize_hv = nn.Linear(node_hidden_size + \
graph_embed_func.graph_hidden_size,
node_hidden_size)
self.init_node_activation = torch.zeros(1, 2 * node_hidden_size)
self.dropout = nn.Dropout(p=dropout)
def _initialize_node_repr(self, g, node_type, graph_embed):
"""Initialize atom representation
Parameters
----------
g : DGLGraph
node_type : int
Index for the type of the new atom
graph_embed : tensor of dtype float32
Molecule representation
"""
num_nodes = g.number_of_nodes()
hv_init = torch.cat([
self.node_type_embed(torch.LongTensor([node_type])),
graph_embed], dim=1)
hv_init = self.dropout(hv_init)
hv_init = self.initialize_hv(hv_init)
g.nodes[num_nodes - 1].data['hv'] = hv_init
g.nodes[num_nodes - 1].data['a'] = self.init_node_activation
def prepare_log_prob(self, compute_log_prob):
"""Setup for returning log likelihood
Parameters
----------
compute_log_prob : bool
Whether to compute log likelihood
"""
if compute_log_prob:
self.log_prob = []
self.compute_log_prob = compute_log_prob
def forward(self, action=None):
"""
Parameters
----------
action : None or int
If None, a new action will be sampled. If not None,
teacher forcing will be used to enforce the decision of the
corresponding action.
Returns
-------
stop : bool
Whether we stop adding new atoms
"""
g = self.env.dgl_graph
graph_embed = self.graph_op['embed'](g)
logits = self.add_node(graph_embed).view(1, -1)
probs = F.softmax(logits, dim=1)
if action is None:
action = Categorical(probs).sample().item()
stop = bool(action == self.stop)
if not stop:
self.env.add_atom(action)
self._initialize_node_repr(g, action, graph_embed)
if self.compute_log_prob:
sample_log_prob = F.log_softmax(logits, dim=1)[:, action: action + 1]
self.log_prob.append(sample_log_prob)
return stop
class AddEdge(nn.Module):
"""Stop or add a bond of a particular type.
Parameters
----------
env : MoleculeEnv
Environment for generating molecules
graph_embed_func : callable taking g as input
Function for computing molecule representation
node_hidden_size : int
Size of atom representation
dropout : float
Probability for dropout
"""
def __init__(self, env, graph_embed_func, node_hidden_size, dropout):
super(AddEdge, self).__init__()
self.env = env
n_bond_types = len(env.bond_types)
self.stop = n_bond_types
self.graph_op = {'embed': graph_embed_func}
self.add_edge = nn.Sequential(
nn.Linear(graph_embed_func.graph_hidden_size + node_hidden_size,
graph_embed_func.graph_hidden_size + node_hidden_size),
nn.Dropout(p=dropout),
nn.Linear(graph_embed_func.graph_hidden_size + node_hidden_size, n_bond_types + 1)
)
def prepare_log_prob(self, compute_log_prob):
"""Setup for returning log likelihood
Parameters
----------
compute_log_prob : bool
Whether to compute log likelihood
"""
if compute_log_prob:
self.log_prob = []
self.compute_log_prob = compute_log_prob
def forward(self, action=None):
"""
Parameters
----------
action : None or int
If None, a new action will be sampled. If not None,
teacher forcing will be used to enforce the decision of the
corresponding action.
Returns
-------
stop : bool
Whether we stop adding new bonds
action : int
The type for the new bond
"""
g = self.env.dgl_graph
graph_embed = self.graph_op['embed'](g)
src_embed = g.nodes[g.number_of_nodes() - 1].data['hv']
logits = self.add_edge(
torch.cat([graph_embed, src_embed], dim=1))
probs = F.softmax(logits, dim=1)
if action is None:
action = Categorical(probs).sample().item()
stop = bool(action == self.stop)
if self.compute_log_prob:
sample_log_prob = F.log_softmax(logits, dim=1)[:, action: action + 1]
self.log_prob.append(sample_log_prob)
return stop, action
class ChooseDestAndUpdate(nn.Module):
"""Choose the atom to connect for the new bond.
Parameters
----------
env : MoleculeEnv
Environment for generating molecules
graph_prop_func : callable taking g as input
Function for performing message passing
and updating atom representations
node_hidden_size : int
Size of atom representation
dropout : float
Probability for dropout
"""
def __init__(self, env, graph_prop_func, node_hidden_size, dropout):
super(ChooseDestAndUpdate, self).__init__()
self.env = env
n_bond_types = len(self.env.bond_types)
# To be used for one-hot encoding of bond type
self.bond_embedding = torch.eye(n_bond_types)
self.graph_op = {'prop': graph_prop_func}
self.choose_dest = nn.Sequential(
nn.Linear(2 * node_hidden_size + n_bond_types, 2 * node_hidden_size + n_bond_types),
nn.Dropout(p=dropout),
nn.Linear(2 * node_hidden_size + n_bond_types, 1)
)
def _initialize_edge_repr(self, g, src_list, dest_list, edge_embed):
"""Initialize bond representation
Parameters
----------
g : DGLGraph
src_list : list of int
source atoms for new bonds
dest_list : list of int
destination atoms for new bonds
edge_embed : 2D tensor of dtype float32
Embeddings for the new bonds
"""
g.edges[src_list, dest_list].data['he'] = edge_embed.expand(len(src_list), -1)
def prepare_log_prob(self, compute_log_prob):
"""Setup for returning log likelihood
Parameters
----------
compute_log_prob : bool
Whether to compute log likelihood
"""
if compute_log_prob:
self.log_prob = []
self.compute_log_prob = compute_log_prob
def forward(self, bond_type, dest):
"""
Parameters
----------
bond_type : int
The type for the new bond
dest : int or None
If None, a new action will be sampled. If not None,
teacher forcing will be used to enforce the decision of the
corresponding action.
"""
g = self.env.dgl_graph
src = g.number_of_nodes() - 1
possible_dests = range(src)
src_embed_expand = g.nodes[src].data['hv'].expand(src, -1)
possible_dests_embed = g.nodes[possible_dests].data['hv']
edge_embed = self.bond_embedding[bond_type: bond_type + 1]
dests_scores = self.choose_dest(
torch.cat([possible_dests_embed,
src_embed_expand,
edge_embed.expand(src, -1)], dim=1)).view(1, -1)
dests_probs = F.softmax(dests_scores, dim=1)
if dest is None:
dest = Categorical(dests_probs).sample().item()
if not g.has_edge_between(src, dest):
# For undirected graphs, we add edges for both directions
# so that we can perform graph propagation.
src_list = [src, dest]
dest_list = [dest, src]
self.env.add_bond(src, dest, bond_type)
self._initialize_edge_repr(g, src_list, dest_list, edge_embed)
# Perform message passing when new bonds are added.
self.graph_op['prop'](g)
if self.compute_log_prob:
if dests_probs.nelement() > 1:
self.log_prob.append(
F.log_softmax(dests_scores, dim=1)[:, dest: dest + 1])
def weights_init(m):
'''Function to initialize weights for models
Code from https://gist.github.com/jeasinema/ed9236ce743c8efaf30fa2ff732749f5
Usage:
model = Model()
model.apply(weight_init)
'''
if isinstance(m, nn.Linear):
init.xavier_normal_(m.weight.data)
init.normal_(m.bias.data)
elif isinstance(m, nn.GRUCell):
for param in m.parameters():
if len(param.shape) >= 2:
init.orthogonal_(param.data)
else:
init.normal_(param.data)
def dgmg_message_weight_init(m):
"""Weight initialization for graph propagation module
These are suggested by the author. This should only be used for
the message passing functions, i.e. fe's in the paper.
"""
def _weight_init(m):
if isinstance(m, nn.Linear):
init.normal_(m.weight.data, std=1./10)
init.normal_(m.bias.data, std=1./10)
else:
raise ValueError('Expected the input to be of type nn.Linear!')
if isinstance(m, nn.ModuleList):
for layer in m:
layer.apply(_weight_init)
else:
m.apply(_weight_init)
class DGMG(nn.Module):
"""DGMG model
`Learning Deep Generative Models of Graphs <https://arxiv.org/abs/1803.03324>`__
Users only need to initialize an instance of this class.
Parameters
----------
atom_types : list
E.g. ['C', 'N']
bond_types : list
E.g. [Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE,
Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC]
node_hidden_size : int
Size of atom representation
num_prop_rounds : int
Number of message passing rounds for each time
dropout : float
Probability for dropout
"""
@deprecated('Import DGMG from dgllife.model instead.', 'class')
def __init__(self, atom_types, bond_types, node_hidden_size, num_prop_rounds, dropout):
super(DGMG, self).__init__()
self.env = MoleculeEnv(atom_types, bond_types)
# Graph embedding module
self.graph_embed = GraphEmbed(node_hidden_size)
# Graph propagation module
# For one-hot encoding, edge_hidden_size is just the number of bond types
self.graph_prop = GraphProp(num_prop_rounds, node_hidden_size, len(self.env.bond_types))
# Actions
self.add_node_agent = AddNode(
self.env, self.graph_embed, node_hidden_size, dropout)
self.add_edge_agent = AddEdge(
self.env, self.graph_embed, node_hidden_size, dropout)
self.choose_dest_agent = ChooseDestAndUpdate(
self.env, self.graph_prop, node_hidden_size, dropout)
# Weight initialization
self.init_weights()
def init_weights(self):
"""Initialize model weights"""
self.graph_embed.apply(weights_init)
self.graph_prop.apply(weights_init)
self.add_node_agent.apply(weights_init)
self.add_edge_agent.apply(weights_init)
self.choose_dest_agent.apply(weights_init)
self.graph_prop.message_funcs.apply(dgmg_message_weight_init)
def count_step(self):
"""Increment the step by 1."""
self.step_count += 1
def prepare_log_prob(self, compute_log_prob):
"""Setup for returning log likelihood
Parameters
----------
compute_log_prob : bool
Whether to compute log likelihood
"""
self.compute_log_prob = compute_log_prob
self.add_node_agent.prepare_log_prob(compute_log_prob)
self.add_edge_agent.prepare_log_prob(compute_log_prob)
self.choose_dest_agent.prepare_log_prob(compute_log_prob)
def add_node_and_update(self, a=None):
"""Decide if to add a new atom.
If a new atom should be added, update the graph.
Parameters
----------
a : None or int
If None, a new action will be sampled. If not None,
teacher forcing will be used to enforce the decision of the
corresponding action.
"""
self.count_step()
return self.add_node_agent(a)
def add_edge_or_not(self, a=None):
"""Decide if to add a new bond.
Parameters
----------
a : None or int
If None, a new action will be sampled. If not None,
teacher forcing will be used to enforce the decision of the
corresponding action.
"""
self.count_step()
return self.add_edge_agent(a)
def choose_dest_and_update(self, bond_type, a=None):
"""Choose destination and connect it to the latest atom.
Add edges for both directions and update the graph.
Parameters
----------
bond_type : int
The type of the new bond to add
a : None or int
If None, a new action will be sampled. If not None,
teacher forcing will be used to enforce the decision of the
corresponding action.
"""
self.count_step()
self.choose_dest_agent(bond_type, a)
def get_log_prob(self):
"""Compute the log likelihood for the decision sequence,
typically corresponding to the generation of a molecule.
Returns
-------
torch.tensor consisting of a float only
"""
return torch.cat(self.add_node_agent.log_prob).sum()\
+ torch.cat(self.add_edge_agent.log_prob).sum()\
+ torch.cat(self.choose_dest_agent.log_prob).sum()
def teacher_forcing(self, actions):
"""Generate a molecule according to a sequence of actions.
Parameters
----------
actions : list of 2-tuples of int
actions[t] gives (i, j), the action to execute by DGMG at timestep t.
- If i = 0, j specifies either the type of the atom to add or termination
- If i = 1, j specifies either the type of the bond to add or termination
- If i = 2, j specifies the destination atom id for the bond to add.
With the formulation of DGMG, j must be created before the decision.
"""
stop_node = self.add_node_and_update(a=actions[self.step_count][1])
while not stop_node:
# A new atom was just added.
stop_edge, bond_type = self.add_edge_or_not(a=actions[self.step_count][1])
while not stop_edge:
# A new bond is to be added.
self.choose_dest_and_update(bond_type, a=actions[self.step_count][1])
stop_edge, bond_type = self.add_edge_or_not(a=actions[self.step_count][1])
stop_node = self.add_node_and_update(a=actions[self.step_count][1])
def rollout(self, max_num_steps):
"""Sample a molecule from the distribution learned by DGMG."""
stop_node = self.add_node_and_update()
while (not stop_node) and (self.step_count <= max_num_steps):
stop_edge, bond_type = self.add_edge_or_not()
if self.env.num_atoms() == 1:
stop_edge = True
while (not stop_edge) and (self.step_count <= max_num_steps):
self.choose_dest_and_update(bond_type)
stop_edge, bond_type = self.add_edge_or_not()
stop_node = self.add_node_and_update()
def forward(self, actions=None, rdkit_mol=False, compute_log_prob=False, max_num_steps=400):
"""
Parameters
----------
actions : list of 2-tuples or None.
If actions are not None, generate a molecule according to actions.
Otherwise, a molecule will be generated based on sampled actions.
rdkit_mol : bool
Whether to maintain a Chem.rdchem.Mol object. This brings extra
computational cost, but is necessary if we are interested in
learning the generated molecule.
compute_log_prob : bool
Whether to compute log likelihood
max_num_steps : int
Maximum number of steps allowed. This only comes into effect
during inference and prevents the model from not stopping.
Returns
-------
torch.tensor consisting of a float only, optional
The log likelihood for the actions taken
str, optional
The generated molecule in the form of SMILES
"""
# Initialize an empty molecule
self.step_count = 0
self.env.reset(rdkit_mol=rdkit_mol)
self.prepare_log_prob(compute_log_prob)
if actions is not None:
# A sequence of decisions is given, use teacher forcing
self.teacher_forcing(actions)
else:
# Sample a molecule from the distribution learned by DGMG
self.rollout(max_num_steps)
if compute_log_prob and rdkit_mol:
return self.get_log_prob(), self.env.get_current_smiles()
if compute_log_prob:
return self.get_log_prob()
if rdkit_mol:
return self.env.get_current_smiles()
# pylint: disable=C0103, E1101
"""GNN layers for updating atom representations"""
import torch.nn as nn
import torch.nn.functional as F
from ...nn.pytorch import GraphConv, GATConv
class GCNLayer(nn.Module):
"""Single layer GCN for updating node features
Parameters
----------
in_feats : int
Number of input atom features
out_feats : int
Number of output atom features
activation : activation function
Default to be ReLU
residual : bool
Whether to use residual connection, default to be True
batchnorm : bool
Whether to use batch normalization on the output,
default to be True
dropout : float
The probability for dropout. Default to be 0., i.e. no
dropout is performed.
"""
def __init__(self, in_feats, out_feats, activation=F.relu,
residual=True, batchnorm=True, dropout=0.):
super(GCNLayer, self).__init__()
self.activation = activation
self.graph_conv = GraphConv(in_feats=in_feats, out_feats=out_feats,
norm=False, activation=activation)
self.dropout = nn.Dropout(dropout)
self.residual = residual
if residual:
self.res_connection = nn.Linear(in_feats, out_feats)
self.bn = batchnorm
if batchnorm:
self.bn_layer = nn.BatchNorm1d(out_feats)
def forward(self, g, feats):
"""Update atom representations
Parameters
----------
g : DGLGraph
DGLGraph with batch size B for processing multiple molecules in parallel
feats : FloatTensor of shape (N, M1)
* N is the total number of atoms in the batched graph
* M1 is the input atom feature size, must match in_feats in initialization
Returns
-------
new_feats : FloatTensor of shape (N, M2)
* M2 is the output atom feature size, must match out_feats in initialization
"""
new_feats = self.graph_conv(g, feats)
if self.residual:
res_feats = self.activation(self.res_connection(feats))
new_feats = new_feats + res_feats
new_feats = self.dropout(new_feats)
if self.bn:
new_feats = self.bn_layer(new_feats)
return new_feats
class GATLayer(nn.Module):
"""Single layer GAT for updating node features
Parameters
----------
in_feats : int
Number of input atom features
out_feats : int
Number of output atom features for each attention head
num_heads : int
Number of attention heads
feat_drop : float
Dropout applied to the input features
attn_drop : float
Dropout applied to attention values of edges
alpha : float
Hyperparameter in LeakyReLU, slope for negative values. Default to be 0.2
residual : bool
Whether to perform skip connection, default to be False
agg_mode : str
The way to aggregate multi-head attention results, can be either
'flatten' for concatenating all head results or 'mean' for averaging
all head results
activation : activation function or None
Activation function applied to aggregated multi-head results, default to be None.
"""
def __init__(self, in_feats, out_feats, num_heads, feat_drop, attn_drop,
alpha=0.2, residual=True, agg_mode='flatten', activation=None):
super(GATLayer, self).__init__()
self.gnn = GATConv(in_feats=in_feats, out_feats=out_feats, num_heads=num_heads,
feat_drop=feat_drop, attn_drop=attn_drop,
negative_slope=alpha, residual=residual)
assert agg_mode in ['flatten', 'mean']
self.agg_mode = agg_mode
self.activation = activation
def forward(self, bg, feats):
"""Update atom representations
Parameters
----------
bg : DGLGraph
Batched DGLGraphs for processing multiple molecules in parallel
feats : FloatTensor of shape (N, M1)
* N is the total number of atoms in the batched graph
* M1 is the input atom feature size, must match in_feats in initialization
Returns
-------
new_feats : FloatTensor of shape (N, M2)
* M2 is the output atom feature size. If self.agg_mode == 'flatten', this would
be out_feats * num_heads, else it would be just out_feats.
"""
new_feats = self.gnn(bg, feats)
if self.agg_mode == 'flatten':
new_feats = new_feats.flatten(1)
else:
new_feats = new_feats.mean(1)
if self.activation is not None:
new_feats = self.activation(new_feats)
return new_feats
"""JTNN Module"""
from .chemutils import decode_stereo
from .jtnn_vae import DGLJTNNVAE
from .mol_tree import Vocab
from .mpn import DGLMPN
from .nnutils import create_var, cuda
# pylint: disable=C0111, C0103, E1101, W0611, W0612, W0703, C0200, R1710
from collections import defaultdict
import rdkit.Chem as Chem
from rdkit.Chem.EnumerateStereoisomers import EnumerateStereoisomers
from scipy.sparse import csr_matrix
from scipy.sparse.csgraph import minimum_spanning_tree
MST_MAX_WEIGHT = 100
MAX_NCAND = 2000
def set_atommap(mol, num=0):
for atom in mol.GetAtoms():
atom.SetAtomMapNum(num)
def get_mol(smiles):
mol = Chem.MolFromSmiles(smiles)
if mol is None:
return None
Chem.Kekulize(mol)
return mol
def get_smiles(mol):
return Chem.MolToSmiles(mol, kekuleSmiles=True)
def decode_stereo(smiles2D):
mol = Chem.MolFromSmiles(smiles2D)
dec_isomers = list(EnumerateStereoisomers(mol))
dec_isomers = [Chem.MolFromSmiles(Chem.MolToSmiles(
mol, isomericSmiles=True)) for mol in dec_isomers]
smiles3D = [Chem.MolToSmiles(mol, isomericSmiles=True)
for mol in dec_isomers]
chiralN = [atom.GetIdx() for atom in dec_isomers[0].GetAtoms() if int(
atom.GetChiralTag()) > 0 and atom.GetSymbol() == "N"]
if len(chiralN) > 0:
for mol in dec_isomers:
for idx in chiralN:
mol.GetAtomWithIdx(idx).SetChiralTag(
Chem.rdchem.ChiralType.CHI_UNSPECIFIED)
smiles3D.append(Chem.MolToSmiles(mol, isomericSmiles=True))
return smiles3D
def sanitize(mol):
try:
smiles = get_smiles(mol)
mol = get_mol(smiles)
except Exception:
return None
return mol
def copy_atom(atom):
new_atom = Chem.Atom(atom.GetSymbol())
new_atom.SetFormalCharge(atom.GetFormalCharge())
new_atom.SetAtomMapNum(atom.GetAtomMapNum())
return new_atom
def copy_edit_mol(mol):
new_mol = Chem.RWMol(Chem.MolFromSmiles(''))
for atom in mol.GetAtoms():
new_atom = copy_atom(atom)
new_mol.AddAtom(new_atom)
for bond in mol.GetBonds():
a1 = bond.GetBeginAtom().GetIdx()
a2 = bond.GetEndAtom().GetIdx()
bt = bond.GetBondType()
new_mol.AddBond(a1, a2, bt)
return new_mol
def get_clique_mol(mol, atoms):
smiles = Chem.MolFragmentToSmiles(mol, atoms, kekuleSmiles=True)
new_mol = Chem.MolFromSmiles(smiles, sanitize=False)
new_mol = copy_edit_mol(new_mol).GetMol()
new_mol = sanitize(new_mol) # We assume this is not None
return new_mol
def tree_decomp(mol):
n_atoms = mol.GetNumAtoms()
if n_atoms == 1:
return [[0]], []
cliques = []
for bond in mol.GetBonds():
a1 = bond.GetBeginAtom().GetIdx()
a2 = bond.GetEndAtom().GetIdx()
if not bond.IsInRing():
cliques.append([a1, a2])
ssr = [list(x) for x in Chem.GetSymmSSSR(mol)]
cliques.extend(ssr)
nei_list = [[] for i in range(n_atoms)]
for i in range(len(cliques)):
for atom in cliques[i]:
nei_list[atom].append(i)
# Merge Rings with intersection > 2 atoms
for i in range(len(cliques)):
if len(cliques[i]) <= 2:
continue
for atom in cliques[i]:
for j in nei_list[atom]:
if i >= j or len(cliques[j]) <= 2:
continue
inter = set(cliques[i]) & set(cliques[j])
if len(inter) > 2:
cliques[i].extend(cliques[j])
cliques[i] = list(set(cliques[i]))
cliques[j] = []
cliques = [c for c in cliques if len(c) > 0]
nei_list = [[] for i in range(n_atoms)]
for i in range(len(cliques)):
for atom in cliques[i]:
nei_list[atom].append(i)
# Build edges and add singleton cliques
edges = defaultdict(int)
for atom in range(n_atoms):
if len(nei_list[atom]) <= 1:
continue
cnei = nei_list[atom]
bonds = [c for c in cnei if len(cliques[c]) == 2]
rings = [c for c in cnei if len(cliques[c]) > 4]
# In general, if len(cnei) >= 3, a singleton should be added, but 1
# bond + 2 ring is currently not dealt with.
if len(bonds) > 2 or (len(bonds) == 2 and len(cnei) > 2):
cliques.append([atom])
c2 = len(cliques) - 1
for c1 in cnei:
edges[(c1, c2)] = 1
elif len(rings) > 2: # Multiple (n>2) complex rings
cliques.append([atom])
c2 = len(cliques) - 1
for c1 in cnei:
edges[(c1, c2)] = MST_MAX_WEIGHT - 1
else:
for i in range(len(cnei)):
for j in range(i + 1, len(cnei)):
c1, c2 = cnei[i], cnei[j]
inter = set(cliques[c1]) & set(cliques[c2])
if edges[(c1, c2)] < len(inter):
# cnei[i] < cnei[j] by construction
edges[(c1, c2)] = len(inter)
edges = [u + (MST_MAX_WEIGHT - v,) for u, v in edges.items()]
if len(edges) == 0:
return cliques, edges
# Compute Maximum Spanning Tree
row, col, data = list(zip(*edges))
n_clique = len(cliques)
clique_graph = csr_matrix((data, (row, col)), shape=(n_clique, n_clique))
junc_tree = minimum_spanning_tree(clique_graph)
row, col = junc_tree.nonzero()
edges = [(row[i], col[i]) for i in range(len(row))]
return (cliques, edges)
def atom_equal(a1, a2):
return a1.GetSymbol() == a2.GetSymbol() and a1.GetFormalCharge() == a2.GetFormalCharge()
# Bond type not considered because all aromatic (so SINGLE matches DOUBLE)
def ring_bond_equal(b1, b2, reverse=False):
b1 = (b1.GetBeginAtom(), b1.GetEndAtom())
if reverse:
b2 = (b2.GetEndAtom(), b2.GetBeginAtom())
else:
b2 = (b2.GetBeginAtom(), b2.GetEndAtom())
return atom_equal(b1[0], b2[0]) and atom_equal(b1[1], b2[1])
def attach_mols_nx(ctr_mol, neighbors, prev_nodes, nei_amap):
prev_nids = [node['nid'] for node in prev_nodes]
for nei_node in prev_nodes + neighbors:
nei_id, nei_mol = nei_node['nid'], nei_node['mol']
amap = nei_amap[nei_id]
for atom in nei_mol.GetAtoms():
if atom.GetIdx() not in amap:
new_atom = copy_atom(atom)
amap[atom.GetIdx()] = ctr_mol.AddAtom(new_atom)
if nei_mol.GetNumBonds() == 0:
nei_atom = nei_mol.GetAtomWithIdx(0)
ctr_atom = ctr_mol.GetAtomWithIdx(amap[0])
ctr_atom.SetAtomMapNum(nei_atom.GetAtomMapNum())
else:
for bond in nei_mol.GetBonds():
a1 = amap[bond.GetBeginAtom().GetIdx()]
a2 = amap[bond.GetEndAtom().GetIdx()]
if ctr_mol.GetBondBetweenAtoms(a1, a2) is None:
ctr_mol.AddBond(a1, a2, bond.GetBondType())
elif nei_id in prev_nids: # father node overrides
ctr_mol.RemoveBond(a1, a2)
ctr_mol.AddBond(a1, a2, bond.GetBondType())
return ctr_mol
def local_attach_nx(ctr_mol, neighbors, prev_nodes, amap_list):
ctr_mol = copy_edit_mol(ctr_mol)
nei_amap = {nei['nid']: {} for nei in prev_nodes + neighbors}
for nei_id, ctr_atom, nei_atom in amap_list:
nei_amap[nei_id][nei_atom] = ctr_atom
ctr_mol = attach_mols_nx(ctr_mol, neighbors, prev_nodes, nei_amap)
return ctr_mol.GetMol()
# This version records idx mapping between ctr_mol and nei_mol
def enum_attach_nx(ctr_mol, nei_node, amap, singletons):
nei_mol, nei_idx = nei_node['mol'], nei_node['nid']
att_confs = []
black_list = [atom_idx for nei_id, atom_idx,
_ in amap if nei_id in singletons]
ctr_atoms = [atom for atom in ctr_mol.GetAtoms() if atom.GetIdx()
not in black_list]
ctr_bonds = [bond for bond in ctr_mol.GetBonds()]
if nei_mol.GetNumBonds() == 0: # neighbor singleton
nei_atom = nei_mol.GetAtomWithIdx(0)
used_list = [atom_idx for _, atom_idx, _ in amap]
for atom in ctr_atoms:
if atom_equal(atom, nei_atom) and atom.GetIdx() not in used_list:
new_amap = amap + [(nei_idx, atom.GetIdx(), 0)]
att_confs.append(new_amap)
elif nei_mol.GetNumBonds() == 1: # neighbor is a bond
bond = nei_mol.GetBondWithIdx(0)
bond_val = int(bond.GetBondTypeAsDouble())
b1, b2 = bond.GetBeginAtom(), bond.GetEndAtom()
for atom in ctr_atoms:
# Optimize if atom is carbon (other atoms may change valence)
if atom.GetAtomicNum() == 6 and atom.GetTotalNumHs() < bond_val:
continue
if atom_equal(atom, b1):
new_amap = amap + [(nei_idx, atom.GetIdx(), b1.GetIdx())]
att_confs.append(new_amap)
elif atom_equal(atom, b2):
new_amap = amap + [(nei_idx, atom.GetIdx(), b2.GetIdx())]
att_confs.append(new_amap)
else:
# intersection is an atom
for a1 in ctr_atoms:
for a2 in nei_mol.GetAtoms():
if atom_equal(a1, a2):
# Optimize if atom is carbon (other atoms may change
# valence)
if a1.GetAtomicNum() == 6 and a1.GetTotalNumHs() + a2.GetTotalNumHs() < 4:
continue
new_amap = amap + [(nei_idx, a1.GetIdx(), a2.GetIdx())]
att_confs.append(new_amap)
# intersection is an bond
if ctr_mol.GetNumBonds() > 1:
for b1 in ctr_bonds:
for b2 in nei_mol.GetBonds():
if ring_bond_equal(b1, b2):
new_amap = amap + [(nei_idx,
b1.GetBeginAtom().GetIdx(),
b2.GetBeginAtom().GetIdx()),
(nei_idx,
b1.GetEndAtom().GetIdx(),
b2.GetEndAtom().GetIdx())]
att_confs.append(new_amap)
if ring_bond_equal(b1, b2, reverse=True):
new_amap = amap + [(nei_idx,
b1.GetBeginAtom().GetIdx(),
b2.GetEndAtom().GetIdx()),
(nei_idx,
b1.GetEndAtom().GetIdx(),
b2.GetBeginAtom().GetIdx())]
att_confs.append(new_amap)
return att_confs
# Try rings first: Speed-Up
def enum_assemble_nx(node, neighbors, prev_nodes=None, prev_amap=None):
if prev_nodes is None:
prev_nodes = []
if prev_amap is None:
prev_amap = []
all_attach_confs = []
singletons = [nei_node['nid'] for nei_node in neighbors +
prev_nodes if nei_node['mol'].GetNumAtoms() == 1]
def search(cur_amap, depth):
if len(all_attach_confs) > MAX_NCAND:
return None
if depth == len(neighbors):
all_attach_confs.append(cur_amap)
return None
nei_node = neighbors[depth]
cand_amap = enum_attach_nx(node['mol'], nei_node, cur_amap, singletons)
cand_smiles = set()
candidates = []
for amap in cand_amap:
cand_mol = local_attach_nx(
node['mol'], neighbors[:depth + 1], prev_nodes, amap)
cand_mol = sanitize(cand_mol)
if cand_mol is None:
continue
smiles = get_smiles(cand_mol)
if smiles in cand_smiles:
continue
cand_smiles.add(smiles)
candidates.append(amap)
if len(candidates) == 0:
return []
for new_amap in candidates:
search(new_amap, depth + 1)
search(prev_amap, 0)
cand_smiles = set()
candidates = []
for amap in all_attach_confs:
cand_mol = local_attach_nx(node['mol'], neighbors, prev_nodes, amap)
cand_mol = Chem.MolFromSmiles(Chem.MolToSmiles(cand_mol))
smiles = Chem.MolToSmiles(cand_mol)
if smiles in cand_smiles:
continue
cand_smiles.add(smiles)
Chem.Kekulize(cand_mol)
candidates.append((smiles, cand_mol, amap))
return candidates
# Only used for debugging purpose
def dfs_assemble_nx(
graph,
cur_mol,
global_amap,
fa_amap,
cur_node_id,
fa_node_id):
cur_node = graph.nodes_dict[cur_node_id]
fa_node = graph.nodes_dict[fa_node_id] if fa_node_id is not None else None
fa_nid = fa_node['nid'] if fa_node is not None else -1
prev_nodes = [fa_node] if fa_node is not None else []
children_id = [nei for nei in graph[cur_node_id]
if graph.nodes_dict[nei]['nid'] != fa_nid]
children = [graph.nodes_dict[nei] for nei in children_id]
neighbors = [nei for nei in children if nei['mol'].GetNumAtoms() > 1]
neighbors = sorted(
neighbors, key=lambda x: x['mol'].GetNumAtoms(), reverse=True)
singletons = [nei for nei in children if nei['mol'].GetNumAtoms() == 1]
neighbors = singletons + neighbors
cur_amap = [(fa_nid, a2, a1)
for nid, a1, a2 in fa_amap if nid == cur_node['nid']]
cands = enum_assemble_nx(
graph.nodes_dict[cur_node_id], neighbors, prev_nodes, cur_amap)
if len(cands) == 0:
return
cand_smiles, _, cand_amap = zip(*cands)
label_idx = cand_smiles.index(cur_node['label'])
label_amap = cand_amap[label_idx]
for nei_id, ctr_atom, nei_atom in label_amap:
if nei_id == fa_nid:
continue
global_amap[nei_id][nei_atom] = global_amap[cur_node['nid']][ctr_atom]
# father is already attached
cur_mol = attach_mols_nx(cur_mol, children, [], global_amap)
for nei_node_id, nei_node in zip(children_id, children):
if not nei_node['is_leaf']:
dfs_assemble_nx(graph, cur_mol, global_amap,
label_amap, nei_node_id, cur_node_id)
# pylint: disable=C0111, C0103, E1101, W0611, W0612, W1508
# pylint: disable=redefined-outer-name
import os
import rdkit.Chem as Chem
import torch
import torch.nn as nn
import dgl.function as DGLF
from dgl import DGLGraph, mean_nodes
from .nnutils import cuda
ELEM_LIST = ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na',
'Ca', 'Fe', 'Al', 'I', 'B', 'K', 'Se', 'Zn', 'H', 'Cu', 'Mn', 'unknown']
ATOM_FDIM = len(ELEM_LIST) + 6 + 5 + 1
BOND_FDIM = 5
MAX_NB = 10
PAPER = os.getenv('PAPER', False)
def onek_encoding_unk(x, allowable_set):
if x not in allowable_set:
x = allowable_set[-1]
return [x == s for s in allowable_set]
# Note that during graph decoding they don't predict stereochemistry-related
# characteristics (i.e. Chiral Atoms, E-Z, Cis-Trans). Instead, they decode
# the 2-D graph first, then enumerate all possible 3-D forms and find the
# one with highest score.
def atom_features(atom):
return (torch.Tensor(onek_encoding_unk(atom.GetSymbol(), ELEM_LIST)
+ onek_encoding_unk(atom.GetDegree(),
[0, 1, 2, 3, 4, 5])
+ onek_encoding_unk(atom.GetFormalCharge(),
[-1, -2, 1, 2, 0])
+ [atom.GetIsAromatic()]))
def bond_features(bond):
bt = bond.GetBondType()
return torch.Tensor([bt == Chem.rdchem.BondType.SINGLE,
bt == Chem.rdchem.BondType.DOUBLE,
bt == Chem.rdchem.BondType.TRIPLE,
bt == Chem.rdchem.BondType.AROMATIC,
bond.IsInRing()])
def mol2dgl_single(cand_batch):
cand_graphs = []
tree_mess_source_edges = [] # map these edges from trees to...
tree_mess_target_edges = [] # these edges on candidate graphs
tree_mess_target_nodes = []
n_nodes = 0
atom_x = []
bond_x = []
for mol, mol_tree, ctr_node_id in cand_batch:
n_atoms = mol.GetNumAtoms()
g = DGLGraph()
for i, atom in enumerate(mol.GetAtoms()):
assert i == atom.GetIdx()
atom_x.append(atom_features(atom))
g.add_nodes(n_atoms)
bond_src = []
bond_dst = []
for i, bond in enumerate(mol.GetBonds()):
a1 = bond.GetBeginAtom()
a2 = bond.GetEndAtom()
begin_idx = a1.GetIdx()
end_idx = a2.GetIdx()
features = bond_features(bond)
bond_src.append(begin_idx)
bond_dst.append(end_idx)
bond_x.append(features)
bond_src.append(end_idx)
bond_dst.append(begin_idx)
bond_x.append(features)
x_nid, y_nid = a1.GetAtomMapNum(), a2.GetAtomMapNum()
# Tree node ID in the batch
x_bid = mol_tree.nodes_dict[x_nid - 1]['idx'] if x_nid > 0 else -1
y_bid = mol_tree.nodes_dict[y_nid - 1]['idx'] if y_nid > 0 else -1
if x_bid >= 0 and y_bid >= 0 and x_bid != y_bid:
if mol_tree.has_edge_between(x_bid, y_bid):
tree_mess_target_edges.append(
(begin_idx + n_nodes, end_idx + n_nodes))
tree_mess_source_edges.append((x_bid, y_bid))
tree_mess_target_nodes.append(end_idx + n_nodes)
if mol_tree.has_edge_between(y_bid, x_bid):
tree_mess_target_edges.append(
(end_idx + n_nodes, begin_idx + n_nodes))
tree_mess_source_edges.append((y_bid, x_bid))
tree_mess_target_nodes.append(begin_idx + n_nodes)
n_nodes += n_atoms
g.add_edges(bond_src, bond_dst)
cand_graphs.append(g)
return cand_graphs, torch.stack(atom_x), \
torch.stack(bond_x) if len(bond_x) > 0 else torch.zeros(0), \
torch.LongTensor(tree_mess_source_edges), \
torch.LongTensor(tree_mess_target_edges), \
torch.LongTensor(tree_mess_target_nodes)
mpn_loopy_bp_msg = DGLF.copy_src(src='msg', out='msg')
mpn_loopy_bp_reduce = DGLF.sum(msg='msg', out='accum_msg')
class LoopyBPUpdate(nn.Module):
def __init__(self, hidden_size):
super(LoopyBPUpdate, self).__init__()
self.hidden_size = hidden_size
self.W_h = nn.Linear(hidden_size, hidden_size, bias=False)
def forward(self, node):
msg_input = node.data['msg_input']
msg_delta = self.W_h(node.data['accum_msg'] + node.data['alpha'])
msg = torch.relu(msg_input + msg_delta)
return {'msg': msg}
if PAPER:
mpn_gather_msg = [
DGLF.copy_edge(edge='msg', out='msg'),
DGLF.copy_edge(edge='alpha', out='alpha')
]
else:
mpn_gather_msg = DGLF.copy_edge(edge='msg', out='msg')
if PAPER:
mpn_gather_reduce = [
DGLF.sum(msg='msg', out='m'),
DGLF.sum(msg='alpha', out='accum_alpha'),
]
else:
mpn_gather_reduce = DGLF.sum(msg='msg', out='m')
class GatherUpdate(nn.Module):
def __init__(self, hidden_size):
super(GatherUpdate, self).__init__()
self.hidden_size = hidden_size
self.W_o = nn.Linear(ATOM_FDIM + hidden_size, hidden_size)
def forward(self, node):
if PAPER:
#m = node['m']
m = node.data['m'] + node.data['accum_alpha']
else:
m = node.data['m'] + node.data['alpha']
return {
'h': torch.relu(self.W_o(torch.cat([node.data['x'], m], 1))),
}
class DGLJTMPN(nn.Module):
def __init__(self, hidden_size, depth):
nn.Module.__init__(self)
self.depth = depth
self.W_i = nn.Linear(ATOM_FDIM + BOND_FDIM, hidden_size, bias=False)
self.loopy_bp_updater = LoopyBPUpdate(hidden_size)
self.gather_updater = GatherUpdate(hidden_size)
self.hidden_size = hidden_size
self.n_samples_total = 0
self.n_nodes_total = 0
self.n_edges_total = 0
self.n_passes = 0
def forward(self, cand_batch, mol_tree_batch):
cand_graphs, tree_mess_src_edges, tree_mess_tgt_edges, tree_mess_tgt_nodes = cand_batch
n_samples = len(cand_graphs)
cand_line_graph = cand_graphs.line_graph(
backtracking=False, shared=True)
n_nodes = cand_graphs.number_of_nodes()
n_edges = cand_graphs.number_of_edges()
cand_graphs = self.run(
cand_graphs, cand_line_graph, tree_mess_src_edges, tree_mess_tgt_edges,
tree_mess_tgt_nodes, mol_tree_batch)
g_repr = mean_nodes(cand_graphs, 'h')
self.n_samples_total += n_samples
self.n_nodes_total += n_nodes
self.n_edges_total += n_edges
self.n_passes += 1
return g_repr
def run(self, cand_graphs, cand_line_graph, tree_mess_src_edges, tree_mess_tgt_edges,
tree_mess_tgt_nodes, mol_tree_batch):
n_nodes = cand_graphs.number_of_nodes()
cand_graphs.apply_edges(
func=lambda edges: {'src_x': edges.src['x']},
)
bond_features = cand_line_graph.ndata['x']
source_features = cand_line_graph.ndata['src_x']
features = torch.cat([source_features, bond_features], 1)
msg_input = self.W_i(features)
cand_line_graph.ndata.update({
'msg_input': msg_input,
'msg': torch.relu(msg_input),
'accum_msg': torch.zeros_like(msg_input),
})
zero_node_state = bond_features.new(n_nodes, self.hidden_size).zero_()
cand_graphs.ndata.update({
'm': zero_node_state.clone(),
'h': zero_node_state.clone(),
})
cand_graphs.edata['alpha'] = \
cuda(torch.zeros(cand_graphs.number_of_edges(), self.hidden_size))
cand_graphs.ndata['alpha'] = zero_node_state
if tree_mess_src_edges.shape[0] > 0:
if PAPER:
src_u, src_v = tree_mess_src_edges.unbind(1)
tgt_u, tgt_v = tree_mess_tgt_edges.unbind(1)
alpha = mol_tree_batch.edges[src_u, src_v].data['m']
cand_graphs.edges[tgt_u, tgt_v].data['alpha'] = alpha
else:
src_u, src_v = tree_mess_src_edges.unbind(1)
alpha = mol_tree_batch.edges[src_u, src_v].data['m']
node_idx = (tree_mess_tgt_nodes
.to(device=zero_node_state.device)[:, None]
.expand_as(alpha))
node_alpha = zero_node_state.clone().scatter_add(0, node_idx, alpha)
cand_graphs.ndata['alpha'] = node_alpha
cand_graphs.apply_edges(
func=lambda edges: {'alpha': edges.src['alpha']},
)
for i in range(self.depth - 1):
cand_line_graph.update_all(
mpn_loopy_bp_msg,
mpn_loopy_bp_reduce,
self.loopy_bp_updater,
)
cand_graphs.update_all(
mpn_gather_msg,
mpn_gather_reduce,
self.gather_updater,
)
return cand_graphs
# pylint: disable=C0111, C0103, E1101, W0611, W0612
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl.function as DGLF
from dgl import batch, dfs_labeled_edges_generator
from .chemutils import enum_assemble_nx, get_mol
from .mol_tree_nx import DGLMolTree
from .nnutils import GRUUpdate, cuda
MAX_NB = 8
MAX_DECODE_LEN = 100
def dfs_order(forest, roots):
edges = dfs_labeled_edges_generator(forest, roots, has_reverse_edge=True)
for e, l in zip(*edges):
# I exploited the fact that the reverse edge ID equal to 1 xor forward
# edge ID for molecule trees. Normally, I should locate reverse edges
# using find_edges().
yield e ^ l, l
dec_tree_node_msg = DGLF.copy_edge(edge='m', out='m')
dec_tree_node_reduce = DGLF.sum(msg='m', out='h')
def dec_tree_node_update(nodes):
return {'new': nodes.data['new'].clone().zero_()}
dec_tree_edge_msg = [DGLF.copy_src(
src='m', out='m'), DGLF.copy_src(src='rm', out='rm')]
dec_tree_edge_reduce = [
DGLF.sum(msg='m', out='s'), DGLF.sum(msg='rm', out='accum_rm')]
def have_slots(fa_slots, ch_slots):
if len(fa_slots) > 2 and len(ch_slots) > 2:
return True
matches = []
for i, s1 in enumerate(fa_slots):
a1, c1, h1 = s1
for j, s2 in enumerate(ch_slots):
a2, c2, h2 = s2
if a1 == a2 and c1 == c2 and (a1 != "C" or h1 + h2 >= 4):
matches.append((i, j))
if len(matches) == 0:
return False
fa_match, ch_match = list(zip(*matches))
if len(set(fa_match)) == 1 and 1 < len(fa_slots) <= 2: # never remove atom from ring
fa_slots.pop(fa_match[0])
if len(set(ch_match)) == 1 and 1 < len(ch_slots) <= 2: # never remove atom from ring
ch_slots.pop(ch_match[0])
return True
def can_assemble(mol_tree, u, v_node_dict):
u_node_dict = mol_tree.nodes_dict[u]
u_neighbors = mol_tree.successors(u)
u_neighbors_node_dict = [
mol_tree.nodes_dict[_u]
for _u in u_neighbors
if _u in mol_tree.nodes_dict
]
neis = u_neighbors_node_dict + [v_node_dict]
for i, nei in enumerate(neis):
nei['nid'] = i
neighbors = [nei for nei in neis if nei['mol'].GetNumAtoms() > 1]
neighbors = sorted(
neighbors, key=lambda x: x['mol'].GetNumAtoms(), reverse=True)
singletons = [nei for nei in neis if nei['mol'].GetNumAtoms() == 1]
neighbors = singletons + neighbors
cands = enum_assemble_nx(u_node_dict, neighbors)
return len(cands) > 0
def create_node_dict(smiles, clique=None):
if clique is None:
clique = []
return dict(
smiles=smiles,
mol=get_mol(smiles),
clique=clique,
)
class DGLJTNNDecoder(nn.Module):
def __init__(self, vocab, hidden_size, latent_size, embedding=None):
nn.Module.__init__(self)
self.hidden_size = hidden_size
self.vocab_size = vocab.size()
self.vocab = vocab
if embedding is None:
self.embedding = nn.Embedding(self.vocab_size, hidden_size)
else:
self.embedding = embedding
self.dec_tree_edge_update = GRUUpdate(hidden_size)
self.W = nn.Linear(latent_size + hidden_size, hidden_size)
self.U = nn.Linear(latent_size + 2 * hidden_size, hidden_size)
self.W_o = nn.Linear(hidden_size, self.vocab_size)
self.U_s = nn.Linear(hidden_size, 1)
def forward(self, mol_trees, tree_vec):
'''
The training procedure which computes the prediction loss given the
ground truth tree
'''
mol_tree_batch = batch(mol_trees)
mol_tree_batch_lg = mol_tree_batch.line_graph(
backtracking=False, shared=True)
n_trees = len(mol_trees)
return self.run(mol_tree_batch, mol_tree_batch_lg, n_trees, tree_vec)
def run(self, mol_tree_batch, mol_tree_batch_lg, n_trees, tree_vec):
node_offset = np.cumsum([0] + mol_tree_batch.batch_num_nodes)
root_ids = node_offset[:-1]
n_nodes = mol_tree_batch.number_of_nodes()
n_edges = mol_tree_batch.number_of_edges()
mol_tree_batch.ndata.update({
'x': self.embedding(mol_tree_batch.ndata['wid']),
'h': cuda(torch.zeros(n_nodes, self.hidden_size)),
# whether it's newly generated node
'new': cuda(torch.ones(n_nodes).byte()),
})
mol_tree_batch.edata.update({
's': cuda(torch.zeros(n_edges, self.hidden_size)),
'm': cuda(torch.zeros(n_edges, self.hidden_size)),
'r': cuda(torch.zeros(n_edges, self.hidden_size)),
'z': cuda(torch.zeros(n_edges, self.hidden_size)),
'src_x': cuda(torch.zeros(n_edges, self.hidden_size)),
'dst_x': cuda(torch.zeros(n_edges, self.hidden_size)),
'rm': cuda(torch.zeros(n_edges, self.hidden_size)),
'accum_rm': cuda(torch.zeros(n_edges, self.hidden_size)),
})
mol_tree_batch.apply_edges(
func=lambda edges: {
'src_x': edges.src['x'], 'dst_x': edges.dst['x']},
)
# input tensors for stop prediction (p) and label prediction (q)
p_inputs = []
p_targets = []
q_inputs = []
q_targets = []
# Predict root
mol_tree_batch.pull(
root_ids,
dec_tree_node_msg,
dec_tree_node_reduce,
dec_tree_node_update,
)
# Extract hidden states and store them for stop/label prediction
h = mol_tree_batch.nodes[root_ids].data['h']
x = mol_tree_batch.nodes[root_ids].data['x']
p_inputs.append(torch.cat([x, h, tree_vec], 1))
# If the out degree is 0 we don't generate any edges at all
root_out_degrees = mol_tree_batch.out_degrees(root_ids)
q_inputs.append(torch.cat([h, tree_vec], 1))
q_targets.append(mol_tree_batch.nodes[root_ids].data['wid'])
# Traverse the tree and predict on children
for eid, p in dfs_order(mol_tree_batch, root_ids):
u, v = mol_tree_batch.find_edges(eid)
p_target_list = torch.zeros_like(root_out_degrees)
p_target_list[root_out_degrees > 0] = 1 - p
p_target_list = p_target_list[root_out_degrees >= 0]
p_targets.append(torch.tensor(p_target_list))
root_out_degrees -= (root_out_degrees == 0).long()
root_out_degrees -= torch.tensor(np.isin(root_ids,
v).astype('int64'))
mol_tree_batch_lg.pull(
eid,
dec_tree_edge_msg,
dec_tree_edge_reduce,
self.dec_tree_edge_update,
)
is_new = mol_tree_batch.nodes[v].data['new']
mol_tree_batch.pull(
v,
dec_tree_node_msg,
dec_tree_node_reduce,
dec_tree_node_update,
)
# Extract
n_repr = mol_tree_batch.nodes[v].data
h = n_repr['h']
x = n_repr['x']
tree_vec_set = tree_vec[root_out_degrees >= 0]
wid = n_repr['wid']
p_inputs.append(torch.cat([x, h, tree_vec_set], 1))
# Only newly generated nodes are needed for label prediction
# NOTE: The following works since the uncomputed messages are zeros.
q_input = torch.cat([h, tree_vec_set], 1)[is_new]
q_target = wid[is_new]
if q_input.shape[0] > 0:
q_inputs.append(q_input)
q_targets.append(q_target)
p_targets.append(torch.zeros((root_out_degrees == 0).sum()).long())
# Batch compute the stop/label prediction losses
p_inputs = torch.cat(p_inputs, 0)
p_targets = cuda(torch.cat(p_targets, 0))
q_inputs = torch.cat(q_inputs, 0)
q_targets = torch.cat(q_targets, 0)
q = self.W_o(torch.relu(self.W(q_inputs)))
p = self.U_s(torch.relu(self.U(p_inputs)))[:, 0]
p_loss = F.binary_cross_entropy_with_logits(
p, p_targets.float(), size_average=False
) / n_trees
q_loss = F.cross_entropy(q, q_targets, size_average=False) / n_trees
p_acc = ((p > 0).long() == p_targets).sum().float() / \
p_targets.shape[0]
q_acc = (q.max(1)[1] == q_targets).float().sum() / q_targets.shape[0]
self.q_inputs = q_inputs
self.q_targets = q_targets
self.q = q
self.p_inputs = p_inputs
self.p_targets = p_targets
self.p = p
return q_loss, p_loss, q_acc, p_acc
def decode(self, mol_vec):
assert mol_vec.shape[0] == 1
mol_tree = DGLMolTree(None)
init_hidden = cuda(torch.zeros(1, self.hidden_size))
root_hidden = torch.cat([init_hidden, mol_vec], 1)
root_hidden = F.relu(self.W(root_hidden))
root_score = self.W_o(root_hidden)
_, root_wid = torch.max(root_score, 1)
root_wid = root_wid.view(1)
mol_tree.add_nodes(1) # root
mol_tree.nodes[0].data['wid'] = root_wid
mol_tree.nodes[0].data['x'] = self.embedding(root_wid)
mol_tree.nodes[0].data['h'] = init_hidden
mol_tree.nodes[0].data['fail'] = cuda(torch.tensor([0]))
mol_tree.nodes_dict[0] = root_node_dict = create_node_dict(
self.vocab.get_smiles(root_wid))
stack, trace = [], []
stack.append((0, self.vocab.get_slots(root_wid)))
all_nodes = {0: root_node_dict}
first = True
new_node_id = 0
new_edge_id = 0
for step in range(MAX_DECODE_LEN):
u, u_slots = stack[-1]
udata = mol_tree.nodes[u].data
x = udata['x']
h = udata['h']
# Predict stop
p_input = torch.cat([x, h, mol_vec], 1)
p_score = torch.sigmoid(self.U_s(torch.relu(self.U(p_input))))
backtrack = (p_score.item() < 0.5)
if not backtrack:
# Predict next clique. Note that the prediction may fail due
# to lack of assemblable components
mol_tree.add_nodes(1)
new_node_id += 1
v = new_node_id
mol_tree.add_edges(u, v)
uv = new_edge_id
new_edge_id += 1
if first:
mol_tree.edata.update({
's': cuda(torch.zeros(1, self.hidden_size)),
'm': cuda(torch.zeros(1, self.hidden_size)),
'r': cuda(torch.zeros(1, self.hidden_size)),
'z': cuda(torch.zeros(1, self.hidden_size)),
'src_x': cuda(torch.zeros(1, self.hidden_size)),
'dst_x': cuda(torch.zeros(1, self.hidden_size)),
'rm': cuda(torch.zeros(1, self.hidden_size)),
'accum_rm': cuda(torch.zeros(1, self.hidden_size)),
})
first = False
mol_tree.edges[uv].data['src_x'] = mol_tree.nodes[u].data['x']
# keeping dst_x 0 is fine as h on new edge doesn't depend on that.
# DGL doesn't dynamically maintain a line graph.
mol_tree_lg = mol_tree.line_graph(
backtracking=False, shared=True)
mol_tree_lg.pull(
uv,
dec_tree_edge_msg,
dec_tree_edge_reduce,
self.dec_tree_edge_update.update_zm,
)
mol_tree.pull(
v,
dec_tree_node_msg,
dec_tree_node_reduce,
)
vdata = mol_tree.nodes[v].data
h_v = vdata['h']
q_input = torch.cat([h_v, mol_vec], 1)
q_score = torch.softmax(
self.W_o(torch.relu(self.W(q_input))), -1)
_, sort_wid = torch.sort(q_score, 1, descending=True)
sort_wid = sort_wid.squeeze()
next_wid = None
for wid in sort_wid.tolist()[:5]:
slots = self.vocab.get_slots(wid)
cand_node_dict = create_node_dict(
self.vocab.get_smiles(wid))
if (have_slots(u_slots, slots) and can_assemble(mol_tree, u, cand_node_dict)):
next_wid = wid
next_slots = slots
next_node_dict = cand_node_dict
break
if next_wid is None:
# Failed adding an actual children; v is a spurious node
# and we mark it.
vdata['fail'] = cuda(torch.tensor([1]))
backtrack = True
else:
next_wid = cuda(torch.tensor([next_wid]))
vdata['wid'] = next_wid
vdata['x'] = self.embedding(next_wid)
mol_tree.nodes_dict[v] = next_node_dict
all_nodes[v] = next_node_dict
stack.append((v, next_slots))
mol_tree.add_edge(v, u)
vu = new_edge_id
new_edge_id += 1
mol_tree.edges[uv].data['dst_x'] = mol_tree.nodes[v].data['x']
mol_tree.edges[vu].data['src_x'] = mol_tree.nodes[v].data['x']
mol_tree.edges[vu].data['dst_x'] = mol_tree.nodes[u].data['x']
# DGL doesn't dynamically maintain a line graph.
mol_tree_lg = mol_tree.line_graph(
backtracking=False, shared=True)
mol_tree_lg.apply_nodes(
self.dec_tree_edge_update.update_r,
uv
)
if backtrack:
if len(stack) == 1:
break # At root, terminate
pu, _ = stack[-2]
u_pu = mol_tree.edge_id(u, pu)
mol_tree_lg.pull(
u_pu,
dec_tree_edge_msg,
dec_tree_edge_reduce,
self.dec_tree_edge_update,
)
mol_tree.pull(
pu,
dec_tree_node_msg,
dec_tree_node_reduce,
)
stack.pop()
effective_nodes = mol_tree.filter_nodes(
lambda nodes: nodes.data['fail'] != 1)
effective_nodes, _ = torch.sort(effective_nodes)
return mol_tree, all_nodes, effective_nodes
# pylint: disable=C0111, C0103, E1101, W0611, W0612
import numpy as np
import torch
import torch.nn as nn
import dgl.function as DGLF
from dgl import batch, bfs_edges_generator
from .nnutils import GRUUpdate, cuda
MAX_NB = 8
def level_order(forest, roots):
edges = bfs_edges_generator(forest, roots)
_, leaves = forest.find_edges(edges[-1])
edges_back = bfs_edges_generator(forest, roots, reverse=True)
yield from reversed(edges_back)
yield from edges
enc_tree_msg = [DGLF.copy_src(src='m', out='m'),
DGLF.copy_src(src='rm', out='rm')]
enc_tree_reduce = [DGLF.sum(msg='m', out='s'),
DGLF.sum(msg='rm', out='accum_rm')]
enc_tree_gather_msg = DGLF.copy_edge(edge='m', out='m')
enc_tree_gather_reduce = DGLF.sum(msg='m', out='m')
class EncoderGatherUpdate(nn.Module):
def __init__(self, hidden_size):
nn.Module.__init__(self)
self.hidden_size = hidden_size
self.W = nn.Linear(2 * hidden_size, hidden_size)
def forward(self, nodes):
x = nodes.data['x']
m = nodes.data['m']
return {
'h': torch.relu(self.W(torch.cat([x, m], 1))),
}
class DGLJTNNEncoder(nn.Module):
def __init__(self, vocab, hidden_size, embedding=None):
nn.Module.__init__(self)
self.hidden_size = hidden_size
self.vocab_size = vocab.size()
self.vocab = vocab
if embedding is None:
self.embedding = nn.Embedding(self.vocab_size, hidden_size)
else:
self.embedding = embedding
self.enc_tree_update = GRUUpdate(hidden_size)
self.enc_tree_gather_update = EncoderGatherUpdate(hidden_size)
def forward(self, mol_trees):
mol_tree_batch = batch(mol_trees)
# Build line graph to prepare for belief propagation
mol_tree_batch_lg = mol_tree_batch.line_graph(
backtracking=False, shared=True)
return self.run(mol_tree_batch, mol_tree_batch_lg)
def run(self, mol_tree_batch, mol_tree_batch_lg):
# Since tree roots are designated to 0. In the batched graph we can
# simply find the corresponding node ID by looking at node_offset
node_offset = np.cumsum([0] + mol_tree_batch.batch_num_nodes)
root_ids = node_offset[:-1]
n_nodes = mol_tree_batch.number_of_nodes()
n_edges = mol_tree_batch.number_of_edges()
# Assign structure embeddings to tree nodes
mol_tree_batch.ndata.update({
'x': self.embedding(mol_tree_batch.ndata['wid']),
'h': cuda(torch.zeros(n_nodes, self.hidden_size)),
})
# Initialize the intermediate variables according to Eq (4)-(8).
# Also initialize the src_x and dst_x fields.
# TODO: context?
mol_tree_batch.edata.update({
's': cuda(torch.zeros(n_edges, self.hidden_size)),
'm': cuda(torch.zeros(n_edges, self.hidden_size)),
'r': cuda(torch.zeros(n_edges, self.hidden_size)),
'z': cuda(torch.zeros(n_edges, self.hidden_size)),
'src_x': cuda(torch.zeros(n_edges, self.hidden_size)),
'dst_x': cuda(torch.zeros(n_edges, self.hidden_size)),
'rm': cuda(torch.zeros(n_edges, self.hidden_size)),
'accum_rm': cuda(torch.zeros(n_edges, self.hidden_size)),
})
# Send the source/destination node features to edges
mol_tree_batch.apply_edges(
func=lambda edges: {
'src_x': edges.src['x'], 'dst_x': edges.dst['x']},
)
# Message passing
# I exploited the fact that the reduce function is a sum of incoming
# messages, and the uncomputed messages are zero vectors. Essentially,
# we can always compute s_ij as the sum of incoming m_ij, no matter
# if m_ij is actually computed or not.
for eid in level_order(mol_tree_batch, root_ids):
#eid = mol_tree_batch.edge_ids(u, v)
mol_tree_batch_lg.pull(
eid,
enc_tree_msg,
enc_tree_reduce,
self.enc_tree_update,
)
# Readout
mol_tree_batch.update_all(
enc_tree_gather_msg,
enc_tree_gather_reduce,
self.enc_tree_gather_update,
)
root_vecs = mol_tree_batch.nodes[root_ids].data['h']
return mol_tree_batch, root_vecs
# pylint: disable=C0111, C0103, E1101, W0611, W0612, C0200
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
import rdkit.Chem as Chem
from ....graph import batch, unbatch
from ....contrib.deprecation import deprecated
from ....data.utils import get_download_dir
from .chemutils import (attach_mols_nx, copy_edit_mol, decode_stereo,
enum_assemble_nx, set_atommap)
from .jtmpn import DGLJTMPN
from .jtmpn import mol2dgl_single as mol2dgl_dec
from .jtnn_dec import DGLJTNNDecoder
from .jtnn_enc import DGLJTNNEncoder
from .mol_tree import Vocab
from .mpn import DGLMPN
from .mpn import mol2dgl_single as mol2dgl_enc
from .nnutils import cuda, move_dgl_to_cuda
class DGLJTNNVAE(nn.Module):
"""
`Junction Tree Variational Autoencoder for Molecular Graph Generation
<https://arxiv.org/abs/1802.04364>`__
"""
@deprecated('Import DGLJTNNVAE from dgllife.model instead.', 'class')
def __init__(self, hidden_size, latent_size, depth, vocab=None, vocab_file=None):
super(DGLJTNNVAE, self).__init__()
if vocab is None:
if vocab_file is None:
vocab_file = '{}/jtnn/{}.txt'.format(
get_download_dir(), 'vocab')
self.vocab = Vocab([x.strip("\r\n ")
for x in open(vocab_file)])
else:
self.vocab = vocab
self.hidden_size = hidden_size
self.latent_size = latent_size
self.depth = depth
self.embedding = nn.Embedding(self.vocab.size(), hidden_size)
self.mpn = DGLMPN(hidden_size, depth)
self.jtnn = DGLJTNNEncoder(self.vocab, hidden_size, self.embedding)
self.decoder = DGLJTNNDecoder(
self.vocab, hidden_size, latent_size // 2, self.embedding)
self.jtmpn = DGLJTMPN(hidden_size, depth)
self.T_mean = nn.Linear(hidden_size, latent_size // 2)
self.T_var = nn.Linear(hidden_size, latent_size // 2)
self.G_mean = nn.Linear(hidden_size, latent_size // 2)
self.G_var = nn.Linear(hidden_size, latent_size // 2)
self.n_nodes_total = 0
self.n_passes = 0
self.n_edges_total = 0
self.n_tree_nodes_total = 0
@staticmethod
def move_to_cuda(mol_batch):
for t in mol_batch['mol_trees']:
move_dgl_to_cuda(t)
move_dgl_to_cuda(mol_batch['mol_graph_batch'])
if 'cand_graph_batch' in mol_batch:
move_dgl_to_cuda(mol_batch['cand_graph_batch'])
if mol_batch.get('stereo_cand_graph_batch') is not None:
move_dgl_to_cuda(mol_batch['stereo_cand_graph_batch'])
def encode(self, mol_batch):
mol_graphs = mol_batch['mol_graph_batch']
mol_vec = self.mpn(mol_graphs)
mol_tree_batch, tree_vec = self.jtnn(mol_batch['mol_trees'])
self.n_nodes_total += mol_graphs.number_of_nodes()
self.n_edges_total += mol_graphs.number_of_edges()
self.n_tree_nodes_total += sum(t.number_of_nodes()
for t in mol_batch['mol_trees'])
self.n_passes += 1
return mol_tree_batch, tree_vec, mol_vec
def sample(self, tree_vec, mol_vec, e1=None, e2=None):
tree_mean = self.T_mean(tree_vec)
tree_log_var = -torch.abs(self.T_var(tree_vec))
mol_mean = self.G_mean(mol_vec)
mol_log_var = -torch.abs(self.G_var(mol_vec))
epsilon = cuda(torch.randn(*tree_mean.shape)) if e1 is None else e1
tree_vec = tree_mean + torch.exp(tree_log_var / 2) * epsilon
epsilon = cuda(torch.randn(*mol_mean.shape)) if e2 is None else e2
mol_vec = mol_mean + torch.exp(mol_log_var / 2) * epsilon
z_mean = torch.cat([tree_mean, mol_mean], 1)
z_log_var = torch.cat([tree_log_var, mol_log_var], 1)
return tree_vec, mol_vec, z_mean, z_log_var
def forward(self, mol_batch, beta=0, e1=None, e2=None):
self.move_to_cuda(mol_batch)
mol_trees = mol_batch['mol_trees']
batch_size = len(mol_trees)
mol_tree_batch, tree_vec, mol_vec = self.encode(mol_batch)
tree_vec, mol_vec, z_mean, z_log_var = self.sample(
tree_vec, mol_vec, e1, e2)
kl_loss = -0.5 * torch.sum(
1.0 + z_log_var - z_mean * z_mean - torch.exp(z_log_var)) / batch_size
word_loss, topo_loss, word_acc, topo_acc = self.decoder(
mol_trees, tree_vec)
assm_loss, assm_acc = self.assm(mol_batch, mol_tree_batch, mol_vec)
stereo_loss, stereo_acc = self.stereo(mol_batch, mol_vec)
loss = word_loss + topo_loss + assm_loss + 2 * stereo_loss + beta * kl_loss
return loss, kl_loss, word_acc, topo_acc, assm_acc, stereo_acc
def assm(self, mol_batch, mol_tree_batch, mol_vec):
cands = [mol_batch['cand_graph_batch'],
mol_batch['tree_mess_src_e'],
mol_batch['tree_mess_tgt_e'],
mol_batch['tree_mess_tgt_n']]
cand_vec = self.jtmpn(cands, mol_tree_batch)
cand_vec = self.G_mean(cand_vec)
batch_idx = cuda(torch.LongTensor(mol_batch['cand_batch_idx']))
mol_vec = mol_vec[batch_idx]
mol_vec = mol_vec.view(-1, 1, self.latent_size // 2)
cand_vec = cand_vec.view(-1, self.latent_size // 2, 1)
scores = (mol_vec @ cand_vec)[:, 0, 0]
cnt, tot, acc = 0, 0, 0
all_loss = []
for i, mol_tree in enumerate(mol_batch['mol_trees']):
comp_nodes = [node_id for node_id, node in mol_tree.nodes_dict.items()
if len(node['cands']) > 1 and not node['is_leaf']]
cnt += len(comp_nodes)
# segmented accuracy and cross entropy
for node_id in comp_nodes:
node = mol_tree.nodes_dict[node_id]
label = node['cands'].index(node['label'])
ncand = len(node['cands'])
cur_score = scores[tot:tot + ncand]
tot += ncand
if cur_score[label].item() >= cur_score.max().item():
acc += 1
label = cuda(torch.LongTensor([label]))
all_loss.append(
F.cross_entropy(cur_score.view(1, -1), label, size_average=False))
all_loss = sum(all_loss) / len(mol_batch['mol_trees'])
return all_loss, acc / cnt
def stereo(self, mol_batch, mol_vec):
stereo_cands = mol_batch['stereo_cand_graph_batch']
batch_idx = mol_batch['stereo_cand_batch_idx']
labels = mol_batch['stereo_cand_labels']
lengths = mol_batch['stereo_cand_lengths']
if len(labels) == 0:
# Only one stereoisomer exists; do nothing
return cuda(torch.tensor(0.)), 1.
batch_idx = cuda(torch.LongTensor(batch_idx))
stereo_cands = self.mpn(stereo_cands)
stereo_cands = self.G_mean(stereo_cands)
stereo_labels = mol_vec[batch_idx]
scores = F.cosine_similarity(stereo_cands, stereo_labels)
st, acc = 0, 0
all_loss = []
for label, le in zip(labels, lengths):
cur_scores = scores[st:st + le]
if cur_scores.data[label].item() >= cur_scores.max().item():
acc += 1
label = cuda(torch.LongTensor([label]))
all_loss.append(
F.cross_entropy(cur_scores.view(1, -1), label, size_average=False))
st += le
all_loss = sum(all_loss) / len(labels)
return all_loss, acc / len(labels)
def decode(self, tree_vec, mol_vec):
mol_tree, nodes_dict, effective_nodes = self.decoder.decode(tree_vec)
effective_nodes_list = effective_nodes.tolist()
nodes_dict = [nodes_dict[v] for v in effective_nodes_list]
for i, (node_id, node) in enumerate(zip(effective_nodes_list, nodes_dict)):
node['idx'] = i
node['nid'] = i + 1
node['is_leaf'] = True
if mol_tree.in_degree(node_id) > 1:
node['is_leaf'] = False
set_atommap(node['mol'], node['nid'])
mol_tree_sg = mol_tree.subgraph(effective_nodes)
mol_tree_sg.copy_from_parent()
mol_tree_msg, _ = self.jtnn([mol_tree_sg])
mol_tree_msg = unbatch(mol_tree_msg)[0]
mol_tree_msg.nodes_dict = nodes_dict
cur_mol = copy_edit_mol(nodes_dict[0]['mol'])
global_amap = [{}] + [{} for node in nodes_dict]
global_amap[1] = {atom.GetIdx(): atom.GetIdx()
for atom in cur_mol.GetAtoms()}
cur_mol = self.dfs_assemble(
mol_tree_msg, mol_vec, cur_mol, global_amap, [], 0, None)
if cur_mol is None:
return None
cur_mol = cur_mol.GetMol()
set_atommap(cur_mol)
cur_mol = Chem.MolFromSmiles(Chem.MolToSmiles(cur_mol))
if cur_mol is None:
return None
smiles2D = Chem.MolToSmiles(cur_mol)
stereo_cands = decode_stereo(smiles2D)
if len(stereo_cands) == 1:
return stereo_cands[0]
stereo_graphs = [mol2dgl_enc(c) for c in stereo_cands]
stereo_cand_graphs, atom_x, bond_x = \
zip(*stereo_graphs)
stereo_cand_graphs = batch(stereo_cand_graphs)
atom_x = cuda(torch.cat(atom_x))
bond_x = cuda(torch.cat(bond_x))
stereo_cand_graphs.ndata['x'] = atom_x
stereo_cand_graphs.edata['x'] = bond_x
stereo_cand_graphs.edata['src_x'] = atom_x.new(
bond_x.shape[0], atom_x.shape[1]).zero_()
stereo_vecs = self.mpn(stereo_cand_graphs)
stereo_vecs = self.G_mean(stereo_vecs)
scores = F.cosine_similarity(stereo_vecs, mol_vec)
_, max_id = scores.max(0)
return stereo_cands[max_id.item()]
def dfs_assemble(self, mol_tree_msg, mol_vec, cur_mol,
global_amap, fa_amap, cur_node_id, fa_node_id):
nodes_dict = mol_tree_msg.nodes_dict
fa_node = nodes_dict[fa_node_id] if fa_node_id is not None else None
cur_node = nodes_dict[cur_node_id]
fa_nid = fa_node['nid'] if fa_node is not None else -1
prev_nodes = [fa_node] if fa_node is not None else []
children_node_id = [v for v in mol_tree_msg.successors(cur_node_id).tolist()
if nodes_dict[v]['nid'] != fa_nid]
children = [nodes_dict[v] for v in children_node_id]
neighbors = [nei for nei in children if nei['mol'].GetNumAtoms() > 1]
neighbors = sorted(
neighbors, key=lambda x: x['mol'].GetNumAtoms(), reverse=True)
singletons = [nei for nei in children if nei['mol'].GetNumAtoms() == 1]
neighbors = singletons + neighbors
cur_amap = [(fa_nid, a2, a1)
for nid, a1, a2 in fa_amap if nid == cur_node['nid']]
cands = enum_assemble_nx(cur_node, neighbors, prev_nodes, cur_amap)
if len(cands) == 0:
return None
cand_smiles, cand_mols, cand_amap = list(zip(*cands))
cands = [(candmol, mol_tree_msg, cur_node_id) for candmol in cand_mols]
cand_graphs, atom_x, bond_x, tree_mess_src_edges, \
tree_mess_tgt_edges, tree_mess_tgt_nodes = mol2dgl_dec(
cands)
cand_graphs = batch(cand_graphs)
atom_x = cuda(atom_x)
bond_x = cuda(bond_x)
cand_graphs.ndata['x'] = atom_x
cand_graphs.edata['x'] = bond_x
cand_graphs.edata['src_x'] = atom_x.new(
bond_x.shape[0], atom_x.shape[1]).zero_()
cand_vecs = self.jtmpn(
(cand_graphs, tree_mess_src_edges,
tree_mess_tgt_edges, tree_mess_tgt_nodes),
mol_tree_msg,
)
cand_vecs = self.G_mean(cand_vecs)
mol_vec = mol_vec.squeeze()
scores = cand_vecs @ mol_vec
_, cand_idx = torch.sort(scores, descending=True)
backup_mol = Chem.RWMol(cur_mol)
for i in range(len(cand_idx)):
cur_mol = Chem.RWMol(backup_mol)
pred_amap = cand_amap[cand_idx[i].item()]
new_global_amap = copy.deepcopy(global_amap)
for nei_id, ctr_atom, nei_atom in pred_amap:
if nei_id == fa_nid:
continue
new_global_amap[nei_id][nei_atom] = new_global_amap[cur_node['nid']][ctr_atom]
cur_mol = attach_mols_nx(cur_mol, children, [], new_global_amap)
new_mol = cur_mol.GetMol()
new_mol = Chem.MolFromSmiles(Chem.MolToSmiles(new_mol))
if new_mol is None:
continue
result = True
for nei_node_id, nei_node in zip(children_node_id, children):
if nei_node['is_leaf']:
continue
cur_mol = self.dfs_assemble(
mol_tree_msg, mol_vec, cur_mol, new_global_amap, pred_amap,
nei_node_id, cur_node_id)
if cur_mol is None:
result = False
break
if result:
return cur_mol
return None
# pylint: disable=C0111, C0103, E1101, W0611, W0612
import copy
import rdkit.Chem as Chem
def get_slots(smiles):
mol = Chem.MolFromSmiles(smiles)
return [(atom.GetSymbol(), atom.GetFormalCharge(), atom.GetTotalNumHs())
for atom in mol.GetAtoms()]
class Vocab(object):
def __init__(self, smiles_list):
self.vocab = smiles_list
self.vmap = {x: i for i, x in enumerate(self.vocab)}
self.slots = [get_slots(smiles) for smiles in self.vocab]
def get_index(self, smiles):
return self.vmap[smiles]
def get_smiles(self, idx):
return self.vocab[idx]
def get_slots(self, idx):
return copy.deepcopy(self.slots[idx])
def size(self):
return len(self.vocab)
# pylint: disable=C0111, C0103, E1101, W0611, W0612
import numpy as np
import rdkit.Chem as Chem
from dgl import DGLGraph
from .chemutils import (decode_stereo, enum_assemble_nx, get_clique_mol,
get_mol, get_smiles, set_atommap, tree_decomp)
class DGLMolTree(DGLGraph):
def __init__(self, smiles):
DGLGraph.__init__(self)
self.nodes_dict = {}
if smiles is None:
return
self.smiles = smiles
self.mol = get_mol(smiles)
# Stereo Generation
mol = Chem.MolFromSmiles(smiles)
self.smiles3D = Chem.MolToSmiles(mol, isomericSmiles=True)
self.smiles2D = Chem.MolToSmiles(mol)
self.stereo_cands = decode_stereo(self.smiles2D)
# cliques: a list of list of atom indices
cliques, edges = tree_decomp(self.mol)
root = 0
for i, c in enumerate(cliques):
cmol = get_clique_mol(self.mol, c)
csmiles = get_smiles(cmol)
self.nodes_dict[i] = dict(
smiles=csmiles,
mol=get_mol(csmiles),
clique=c,
)
if min(c) == 0:
root = i
self.add_nodes(len(cliques))
# The clique with atom ID 0 becomes root
if root > 0:
for attr in self.nodes_dict[0]:
self.nodes_dict[0][attr], self.nodes_dict[root][attr] = \
self.nodes_dict[root][attr], self.nodes_dict[0][attr]
src = np.zeros((len(edges) * 2,), dtype='int')
dst = np.zeros((len(edges) * 2,), dtype='int')
for i, (_x, _y) in enumerate(edges):
x = 0 if _x == root else root if _x == 0 else _x
y = 0 if _y == root else root if _y == 0 else _y
src[2 * i] = x
dst[2 * i] = y
src[2 * i + 1] = y
dst[2 * i + 1] = x
self.add_edges(src, dst)
for i in self.nodes_dict:
self.nodes_dict[i]['nid'] = i + 1
if self.out_degree(i) > 1: # Leaf node mol is not marked
set_atommap(self.nodes_dict[i]['mol'],
self.nodes_dict[i]['nid'])
self.nodes_dict[i]['is_leaf'] = (self.out_degree(i) == 1)
def treesize(self):
return self.number_of_nodes()
def _recover_node(self, i, original_mol):
node = self.nodes_dict[i]
clique = []
clique.extend(node['clique'])
if not node['is_leaf']:
for cidx in node['clique']:
original_mol.GetAtomWithIdx(cidx).SetAtomMapNum(node['nid'])
for j in self.successors(i).numpy():
nei_node = self.nodes_dict[j]
clique.extend(nei_node['clique'])
if nei_node['is_leaf']: # Leaf node, no need to mark
continue
for cidx in nei_node['clique']:
# allow singleton node override the atom mapping
if cidx not in node['clique'] or len(nei_node['clique']) == 1:
atom = original_mol.GetAtomWithIdx(cidx)
atom.SetAtomMapNum(nei_node['nid'])
clique = list(set(clique))
label_mol = get_clique_mol(original_mol, clique)
node['label'] = Chem.MolToSmiles(
Chem.MolFromSmiles(get_smiles(label_mol)))
node['label_mol'] = get_mol(node['label'])
for cidx in clique:
original_mol.GetAtomWithIdx(cidx).SetAtomMapNum(0)
return node['label']
def _assemble_node(self, i):
neighbors = [self.nodes_dict[j] for j in self.successors(i).numpy()
if self.nodes_dict[j]['mol'].GetNumAtoms() > 1]
neighbors = sorted(
neighbors, key=lambda x: x['mol'].GetNumAtoms(), reverse=True)
singletons = [self.nodes_dict[j] for j in self.successors(i).numpy()
if self.nodes_dict[j]['mol'].GetNumAtoms() == 1]
neighbors = singletons + neighbors
cands = enum_assemble_nx(self.nodes_dict[i], neighbors)
if len(cands) > 0:
self.nodes_dict[i]['cands'], self.nodes_dict[i]['cand_mols'], _ = list(
zip(*cands))
self.nodes_dict[i]['cands'] = list(self.nodes_dict[i]['cands'])
self.nodes_dict[i]['cand_mols'] = list(
self.nodes_dict[i]['cand_mols'])
else:
self.nodes_dict[i]['cands'] = []
self.nodes_dict[i]['cand_mols'] = []
def recover(self):
for i in self.nodes_dict:
self._recover_node(i, self.mol)
def assemble(self):
for i in self.nodes_dict:
self._assemble_node(i)
# pylint: disable=C0111, C0103, E1101, W0611, W0612
# pylint: disable=redefined-outer-name
import rdkit.Chem as Chem
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl.function as DGLF
from dgl import DGLGraph, mean_nodes
from .chemutils import get_mol
ELEM_LIST = ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na',
'Ca', 'Fe', 'Al', 'I', 'B', 'K', 'Se', 'Zn', 'H', 'Cu', 'Mn', 'unknown']
ATOM_FDIM = len(ELEM_LIST) + 6 + 5 + 4 + 1
BOND_FDIM = 5 + 6
MAX_NB = 6
def onek_encoding_unk(x, allowable_set):
if x not in allowable_set:
x = allowable_set[-1]
return [x == s for s in allowable_set]
def atom_features(atom):
return (torch.Tensor(onek_encoding_unk(atom.GetSymbol(), ELEM_LIST)
+ onek_encoding_unk(atom.GetDegree(),
[0, 1, 2, 3, 4, 5])
+ onek_encoding_unk(atom.GetFormalCharge(),
[-1, -2, 1, 2, 0])
+ onek_encoding_unk(int(atom.GetChiralTag()), [0, 1, 2, 3])
+ [atom.GetIsAromatic()]))
def bond_features(bond):
bt = bond.GetBondType()
stereo = int(bond.GetStereo())
fbond = [bt == Chem.rdchem.BondType.SINGLE, bt == Chem.rdchem.BondType.DOUBLE, bt ==
Chem.rdchem.BondType.TRIPLE, bt == Chem.rdchem.BondType.AROMATIC, bond.IsInRing()]
fstereo = onek_encoding_unk(stereo, [0, 1, 2, 3, 4, 5])
return torch.Tensor(fbond + fstereo)
def mol2dgl_single(smiles):
n_edges = 0
atom_x = []
bond_x = []
mol = get_mol(smiles)
n_atoms = mol.GetNumAtoms()
n_bonds = mol.GetNumBonds()
graph = DGLGraph()
for i, atom in enumerate(mol.GetAtoms()):
assert i == atom.GetIdx()
atom_x.append(atom_features(atom))
graph.add_nodes(n_atoms)
bond_src = []
bond_dst = []
for i, bond in enumerate(mol.GetBonds()):
begin_idx = bond.GetBeginAtom().GetIdx()
end_idx = bond.GetEndAtom().GetIdx()
features = bond_features(bond)
bond_src.append(begin_idx)
bond_dst.append(end_idx)
bond_x.append(features)
# set up the reverse direction
bond_src.append(end_idx)
bond_dst.append(begin_idx)
bond_x.append(features)
graph.add_edges(bond_src, bond_dst)
n_edges += n_bonds
return graph, torch.stack(atom_x), \
torch.stack(bond_x) if len(bond_x) > 0 else torch.zeros(0)
mpn_loopy_bp_msg = DGLF.copy_src(src='msg', out='msg')
mpn_loopy_bp_reduce = DGLF.sum(msg='msg', out='accum_msg')
class LoopyBPUpdate(nn.Module):
def __init__(self, hidden_size):
super(LoopyBPUpdate, self).__init__()
self.hidden_size = hidden_size
self.W_h = nn.Linear(hidden_size, hidden_size, bias=False)
def forward(self, nodes):
msg_input = nodes.data['msg_input']
msg_delta = self.W_h(nodes.data['accum_msg'])
msg = F.relu(msg_input + msg_delta)
return {'msg': msg}
mpn_gather_msg = DGLF.copy_edge(edge='msg', out='msg')
mpn_gather_reduce = DGLF.sum(msg='msg', out='m')
class GatherUpdate(nn.Module):
def __init__(self, hidden_size):
super(GatherUpdate, self).__init__()
self.hidden_size = hidden_size
self.W_o = nn.Linear(ATOM_FDIM + hidden_size, hidden_size)
def forward(self, nodes):
m = nodes.data['m']
return {
'h': F.relu(self.W_o(torch.cat([nodes.data['x'], m], 1))),
}
class DGLMPN(nn.Module):
def __init__(self, hidden_size, depth):
super(DGLMPN, self).__init__()
self.depth = depth
self.W_i = nn.Linear(ATOM_FDIM + BOND_FDIM, hidden_size, bias=False)
self.loopy_bp_updater = LoopyBPUpdate(hidden_size)
self.gather_updater = GatherUpdate(hidden_size)
self.hidden_size = hidden_size
self.n_samples_total = 0
self.n_nodes_total = 0
self.n_edges_total = 0
self.n_passes = 0
def forward(self, mol_graph):
n_samples = mol_graph.batch_size
mol_line_graph = mol_graph.line_graph(backtracking=False, shared=True)
n_nodes = mol_graph.number_of_nodes()
n_edges = mol_graph.number_of_edges()
mol_graph = self.run(mol_graph, mol_line_graph)
# TODO: replace with unbatch or readout
g_repr = mean_nodes(mol_graph, 'h')
self.n_samples_total += n_samples
self.n_nodes_total += n_nodes
self.n_edges_total += n_edges
self.n_passes += 1
return g_repr
def run(self, mol_graph, mol_line_graph):
n_nodes = mol_graph.number_of_nodes()
mol_graph.apply_edges(
func=lambda edges: {'src_x': edges.src['x']},
)
e_repr = mol_line_graph.ndata
bond_features = e_repr['x']
source_features = e_repr['src_x']
features = torch.cat([source_features, bond_features], 1)
msg_input = self.W_i(features)
mol_line_graph.ndata.update({
'msg_input': msg_input,
'msg': F.relu(msg_input),
'accum_msg': torch.zeros_like(msg_input),
})
mol_graph.ndata.update({
'm': bond_features.new(n_nodes, self.hidden_size).zero_(),
'h': bond_features.new(n_nodes, self.hidden_size).zero_(),
})
for i in range(self.depth - 1):
mol_line_graph.update_all(
mpn_loopy_bp_msg,
mpn_loopy_bp_reduce,
self.loopy_bp_updater,
)
mol_graph.update_all(
mpn_gather_msg,
mpn_gather_reduce,
self.gather_updater,
)
return mol_graph
# pylint: disable=C0111, C0103, E1101, W0611, W0612
import os
import torch
import torch.nn as nn
from torch.autograd import Variable
def create_var(tensor, requires_grad=None):
if requires_grad is None:
return Variable(tensor)
else:
return Variable(tensor, requires_grad=requires_grad)
def cuda(tensor):
if torch.cuda.is_available() and not os.getenv('NOCUDA', None):
return tensor.cuda()
else:
return tensor
class GRUUpdate(nn.Module):
def __init__(self, hidden_size):
nn.Module.__init__(self)
self.hidden_size = hidden_size
self.W_z = nn.Linear(2 * hidden_size, hidden_size)
self.W_r = nn.Linear(hidden_size, hidden_size, bias=False)
self.U_r = nn.Linear(hidden_size, hidden_size)
self.W_h = nn.Linear(2 * hidden_size, hidden_size)
def update_zm(self, node):
src_x = node.data['src_x']
s = node.data['s']
rm = node.data['accum_rm']
z = torch.sigmoid(self.W_z(torch.cat([src_x, s], 1)))
m = torch.tanh(self.W_h(torch.cat([src_x, rm], 1)))
m = (1 - z) * s + z * m
return {'m': m, 'z': z}
def update_r(self, node, zm=None):
dst_x = node.data['dst_x']
m = node.data['m'] if zm is None else zm['m']
r_1 = self.W_r(dst_x)
r_2 = self.U_r(m)
r = torch.sigmoid(r_1 + r_2)
return {'r': r, 'rm': r * m}
def forward(self, node):
dic = self.update_zm(node)
dic.update(self.update_r(node, zm=dic))
return dic
def move_dgl_to_cuda(g):
g.ndata.update({k: cuda(g.ndata[k]) for k in g.ndata})
g.edata.update({k: cuda(g.edata[k]) for k in g.edata})
# -*- coding: utf-8 -*-
# pylint: disable=C0103, E1101, C0111
"""
The implementation of neural network layers used in SchNet and MGCN.
"""
import torch
import torch.nn as nn
from torch.nn import Softplus
import numpy as np
from ... import function as fn
class AtomEmbedding(nn.Module):
"""
Convert the atom(node) list to atom embeddings.
The atoms with the same element share the same initial embedding.
Parameters
----------
dim : int
Size of embeddings, default to be 128.
type_num : int
The largest atomic number of atoms in the dataset, default to be 100.
pre_train : None or pre-trained embeddings
Pre-trained embeddings, default to be None.
"""
def __init__(self, dim=128, type_num=100, pre_train=None):
super(AtomEmbedding, self).__init__()
self._dim = dim
self._type_num = type_num
if pre_train is not None:
self.embedding = nn.Embedding.from_pretrained(pre_train, padding_idx=0)
else:
self.embedding = nn.Embedding(type_num, dim, padding_idx=0)
def forward(self, atom_types):
"""
Parameters
----------
atom_types : int64 tensor of shape (B1)
Types for atoms in the graph(s), B1 for the number of atoms.
Returns
-------
float32 tensor of shape (B1, self._dim)
Atom embeddings.
"""
return self.embedding(atom_types)
class EdgeEmbedding(nn.Module):
"""
Module for embedding edges. Edges linking same pairs of atoms share
the same initial embedding.
Parameters
----------
dim : int
Size of embeddings, default to be 128.
edge_num : int
Maximum number of edge types allowed, default to be 3000.
pre_train : Edge embeddings or None
Pre-trained edge embeddings, default to be None.
"""
def __init__(self, dim=128, edge_num=3000, pre_train=None):
super(EdgeEmbedding, self).__init__()
self._dim = dim
self._edge_num = edge_num
if pre_train is not None:
self.embedding = nn.Embedding.from_pretrained(pre_train, padding_idx=0)
else:
self.embedding = nn.Embedding(edge_num, dim, padding_idx=0)
def generate_edge_type(self, edges):
"""Generate edge type.
The edge type is based on the type of the src & dst atom.
Note that directions are not distinguished, e.g. C-O and O-C are the same edge type.
To map a pair of nodes to one number, we use an unordered pairing function here
See more detail in this disscussion:
https://math.stackexchange.com/questions/23503/create-unique-number-from-2-numbers
Note that the edge_num should be larger than the square of maximum atomic number
in the dataset.
Parameters
----------
edges : EdgeBatch
Edges for deciding types
Returns
-------
dict
Stores the edge types in "type"
"""
atom_type_x = edges.src['ntype']
atom_type_y = edges.dst['ntype']
return {
'etype': atom_type_x * atom_type_y + \
(torch.abs(atom_type_x - atom_type_y) - 1) ** 2 / 4
}
def forward(self, g, atom_types):
"""Compute edge embeddings
Parameters
----------
g : DGLGraph
The graph to compute edge embeddings
atom_types : int64 tensor of shape (B1)
Types for atoms in the graph(s), B1 for the number of atoms.
Returns
-------
float32 tensor of shape (B2, self._dim)
Computed edge embeddings
"""
g = g.local_var()
g.ndata['ntype'] = atom_types
g.apply_edges(self.generate_edge_type)
return self.embedding(g.edata.pop('etype'))
class ShiftSoftplus(nn.Module):
"""
ShiftSoftplus activation function:
1/beta * (log(1 + exp**(beta * x)) - log(shift))
Parameters
----------
beta : int
Default to be 1.
shift : int
Default to be 2.
threshold : int
Default to be 20.
"""
def __init__(self, beta=1, shift=2, threshold=20):
super(ShiftSoftplus, self).__init__()
self.shift = shift
self.softplus = Softplus(beta, threshold)
def forward(self, x):
"""Applies the activation function"""
return self.softplus(x) - np.log(float(self.shift))
class RBFLayer(nn.Module):
"""
Radial basis functions Layer.
e(d) = exp(- gamma * ||d - mu_k||^2)
With the default parameters below, we are using a default settings:
* gamma = 10
* 0 <= mu_k <= 30 for k=1~300
Parameters
----------
low : int
Smallest value to take for mu_k, default to be 0.
high : int
Largest value to take for mu_k, default to be 30.
gap : float
Difference between two consecutive values for mu_k, default to be 0.1.
dim : int
Output size for each center, default to be 1.
"""
def __init__(self, low=0, high=30, gap=0.1, dim=1):
super(RBFLayer, self).__init__()
self._low = low
self._high = high
self._dim = dim
self._n_centers = int(np.ceil((high - low) / gap))
centers = np.linspace(low, high, self._n_centers)
self.centers = torch.tensor(centers, dtype=torch.float, requires_grad=False)
self.centers = nn.Parameter(self.centers, requires_grad=False)
self._fan_out = self._dim * self._n_centers
self._gap = centers[1] - centers[0]
def forward(self, edge_distances):
"""
Parameters
----------
edge_distances : float32 tensor of shape (B, 1)
Edge distances, B for the number of edges.
Returns
-------
float32 tensor of shape (B, self._fan_out)
Computed RBF results
"""
radial = edge_distances - self.centers
coef = -1 / self._gap
return torch.exp(coef * (radial ** 2))
class CFConv(nn.Module):
"""
The continuous-filter convolution layer in SchNet.
Parameters
----------
rbf_dim : int
Dimension of the RBF layer output
dim : int
Dimension of output, default to be 64
act : activation function or None.
Activation function, default to be shifted softplus
"""
def __init__(self, rbf_dim, dim=64, act=None):
super(CFConv, self).__init__()
self._rbf_dim = rbf_dim
self._dim = dim
if act is None:
activation = nn.Softplus(beta=0.5, threshold=14)
else:
activation = act
self.project = nn.Sequential(
nn.Linear(self._rbf_dim, self._dim),
activation,
nn.Linear(self._dim, self._dim)
)
def forward(self, g, node_weight, rbf_out):
"""
Parameters
----------
g : DGLGraph
The graph for performing convolution
node_weight : float32 tensor of shape (B1, D1)
The weight of nodes in message passing, B1 for number of nodes and
D1 for node weight size.
rbf_out : float32 tensor of shape (B2, D2)
The output of RBFLayer, B2 for number of edges and D2 for rbf out size.
"""
g = g.local_var()
e = self.project(rbf_out)
g.ndata['node_weight'] = node_weight
g.edata['e'] = e
g.update_all(fn.u_mul_e('node_weight', 'e', 'm'), fn.sum('m', 'h'))
return g.ndata.pop('h')
class Interaction(nn.Module):
"""
The interaction layer in the SchNet model.
Parameters
----------
rbf_dim : int
Dimension of the RBF layer output
dim : int
Dimension of intermediate representations
"""
def __init__(self, rbf_dim, dim):
super(Interaction, self).__init__()
self._dim = dim
self.node_layer1 = nn.Linear(dim, dim, bias=False)
self.cfconv = CFConv(rbf_dim, dim, Softplus(beta=0.5, threshold=14))
self.node_layer2 = nn.Sequential(
nn.Linear(dim, dim),
Softplus(beta=0.5, threshold=14),
nn.Linear(dim, dim)
)
def forward(self, g, n_feat, rbf_out):
"""
Parameters
----------
g : DGLGraph
The graph for performing convolution
n_feat : float32 tensor of shape (B1, D1)
Node features, B1 for number of nodes and D1 for feature size.
rbf_out : float32 tensor of shape (B2, D2)
The output of RBFLayer, B2 for number of edges and D2 for rbf out size.
Returns
-------
float32 tensor of shape (B1, D1)
Updated node representations
"""
n_weight = self.node_layer1(n_feat)
new_n_feat = self.cfconv(g, n_weight, rbf_out)
new_n_feat = self.node_layer2(new_n_feat)
return n_feat + new_n_feat
class VEConv(nn.Module):
"""
The Vertex-Edge convolution layer in MGCN which takes both edge & vertex features
in consideration.
Parameters
----------
rbf_dim : int
Size of the RBF layer output
dim : int
Size of intermediate representations, default to be 64.
update_edge : bool
Whether to apply a linear layer to update edge representations, default to be True.
"""
def __init__(self, rbf_dim, dim=64, update_edge=True):
super(VEConv, self).__init__()
self._rbf_dim = rbf_dim
self._dim = dim
self._update_edge = update_edge
self.update_rbf = nn.Sequential(
nn.Linear(self._rbf_dim, self._dim),
nn.Softplus(beta=0.5, threshold=14),
nn.Linear(self._dim, self._dim)
)
self.update_efeat = nn.Linear(self._dim, self._dim)
def forward(self, g, n_feat, e_feat, rbf_out):
"""
Parameters
----------
g : DGLGraph
The graph for performing convolution
n_feat : float32 tensor of shape (B1, D1)
Node features, B1 for number of nodes and D1 for feature size.
e_feat : float32 tensor of shape (B2, D2)
Edge features. B2 for number of edges and D2 for
the edge feature size.
rbf_out : float32 tensor of shape (B2, D3)
The output of RBFLayer, B2 for number of edges and D3 for rbf out size.
Returns
-------
n_feat : float32 tensor
Updated node features.
e_feat : float32 tensor
(Potentially updated) edge features
"""
rbf_out = self.update_rbf(rbf_out)
if self._update_edge:
e_feat = self.update_efeat(e_feat)
g = g.local_var()
g.ndata.update({'n_feat': n_feat})
g.edata.update({'rbf_out': rbf_out, 'e_feat': e_feat})
g.update_all(message_func=[fn.u_mul_e('n_feat', 'rbf_out', 'm_0'),
fn.copy_e('e_feat', 'm_1')],
reduce_func=[fn.sum('m_0', 'n_feat_0'),
fn.sum('m_1', 'n_feat_1')])
n_feat = g.ndata.pop('n_feat_0') + g.ndata.pop('n_feat_1')
return n_feat, e_feat
class MultiLevelInteraction(nn.Module):
"""
The multilevel interaction in the MGCN model.
Parameters
----------
rbf_dim : int
Dimension of the RBF layer output
dim : int
Dimension of intermediate representations
"""
def __init__(self, rbf_dim, dim):
super(MultiLevelInteraction, self).__init__()
self._atom_dim = dim
self.node_layer1 = nn.Linear(dim, dim, bias=True)
self.conv_layer = VEConv(rbf_dim, dim)
self.activation = nn.Softplus(beta=0.5, threshold=14)
self.edge_layer1 = nn.Linear(dim, dim, bias=True)
self.node_out = nn.Sequential(
nn.Linear(dim, dim),
nn.Softplus(beta=0.5, threshold=14),
nn.Linear(dim, dim)
)
def forward(self, g, n_feat, e_feat, rbf_out):
"""
Parameters
----------
g : DGLGraph
The graph for performing convolution
n_feat : float32 tensor of shape (B1, D1)
Node features, B1 for number of nodes and D1 for feature size.
e_feat : float32 tensor of shape (B2, D2)
Edge features. B2 for number of edges and D2 for
the edge feature size.
rbf_out : float32 tensor of shape (B2, D3)
The output of RBFLayer, B2 for number of edges and D3 for rbf out size.
Returns
-------
n_feat : float32 tensor
Updated node representations
e_feat : float32 tensor
Updated edge representations
"""
new_n_feat = self.node_layer1(n_feat)
new_n_feat, e_feat = self.conv_layer(g, new_n_feat, e_feat, rbf_out)
new_n_feat = self.node_out(new_n_feat)
n_feat = n_feat + new_n_feat
e_feat = self.activation(self.edge_layer1(e_feat))
return n_feat, e_feat
# -*- coding:utf-8 -*-
# pylint: disable=C0103, C0111, W0621
"""Implementation of MGCN model"""
import torch
import torch.nn as nn
from .layers import AtomEmbedding, RBFLayer, EdgeEmbedding, \
MultiLevelInteraction
from ...nn.pytorch import SumPooling
from ...contrib.deprecation import deprecated
class MGCNModel(nn.Module):
"""
`Molecular Property Prediction: A Multilevel
Quantum Interactions Modeling Perspective <https://arxiv.org/abs/1906.11081>`__
Parameters
----------
dim : int
Size for embeddings, default to be 128.
width : int
Width in the RBF layer, default to be 1.
cutoff : float
The maximum distance between nodes, default to be 5.0.
edge_dim : int
Size for edge embedding, default to be 128.
out_put_dim: int
Number of target properties to predict, default to be 1.
n_conv : int
Number of convolutional layers, default to be 3.
norm : bool
Whether to perform normalization, default to be False.
atom_ref : Atom embeddings or None
If None, random representation initialization will be used. Otherwise,
they will be used to initialize atom representations. Default to be None.
pre_train : Atom embeddings or None
If None, random representation initialization will be used. Otherwise,
they will be used to initialize atom representations. Default to be None.
"""
@deprecated('Import MGCNPredictor from dgllife.model instead.', 'class')
def __init__(self,
dim=128,
width=1,
cutoff=5.0,
edge_dim=128,
output_dim=1,
n_conv=3,
norm=False,
atom_ref=None,
pre_train=None):
super(MGCNModel, self).__init__()
self._dim = dim
self.output_dim = output_dim
self.edge_dim = edge_dim
self.cutoff = cutoff
self.width = width
self.n_conv = n_conv
self.atom_ref = atom_ref
self.norm = norm
if pre_train is None:
self.embedding_layer = AtomEmbedding(dim)
else:
self.embedding_layer = AtomEmbedding(pre_train=pre_train)
self.rbf_layer = RBFLayer(0, cutoff, width)
self.edge_embedding_layer = EdgeEmbedding(dim=edge_dim)
if atom_ref is not None:
self.e0 = AtomEmbedding(1, pre_train=atom_ref)
self.conv_layers = nn.ModuleList([
MultiLevelInteraction(self.rbf_layer._fan_out, dim)
for i in range(n_conv)
])
self.out_project = nn.Sequential(
nn.Linear(dim * (self.n_conv + 1), 64),
nn.Softplus(beta=1, threshold=20),
nn.Linear(64, output_dim)
)
self.readout = SumPooling()
def set_mean_std(self, mean, std, device="cpu"):
"""Set the mean and std of atom representations for normalization.
Parameters
----------
mean : list or numpy array
The mean of labels
std : list or numpy array
The std of labels
device : str or torch.device
Device for storing the mean and std
"""
self.mean_per_node = torch.tensor(mean, device=device)
self.std_per_node = torch.tensor(std, device=device)
def forward(self, g, atom_types, edge_distances):
"""Predict molecule labels
Parameters
----------
g : DGLGraph
Input DGLGraph for molecule(s)
atom_types : int64 tensor of shape (B1)
Types for atoms in the graph(s), B1 for the number of atoms.
edge_distances : float32 tensor of shape (B2, 1)
Edge distances, B2 for the number of edges.
Returns
-------
prediction : float32 tensor of shape (B, output_dim)
Model prediction for the batch of graphs, B for the number
of graphs, output_dim for the prediction size.
"""
h = self.embedding_layer(atom_types)
e = self.edge_embedding_layer(g, atom_types)
rbf_out = self.rbf_layer(edge_distances)
all_layer_h = [h]
for idx in range(self.n_conv):
h, e = self.conv_layers[idx](g, h, e, rbf_out)
all_layer_h.append(h)
# concat multilevel representations
h = torch.cat(all_layer_h, dim=1)
h = self.out_project(h)
if self.atom_ref is not None:
h_ref = self.e0(atom_types)
h = h + h_ref
if self.norm:
h = h * self.std_per_node + self.mean_per_node
return self.readout(g, h)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment