Unverified Commit 545cc065 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[DGL-LifeSci] WLN for Reaction Center Prediction (#1360)

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update
parent 5a1ef70f
...@@ -10,6 +10,8 @@ with graph neural networks. ...@@ -10,6 +10,8 @@ with graph neural networks.
We provide various functionalities, including but not limited to methods for graph construction, We provide various functionalities, including but not limited to methods for graph construction,
featurization, and evaluation, model architectures, training scripts and pre-trained models. featurization, and evaluation, model architectures, training scripts and pre-trained models.
**For a full list of work implemented in DGL-LifeSci, see [here](examples/README.md).**
## Dependencies ## Dependencies
For the time being, we only support PyTorch. For the time being, we only support PyTorch.
...@@ -19,14 +21,12 @@ Depending on the features you want to use, you may need to manually install the ...@@ -19,14 +21,12 @@ Depending on the features you want to use, you may need to manually install the
- RDKit 2018.09.3 - RDKit 2018.09.3
- We recommend installation with `conda install -c conda-forge rdkit==2018.09.3`. For other installation recipes, - We recommend installation with `conda install -c conda-forge rdkit==2018.09.3`. For other installation recipes,
see the [official documentation](https://www.rdkit.org/docs/Install.html). see the [official documentation](https://www.rdkit.org/docs/Install.html).
- MDTraj - (optional) MDTraj
- We recommend installation with `conda install -c conda-forge mdtraj`. For alternative ways of installation, - We recommend installation with `conda install -c conda-forge mdtraj`. For alternative ways of installation,
see the [official documentation](http://mdtraj.org/1.9.3/installation.html). see the [official documentation](http://mdtraj.org/1.9.3/installation.html).
## Organization ## Organization
For a full list of work implemented in DGL-LifeSci, **see implemented.md**.
``` ```
dgllife dgllife
data data
...@@ -123,8 +123,9 @@ SVG(Draw.MolsToGridImage(mols, molsPerRow=4, subImgSize=(180, 150), useSVG=True) ...@@ -123,8 +123,9 @@ SVG(Draw.MolsToGridImage(mols, molsPerRow=4, subImgSize=(180, 150), useSVG=True)
Below we provide some reference numbers to show how DGL improves the speed of training models per epoch in seconds. Below we provide some reference numbers to show how DGL improves the speed of training models per epoch in seconds.
| Model | Original Implementation | DGL Implementation | Improvement | | Model | Original Implementation | DGL Implementation | Improvement |
| -------------------------- | ----------------------- | ------------------ | ----------- | | ---------------------------------- | ----------------------- | ------------------ | ----------- |
| GCN on Tox21 | 5.5 (DeepChem) | 1.0 | 5.5x | | GCN on Tox21 | 5.5 (DeepChem) | 1.0 | 5.5x |
| AttentiveFP on Aromaticity | 6.0 | 1.2 | 5x | | AttentiveFP on Aromaticity | 6.0 | 1.2 | 5x |
| JTNN on ZINC | 1826 | 743 | 2.5x | | JTNN on ZINC | 1826 | 743 | 2.5x |
| WLN for reaction center prediction | 11657 | 5095 | 2.3x | |
__version__ = '0.1.0' __version__ = '0.2.0'
...@@ -4,3 +4,4 @@ from .csv_dataset import * ...@@ -4,3 +4,4 @@ from .csv_dataset import *
from .pdbbind import * from .pdbbind import *
from .pubchem_aromaticity import * from .pubchem_aromaticity import *
from .tox21 import * from .tox21 import *
from .uspto import *
...@@ -146,7 +146,7 @@ class TencentAlchemyDataset(object): ...@@ -146,7 +146,7 @@ class TencentAlchemyDataset(object):
contest is ongoing. contest is ongoing.
mol_to_graph: callable, str -> DGLGraph mol_to_graph: callable, str -> DGLGraph
A function turning an RDKit molecule instance into a DGLGraph. A function turning an RDKit molecule instance into a DGLGraph.
Default to :func:`dgl.data.chem.mol_to_complete_graph`. Default to :func:`dgllife.utils.mol_to_complete_graph`.
node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for nodes like atoms in a molecule, which can be used to update Featurization for nodes like atoms in a molecule, which can be used to update
ndata for a DGLGraph. By default, we construct graphs where nodes represent atoms ndata for a DGLGraph. By default, we construct graphs where nodes represent atoms
......
...@@ -23,7 +23,7 @@ class PubChemBioAssayAromaticity(MoleculeCSVDataset): ...@@ -23,7 +23,7 @@ class PubChemBioAssayAromaticity(MoleculeCSVDataset):
---------- ----------
smiles_to_graph: callable, str -> DGLGraph smiles_to_graph: callable, str -> DGLGraph
A function turning smiles into a DGLGraph. A function turning smiles into a DGLGraph.
Default to :func:`dgl.data.chem.smiles_to_bigraph`. Default to :func:`dgllife.utils.smiles_to_bigraph`.
node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for nodes like atoms in a molecule, which can be used to update Featurization for nodes like atoms in a molecule, which can be used to update
ndata for a DGLGraph. Default to None. ndata for a DGLGraph. Default to None.
......
...@@ -30,7 +30,7 @@ class Tox21(MoleculeCSVDataset): ...@@ -30,7 +30,7 @@ class Tox21(MoleculeCSVDataset):
---------- ----------
smiles_to_graph: callable, str -> DGLGraph smiles_to_graph: callable, str -> DGLGraph
A function turning smiles into a DGLGraph. A function turning smiles into a DGLGraph.
Default to :func:`dgl.data.chem.smiles_to_bigraph`. Default to :func:`dgllife.utils.smiles_to_bigraph`.
node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for nodes like atoms in a molecule, which can be used to update Featurization for nodes like atoms in a molecule, which can be used to update
ndata for a DGLGraph. Default to None. ndata for a DGLGraph. Default to None.
......
This diff is collapsed.
...@@ -5,3 +5,4 @@ from .gcn import * ...@@ -5,3 +5,4 @@ from .gcn import *
from .mgcn import * from .mgcn import *
from .mpnn import * from .mpnn import *
from .schnet import * from .schnet import *
from .wln import *
"""WLN"""
import dgl.function as fn
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Parameter
__all__ = ['WLN']
class WLNLinear(nn.Module):
"""Linear layer for WLN
Let stddev be
.. math::
\min(\frac{1.0}{\sqrt{in_feats}}, 0.1)
The weight of the linear layer is initialized from a normal distribution
with mean 0 and std as specified in stddev.
Parameters
----------
in_feats : int
Size for the input.
out_feats : int
Size for the output.
bias : bool
Whether bias will be added to the output. Default to True.
"""
def __init__(self, in_feats, out_feats, bias=True):
super(WLNLinear, self).__init__()
self.in_feats = in_feats
self.out_feats = out_feats
self.weight = Parameter(torch.Tensor(out_feats, in_feats))
if bias:
self.bias = Parameter(torch.Tensor(out_feats))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
"""Initialize model parameters."""
stddev = min(1.0 / math.sqrt(self.in_feats), 0.1)
nn.init.normal_(self.weight, std=stddev)
if self.bias is not None:
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
bound = 1 / math.sqrt(fan_in)
nn.init.uniform_(self.bias, -bound, bound)
def forward(self, input):
"""Applies the layer.
Parameters
----------
input : float32 tensor of shape (N, *, in_feats)
N for the number of samples, * for any additional dimensions.
Returns
-------
float32 tensor of shape (N, *, out_feats)
Result of the layer.
"""
return F.linear(input, self.weight, self.bias)
def extra_repr(self):
"""Return a description of the layer."""
return 'in_feats={}, out_feats={}, bias={}'.format(
self.in_feats, self.out_feats, self.bias is not None
)
class WLN(nn.Module):
"""Weisfeiler-Lehman Network (WLN)
WLN is introduced in `Predicting Organic Reaction Outcomes with
Weisfeiler-Lehman Network <https://arxiv.org/abs/1709.04555>`__.
This class performs message passing and updates node representations.
Parameters
----------
node_in_feats : int
Size for the input node features.
edge_in_feats : int
Size for the input edge features.
node_out_feats : int
Size for the output node representations. Default to 300.
n_layers : int
Number of times for message passing. Note that same parameters
are shared across n_layers message passing. Default to 3.
"""
def __init__(self,
node_in_feats,
edge_in_feats,
node_out_feats=300,
n_layers=3):
super(WLN, self).__init__()
self.n_layers = n_layers
self.project_node_in_feats = nn.Sequential(
WLNLinear(node_in_feats, node_out_feats, bias=False),
nn.ReLU()
)
self.project_concatenated_messages = nn.Sequential(
WLNLinear(edge_in_feats + node_out_feats, node_out_feats),
nn.ReLU()
)
self.get_new_node_feats = nn.Sequential(
WLNLinear(2 * node_out_feats, node_out_feats),
nn.ReLU()
)
self.project_edge_messages = WLNLinear(edge_in_feats, node_out_feats, bias=False)
self.project_node_messages = WLNLinear(node_out_feats, node_out_feats, bias=False)
self.project_self = WLNLinear(node_out_feats, node_out_feats, bias=False)
def forward(self, g, node_feats, edge_feats):
"""Performs message passing and updates node representations.
Parameters
----------
g : DGLGraph
DGLGraph for a batch of graphs
node_feats : float32 tensor of shape (V, node_in_feats)
Input node features. V for the number of nodes.
edge_feats : float32 tensor of shape (E, edge_in_feats)
Input edge features. E for the number of edges.
Returns
-------
float32 tensor of shape (V, node_out_feats)
Updated node representations.
"""
node_feats = self.project_node_in_feats(node_feats)
for l in range(self.n_layers):
g = g.local_var()
g.ndata['hv'] = node_feats
g.apply_edges(fn.copy_src('hv', 'he_src'))
concat_edge_feats = torch.cat([g.edata['he_src'], edge_feats], dim=1)
g.edata['he'] = self.project_concatenated_messages(concat_edge_feats)
g.update_all(fn.copy_edge('he', 'm'), fn.sum('m', 'hv_new'))
g = g.local_var()
g.ndata['hv'] = self.project_node_messages(node_feats)
g.edata['he'] = self.project_edge_messages(edge_feats)
g.update_all(fn.u_mul_e('hv', 'he', 'm'), fn.sum('m', 'h_nbr'))
h_self = self.project_self(node_feats) # (V, node_out_feats)
return g.ndata['h_nbr'] * h_self
...@@ -9,3 +9,4 @@ from .schnet_predictor import * ...@@ -9,3 +9,4 @@ from .schnet_predictor import *
from .mgcn_predictor import * from .mgcn_predictor import *
from .mpnn_predictor import * from .mpnn_predictor import *
from .acnn import * from .acnn import *
from .wln_reaction_center import *
"""Weisfeiler-Lehman Network (WLN) for Reaction Center Prediction."""
import dgl.function as fn
import torch
import torch.nn as nn
from ..gnn.wln import WLNLinear, WLN
__all__ = ['WLNReactionCenter']
class WLNContext(nn.Module):
"""Attention-based context computation for each node.
A context vector is computed by taking a weighted sum of node representations,
with weights computed from an attention module.
Parameters
----------
node_in_feats : int
Size for the input node features.
node_pair_in_feats : int
Size for the input features of node pairs.
"""
def __init__(self, node_in_feats, node_pair_in_feats):
super(WLNContext, self).__init__()
self.project_feature_sum = WLNLinear(node_in_feats, node_in_feats, bias=False)
self.project_node_pair_feature = WLNLinear(node_pair_in_feats, node_in_feats)
self.compute_attention = nn.Sequential(
nn.ReLU(),
WLNLinear(node_in_feats, 1),
nn.Sigmoid()
)
def forward(self, batch_complete_graphs, node_feats, feat_sum, node_pair_feat):
"""Compute context vectors for each node.
Parameters
----------
batch_complete_graphs : DGLGraph
A batch of fully connected graphs.
node_feats : float32 tensor of shape (V, node_in_feats)
Input node features. V for the number of nodes.
feat_sum : float32 tensor of shape (E_full, node_in_feats)
Sum of node_feats between each pair of nodes. E_full for the number of
edges in the batch of complete graphs.
node_pair_feat : float32 tensor of shape (E_full, node_pair_in_feats)
Input features for each pair of nodes. E_full for the number of edges in
the batch of complete graphs.
Returns
-------
node_contexts : float32 tensor of shape (V, node_in_feats)
Context vectors for nodes.
"""
with batch_complete_graphs.local_scope():
batch_complete_graphs.ndata['hv'] = node_feats
batch_complete_graphs.edata['a'] = self.compute_attention(
self.project_feature_sum(feat_sum) + \
self.project_node_pair_feature(node_pair_feat)
)
batch_complete_graphs.update_all(
fn.src_mul_edge('hv', 'a', 'm'), fn.sum('m', 'context'))
node_contexts = batch_complete_graphs.ndata.pop('context')
return node_contexts
class WLNReactionCenter(nn.Module):
r"""Weisfeiler-Lehman Network (WLN) for Reaction Center Prediction.
The model is introduced in `Predicting Organic Reaction Outcomes with
Weisfeiler-Lehman Network <https://arxiv.org/abs/1709.04555>`__.
The model uses WLN to update atom representations and then predicts the
score for each pair of atoms to form a bond.
Parameters
----------
node_in_feats : int
Size for the input node features.
edge_in_feats : int
Size for the input edge features.
node_out_feats : int
Size for the output node representations. Default to 300.
node_pair_in_feats : int
Size for the input features of node pairs.
n_layers : int
Number of times for message passing. Note that same parameters
are shared across n_layers message passing. Default to 3.
n_tasks : int
Number of tasks for prediction.
"""
def __init__(self,
node_in_feats,
edge_in_feats,
node_pair_in_feats,
node_out_feats=300,
n_layers=3,
n_tasks=5):
super(WLNReactionCenter, self).__init__()
self.gnn = WLN(node_in_feats=node_in_feats,
edge_in_feats=edge_in_feats,
node_out_feats=node_out_feats,
n_layers=n_layers)
self.context_module = WLNContext(node_in_feats=node_out_feats,
node_pair_in_feats=node_pair_in_feats)
self.project_feature_sum = WLNLinear(node_out_feats, node_out_feats, bias=False)
self.project_node_pair_feature = WLNLinear(node_pair_in_feats, node_out_feats, bias=False)
self.project_context_sum = WLNLinear(node_out_feats, node_out_feats)
self.predict = nn.Sequential(
nn.ReLU(),
WLNLinear(node_out_feats, n_tasks)
)
def forward(self, batch_mol_graphs, batch_complete_graphs,
node_feats, edge_feats, node_pair_feats):
r"""Predict score for each pair of nodes.
Parameters
----------
batch_mol_graphs : DGLGraph
A batch of molecular graphs.
batch_complete_graphs : DGLGraph
A batch of fully connected graphs.
node_feats : float32 tensor of shape (V, node_in_feats)
Input node features. V for the number of nodes.
edge_feats : float32 tensor of shape (E, edge_in_feats)
Input edge features. E for the number of edges.
node_pair_feats : float32 tensor of shape (E_full, node_pair_in_feats)
Input features for each pair of nodes. E_full for the number of edges in
the batch of complete graphs.
Returns
-------
scores : float32 tensor of shape (E_full, 5)
Predicted scores for each pair of atoms to perform one of the following
5 actions in reaction:
* The bond between them gets broken
* Forming a single bond
* Forming a double bond
* Forming a triple bond
* Forming an aromatic bond
biased_scores : float32 tensor of shape (E_full, 5)
Comparing to scores, a bias is added if the pair is for a same atom.
"""
node_feats = self.gnn(batch_mol_graphs, node_feats, edge_feats)
# Compute context vectors for all atoms, which are weighted sum of atom
# representations in all reactants.
with batch_complete_graphs.local_scope():
batch_complete_graphs.ndata['hv'] = node_feats
batch_complete_graphs.apply_edges(fn.u_add_v('hv', 'hv', 'feature_sum'))
feat_sum = batch_complete_graphs.edata.pop('feature_sum')
node_contexts = self.context_module(batch_complete_graphs, node_feats,
feat_sum, node_pair_feats)
# Predict score
with batch_complete_graphs.local_scope():
batch_complete_graphs.ndata['context'] = node_contexts
batch_complete_graphs.apply_edges(fn.u_add_v('context', 'context', 'context_sum'))
scores = self.predict(
self.project_feature_sum(feat_sum) + \
self.project_node_pair_feature(node_pair_feats) + \
self.project_context_sum(batch_complete_graphs.edata['context_sum'])
)
# Masking self loops
nodes = batch_complete_graphs.nodes()
e_ids = batch_complete_graphs.edge_ids(nodes, nodes)
bias = torch.zeros(scores.shape[0], 5).to(scores.device)
bias[e_ids, :] = 1e4
biased_scores = scores - bias
return scores, biased_scores
...@@ -6,7 +6,8 @@ import torch.nn.functional as F ...@@ -6,7 +6,8 @@ import torch.nn.functional as F
from dgl.data.utils import _get_dgl_url, download, get_download_dir, extract_archive from dgl.data.utils import _get_dgl_url, download, get_download_dir, extract_archive
from rdkit import Chem from rdkit import Chem
from ..model import GCNPredictor, GATPredictor, AttentiveFPPredictor, DGMG, DGLJTNNVAE from ..model import GCNPredictor, GATPredictor, AttentiveFPPredictor, DGMG, DGLJTNNVAE, \
WLNReactionCenter
__all__ = ['load_pretrained'] __all__ = ['load_pretrained']
...@@ -18,7 +19,8 @@ URL = { ...@@ -18,7 +19,8 @@ URL = {
'DGMG_ChEMBL_random': 'pre_trained/dgmg_ChEMBL_random.pth', 'DGMG_ChEMBL_random': 'pre_trained/dgmg_ChEMBL_random.pth',
'DGMG_ZINC_canonical': 'pre_trained/dgmg_ZINC_canonical.pth', 'DGMG_ZINC_canonical': 'pre_trained/dgmg_ZINC_canonical.pth',
'DGMG_ZINC_random': 'pre_trained/dgmg_ZINC_random.pth', 'DGMG_ZINC_random': 'pre_trained/dgmg_ZINC_random.pth',
'JTNN_ZINC': 'pre_trained/JTNN_ZINC.pth' 'JTNN_ZINC': 'pre_trained/JTNN_ZINC.pth',
'wln_center_uspto': 'dgllife/pre_trained/wln_center_uspto.pth'
} }
def download_and_load_checkpoint(model_name, model, model_postfix, def download_and_load_checkpoint(model_name, model, model_postfix,
...@@ -72,6 +74,7 @@ def load_pretrained(model_name, log=True): ...@@ -72,6 +74,7 @@ def load_pretrained(model_name, log=True):
* ``'DGMG_ZINC_canonical'`` * ``'DGMG_ZINC_canonical'``
* ``'DGMG_ZINC_random'`` * ``'DGMG_ZINC_random'``
* ``'JTNN_ZINC'`` * ``'JTNN_ZINC'``
* ``'wln_center_uspto'``
log : bool log : bool
Whether to print progress for model loading Whether to print progress for model loading
...@@ -134,4 +137,12 @@ def load_pretrained(model_name, log=True): ...@@ -134,4 +137,12 @@ def load_pretrained(model_name, log=True):
hidden_size=450, hidden_size=450,
latent_size=56) latent_size=56)
elif model_name == 'wln_center_uspto':
model = WLNReactionCenter(node_in_feats=82,
edge_in_feats=6,
node_pair_in_feats=10,
node_out_feats=300,
n_layers=3,
n_tasks=5)
return download_and_load_checkpoint(model_name, model, URL[model_name], log=log) return download_and_load_checkpoint(model_name, model, URL[model_name], log=log)
...@@ -14,6 +14,8 @@ __all__ = ['one_hot_encoding', ...@@ -14,6 +14,8 @@ __all__ = ['one_hot_encoding',
'atom_degree', 'atom_degree',
'atom_total_degree_one_hot', 'atom_total_degree_one_hot',
'atom_total_degree', 'atom_total_degree',
'atom_explicit_valence_one_hot',
'atom_explicit_valence',
'atom_implicit_valence_one_hot', 'atom_implicit_valence_one_hot',
'atom_implicit_valence', 'atom_implicit_valence',
'atom_hybridization_one_hot', 'atom_hybridization_one_hot',
...@@ -25,6 +27,8 @@ __all__ = ['one_hot_encoding', ...@@ -25,6 +27,8 @@ __all__ = ['one_hot_encoding',
'atom_num_radical_electrons', 'atom_num_radical_electrons',
'atom_is_aromatic_one_hot', 'atom_is_aromatic_one_hot',
'atom_is_aromatic', 'atom_is_aromatic',
'atom_is_in_ring_one_hot',
'atom_is_in_ring',
'atom_chiral_tag_one_hot', 'atom_chiral_tag_one_hot',
'atom_mass', 'atom_mass',
'ConcatFeaturizer', 'ConcatFeaturizer',
...@@ -224,8 +228,45 @@ def atom_total_degree(atom): ...@@ -224,8 +228,45 @@ def atom_total_degree(atom):
""" """
return [atom.GetTotalDegree()] return [atom.GetTotalDegree()]
def atom_explicit_valence_one_hot(atom, allowable_set=None, encode_unknown=False):
"""One hot encoding for the explicit valence of an aotm.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
allowable_set : list of int
Atom explicit valences to consider. Default: ``1`` - ``6``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
Returns
-------
list
List of boolean values where at most one value is True.
"""
if allowable_set is None:
allowable_set = list(range(1, 7))
return one_hot_encoding(atom.GetExplicitValence(), allowable_set, encode_unknown)
def atom_explicit_valence(atom):
"""Get the explicit valence of an atom.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
Returns
-------
list
List containing one int only.
"""
return [atom.GetExplicitValence()]
def atom_implicit_valence_one_hot(atom, allowable_set=None, encode_unknown=False): def atom_implicit_valence_one_hot(atom, allowable_set=None, encode_unknown=False):
"""One hot encoding for the implicit valences of an atom. """One hot encoding for the implicit valence of an atom.
Parameters Parameters
---------- ----------
...@@ -437,6 +478,43 @@ def atom_is_aromatic(atom): ...@@ -437,6 +478,43 @@ def atom_is_aromatic(atom):
""" """
return [atom.GetIsAromatic()] return [atom.GetIsAromatic()]
def atom_is_in_ring_one_hot(atom, allowable_set=None, encode_unknown=False):
"""One hot encoding for whether the atom is in ring.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
allowable_set : list of bool
Conditions to consider. Default: ``False`` and ``True``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
Returns
-------
list
List of boolean values where at most one value is True.
"""
if allowable_set is None:
allowable_set = [False, True]
return one_hot_encoding(atom.IsInRing(), allowable_set, encode_unknown)
def atom_is_in_ring(atom):
"""Get whether the atom is in ring.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
Returns
-------
list
List containing one bool only.
"""
return [atom.IsInRing()]
def atom_chiral_tag_one_hot(atom, allowable_set=None, encode_unknown=False): def atom_chiral_tag_one_hot(atom, allowable_set=None, encode_unknown=False):
"""One hot encoding for the chiral tag of an atom. """One hot encoding for the chiral tag of an atom.
......
...@@ -18,7 +18,7 @@ __all__ = ['mol_to_graph', ...@@ -18,7 +18,7 @@ __all__ = ['mol_to_graph',
'mol_to_complete_graph', 'mol_to_complete_graph',
'k_nearest_neighbors'] 'k_nearest_neighbors']
def mol_to_graph(mol, graph_constructor, node_featurizer, edge_featurizer): def mol_to_graph(mol, graph_constructor, node_featurizer, edge_featurizer, canonical_atom_order):
"""Convert an RDKit molecule object into a DGLGraph and featurize for it. """Convert an RDKit molecule object into a DGLGraph and featurize for it.
Parameters Parameters
...@@ -33,14 +33,18 @@ def mol_to_graph(mol, graph_constructor, node_featurizer, edge_featurizer): ...@@ -33,14 +33,18 @@ def mol_to_graph(mol, graph_constructor, node_featurizer, edge_featurizer):
edge_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict edge_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for edges like bonds in a molecule, which can be used to Featurization for edges like bonds in a molecule, which can be used to
update edata for a DGLGraph. update edata for a DGLGraph.
canonical_atom_order : bool
Whether to use a canonical order of atoms returned by RDKit. Setting it
to true might change the order of atoms in the graph constructed.
Returns Returns
------- -------
g : DGLGraph g : DGLGraph
Converted DGLGraph for the molecule Converted DGLGraph for the molecule
""" """
new_order = rdmolfiles.CanonicalRankAtoms(mol) if canonical_atom_order:
mol = rdmolops.RenumberAtoms(mol, new_order) new_order = rdmolfiles.CanonicalRankAtoms(mol)
mol = rdmolops.RenumberAtoms(mol, new_order)
g = graph_constructor(mol) g = graph_constructor(mol)
if node_featurizer is not None: if node_featurizer is not None:
...@@ -103,7 +107,8 @@ def construct_bigraph_from_mol(mol, add_self_loop=False): ...@@ -103,7 +107,8 @@ def construct_bigraph_from_mol(mol, add_self_loop=False):
def mol_to_bigraph(mol, add_self_loop=False, def mol_to_bigraph(mol, add_self_loop=False,
node_featurizer=None, node_featurizer=None,
edge_featurizer=None): edge_featurizer=None,
canonical_atom_order=True):
"""Convert an RDKit molecule object into a bi-directed DGLGraph and featurize for it. """Convert an RDKit molecule object into a bi-directed DGLGraph and featurize for it.
Parameters Parameters
...@@ -118,6 +123,10 @@ def mol_to_bigraph(mol, add_self_loop=False, ...@@ -118,6 +123,10 @@ def mol_to_bigraph(mol, add_self_loop=False,
edge_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict edge_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for edges like bonds in a molecule, which can be used to update Featurization for edges like bonds in a molecule, which can be used to update
edata for a DGLGraph. Default to None. edata for a DGLGraph. Default to None.
canonical_atom_order : bool
Whether to use a canonical order of atoms returned by RDKit. Setting it
to true might change the order of atoms in the graph constructed. Default
to True.
Returns Returns
------- -------
...@@ -125,11 +134,12 @@ def mol_to_bigraph(mol, add_self_loop=False, ...@@ -125,11 +134,12 @@ def mol_to_bigraph(mol, add_self_loop=False,
Bi-directed DGLGraph for the molecule Bi-directed DGLGraph for the molecule
""" """
return mol_to_graph(mol, partial(construct_bigraph_from_mol, add_self_loop=add_self_loop), return mol_to_graph(mol, partial(construct_bigraph_from_mol, add_self_loop=add_self_loop),
node_featurizer, edge_featurizer) node_featurizer, edge_featurizer, canonical_atom_order)
def smiles_to_bigraph(smiles, add_self_loop=False, def smiles_to_bigraph(smiles, add_self_loop=False,
node_featurizer=None, node_featurizer=None,
edge_featurizer=None): edge_featurizer=None,
canonical_atom_order=True):
"""Convert a SMILES into a bi-directed DGLGraph and featurize for it. """Convert a SMILES into a bi-directed DGLGraph and featurize for it.
Parameters Parameters
...@@ -144,6 +154,10 @@ def smiles_to_bigraph(smiles, add_self_loop=False, ...@@ -144,6 +154,10 @@ def smiles_to_bigraph(smiles, add_self_loop=False,
edge_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict edge_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for edges like bonds in a molecule, which can be used to update Featurization for edges like bonds in a molecule, which can be used to update
edata for a DGLGraph. Default to None. edata for a DGLGraph. Default to None.
canonical_atom_order : bool
Whether to use a canonical order of atoms returned by RDKit. Setting it
to true might change the order of atoms in the graph constructed. Default
to True.
Returns Returns
------- -------
...@@ -151,7 +165,8 @@ def smiles_to_bigraph(smiles, add_self_loop=False, ...@@ -151,7 +165,8 @@ def smiles_to_bigraph(smiles, add_self_loop=False,
Bi-directed DGLGraph for the molecule Bi-directed DGLGraph for the molecule
""" """
mol = Chem.MolFromSmiles(smiles) mol = Chem.MolFromSmiles(smiles)
return mol_to_bigraph(mol, add_self_loop, node_featurizer, edge_featurizer) return mol_to_bigraph(mol, add_self_loop, node_featurizer,
edge_featurizer, canonical_atom_order)
def construct_complete_graph_from_mol(mol, add_self_loop=False): def construct_complete_graph_from_mol(mol, add_self_loop=False):
"""Construct a complete graph with topology only for the molecule """Construct a complete graph with topology only for the molecule
...@@ -174,26 +189,20 @@ def construct_complete_graph_from_mol(mol, add_self_loop=False): ...@@ -174,26 +189,20 @@ def construct_complete_graph_from_mol(mol, add_self_loop=False):
g : DGLGraph g : DGLGraph
Empty complete graph topology of the molecule Empty complete graph topology of the molecule
""" """
g = DGLGraph()
num_atoms = mol.GetNumAtoms() num_atoms = mol.GetNumAtoms()
g.add_nodes(num_atoms) edge_list = []
for i in range(num_atoms):
if add_self_loop: for j in range(num_atoms):
g.add_edges( if i != j or add_self_loop:
[i for i in range(num_atoms) for j in range(num_atoms)], edge_list.append((i, j))
[j for i in range(num_atoms) for j in range(num_atoms)]) g = DGLGraph(edge_list)
else:
g.add_edges(
[i for i in range(num_atoms) for j in range(num_atoms - 1)], [
j for i in range(num_atoms)
for j in range(num_atoms) if i != j
])
return g return g
def mol_to_complete_graph(mol, add_self_loop=False, def mol_to_complete_graph(mol, add_self_loop=False,
node_featurizer=None, node_featurizer=None,
edge_featurizer=None): edge_featurizer=None,
canonical_atom_order=True):
"""Convert an RDKit molecule into a complete DGLGraph and featurize for it. """Convert an RDKit molecule into a complete DGLGraph and featurize for it.
Parameters Parameters
...@@ -208,6 +217,10 @@ def mol_to_complete_graph(mol, add_self_loop=False, ...@@ -208,6 +217,10 @@ def mol_to_complete_graph(mol, add_self_loop=False,
edge_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict edge_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for edges like bonds in a molecule, which can be used to update Featurization for edges like bonds in a molecule, which can be used to update
edata for a DGLGraph. Default to None. edata for a DGLGraph. Default to None.
canonical_atom_order : bool
Whether to use a canonical order of atoms returned by RDKit. Setting it
to true might change the order of atoms in the graph constructed. Default
to True.
Returns Returns
------- -------
...@@ -215,11 +228,12 @@ def mol_to_complete_graph(mol, add_self_loop=False, ...@@ -215,11 +228,12 @@ def mol_to_complete_graph(mol, add_self_loop=False,
Complete DGLGraph for the molecule Complete DGLGraph for the molecule
""" """
return mol_to_graph(mol, partial(construct_complete_graph_from_mol, add_self_loop=add_self_loop), return mol_to_graph(mol, partial(construct_complete_graph_from_mol, add_self_loop=add_self_loop),
node_featurizer, edge_featurizer) node_featurizer, edge_featurizer, canonical_atom_order)
def smiles_to_complete_graph(smiles, add_self_loop=False, def smiles_to_complete_graph(smiles, add_self_loop=False,
node_featurizer=None, node_featurizer=None,
edge_featurizer=None): edge_featurizer=None,
canonical_atom_order=True):
"""Convert a SMILES into a complete DGLGraph and featurize for it. """Convert a SMILES into a complete DGLGraph and featurize for it.
Parameters Parameters
...@@ -234,6 +248,10 @@ def smiles_to_complete_graph(smiles, add_self_loop=False, ...@@ -234,6 +248,10 @@ def smiles_to_complete_graph(smiles, add_self_loop=False,
edge_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict edge_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for edges like bonds in a molecule, which can be used to update Featurization for edges like bonds in a molecule, which can be used to update
edata for a DGLGraph. Default to None. edata for a DGLGraph. Default to None.
canonical_atom_order : bool
Whether to use a canonical order of atoms returned by RDKit. Setting it
to true might change the order of atoms in the graph constructed. Default
to True.
Returns Returns
------- -------
...@@ -241,7 +259,8 @@ def smiles_to_complete_graph(smiles, add_self_loop=False, ...@@ -241,7 +259,8 @@ def smiles_to_complete_graph(smiles, add_self_loop=False,
Complete DGLGraph for the molecule Complete DGLGraph for the molecule
""" """
mol = Chem.MolFromSmiles(smiles) mol = Chem.MolFromSmiles(smiles)
return mol_to_complete_graph(mol, add_self_loop, node_featurizer, edge_featurizer) return mol_to_complete_graph(mol, add_self_loop, node_featurizer,
edge_featurizer, canonical_atom_order)
def k_nearest_neighbors(coordinates, neighbor_cutoff, max_num_neighbors): def k_nearest_neighbors(coordinates, neighbor_cutoff, max_num_neighbors):
"""Find k nearest neighbors for each atom based on the 3D coordinates and """Find k nearest neighbors for each atom based on the 3D coordinates and
......
# Work Implemented in DGL-LifeSci # Work Implemented in DGL-LifeSci
We provide various examples across 3 applications -- property prediction, generative models and protein-ligand binding affinity prediction.
## Datasets/Benchmarks ## Datasets/Benchmarks
- MoleculeNet: A Benchmark for Molecular Machine Learning [[paper]](https://arxiv.org/abs/1703.00564), [[website]](http://moleculenet.ai/) - MoleculeNet: A Benchmark for Molecular Machine Learning [[paper]](https://arxiv.org/abs/1703.00564), [[website]](http://moleculenet.ai/)
- [Tox21 with DGL](dgllife/data/tox21.py) - [Tox21 with DGL](../dgllife/data/tox21.py)
- [PDBBind with DGL](dgllife/data/pdbbind.py) - [PDBBind with DGL](../dgllife/data/pdbbind.py)
- Alchemy: A Quantum Chemistry Dataset for Benchmarking AI Models [[paper]](https://arxiv.org/abs/1906.09427), [[github]](https://github.com/tencent-alchemy/Alchemy) - Alchemy: A Quantum Chemistry Dataset for Benchmarking AI Models [[paper]](https://arxiv.org/abs/1906.09427), [[github]](https://github.com/tencent-alchemy/Alchemy)
- [Alchemy with DGL](dgllife/data/alchemy.py) - [Alchemy with DGL](../dgllife/data/alchemy.py)
## Property Prediction ## Property Prediction
- Semi-Supervised Classification with Graph Convolutional Networks (GCN) [[paper]](https://arxiv.org/abs/1609.02907), [[github]](https://github.com/tkipf/gcn) - Semi-Supervised Classification with Graph Convolutional Networks (GCN) [[paper]](https://arxiv.org/abs/1609.02907), [[github]](https://github.com/tkipf/gcn)
- [GCN-Based Predictor with DGL](dgllife/model/model_zoo/gcn_predictor.py) - [GCN-Based Predictor with DGL](../dgllife/model/model_zoo/gcn_predictor.py)
- [Example for Molecule Classification](examples/property_prediction/classification.py) - [Example for Molecule Classification](property_prediction/classification.py)
- Graph Attention Networks (GAT) [[paper]](https://arxiv.org/abs/1710.10903), [[github]](https://github.com/PetarV-/GAT) - Graph Attention Networks (GAT) [[paper]](https://arxiv.org/abs/1710.10903), [[github]](https://github.com/PetarV-/GAT)
- [GAT-Based Predictor with DGL](dgllife/model/model_zoo/gat_predictor.py) - [GAT-Based Predictor with DGL](../dgllife/model/model_zoo/gat_predictor.py)
- [Example for Molecule Classification](examples/property_prediction/classification.py) - [Example for Molecule Classification](property_prediction/classification.py)
- SchNet: A continuous-filter convolutional neural network for modeling quantum interactions [[paper]](https://arxiv.org/abs/1706.08566), [[github]](https://github.com/atomistic-machine-learning/SchNet) - SchNet: A continuous-filter convolutional neural network for modeling quantum interactions [[paper]](https://arxiv.org/abs/1706.08566), [[github]](https://github.com/atomistic-machine-learning/SchNet)
- [SchNet with DGL](dgllife/model/model_zoo/schnet_predictor.py) - [SchNet with DGL](../dgllife/model/model_zoo/schnet_predictor.py)
- [Example for Molecule Regression](examples/property_prediction/regression.py) - [Example for Molecule Regression](property_prediction/regression.py)
- Molecular Property Prediction: A Multilevel Quantum Interactions Modeling Perspective (MGCN) [[paper]](https://arxiv.org/abs/1906.11081) - Molecular Property Prediction: A Multilevel Quantum Interactions Modeling Perspective (MGCN) [[paper]](https://arxiv.org/abs/1906.11081)
- [MGCN with DGL](dgllife/model/model_zoo/mgcn_predictor.py) - [MGCN with DGL](../dgllife/model/model_zoo/mgcn_predictor.py)
- [Example for Molecule Regression](examples/property_prediction/regression.py) - [Example for Molecule Regression](property_prediction/regression.py)
- Neural Message Passing for Quantum Chemistry (MPNN) [[paper]](https://arxiv.org/abs/1704.01212), [[github]](https://github.com/brain-research/mpnn) - Neural Message Passing for Quantum Chemistry (MPNN) [[paper]](https://arxiv.org/abs/1704.01212), [[github]](https://github.com/brain-research/mpnn)
- [MPNN with DGL](dgllife/model/model_zoo/mpnn_predictor.py) - [MPNN with DGL](../dgllife/model/model_zoo/mpnn_predictor.py)
- [Example for Molecule Regression](examples/property_prediction/regression.py) - [Example for Molecule Regression](property_prediction/regression.py)
- Pushing the Boundaries of Molecular Representation for Drug Discovery with the Graph Attention Mechanism (AttentiveFP) [[paper]](https://pubs.acs.org/doi/abs/10.1021/acs.jmedchem.9b00959) - Pushing the Boundaries of Molecular Representation for Drug Discovery with the Graph Attention Mechanism (AttentiveFP) [[paper]](https://pubs.acs.org/doi/abs/10.1021/acs.jmedchem.9b00959)
- [AttentiveFP with DGL](dgllife/model/model_zoo/attentivefp_predictor.py) - [AttentiveFP with DGL](../dgllife/model/model_zoo/attentivefp_predictor.py)
- [Example for Molecule Regression](examples/property_prediction/regression.py) - [Example for Molecule Regression](property_prediction/regression.py)
## Generative Models ## Generative Models
- Learning Deep Generative Models of Graphs (DGMG) [[paper]](https://arxiv.org/abs/1803.03324) - Learning Deep Generative Models of Graphs (DGMG) [[paper]](https://arxiv.org/abs/1803.03324)
- [DGMG with DGL](dgllife/model/model_zoo/dgmg.py) - [DGMG with DGL](../dgllife/model/model_zoo/dgmg.py)
- [Example Training Script](examples/generative_models/dgmg) - [Example Training Script](generative_models/dgmg)
- Junction Tree Variational Autoencoder for Molecular Graph Generation (JTNN) [[paper]](https://arxiv.org/abs/1802.04364) - Junction Tree Variational Autoencoder for Molecular Graph Generation (JTNN) [[paper]](https://arxiv.org/abs/1802.04364)
- [JTNN with DGL](dgllife/model/model_zoo/jtnn) - [JTNN with DGL](../dgllife/model/model_zoo/jtnn)
- [Example Training Script](examples/generative_models/jtnn) - [Example Training Script](generative_models/jtnn)
## Binding Affinity Prediction ## Binding Affinity Prediction
- Atomic Convolutional Networks for Predicting Protein-Ligand Binding Affinity (ACNN) [[paper]](https://arxiv.org/abs/1703.10603), [[github]](https://github.com/deepchem/deepchem/tree/master/contrib/atomicconv) - Atomic Convolutional Networks for Predicting Protein-Ligand Binding Affinity (ACNN) [[paper]](https://arxiv.org/abs/1703.10603), [[github]](https://github.com/deepchem/deepchem/tree/master/contrib/atomicconv)
- [ACNN with DGL](dgllife/model/model_zoo/acnn.py) - [ACNN with DGL](../dgllife/model/model_zoo/acnn.py)
- [Example Training Script](examples/binding_affinity_prediction) - [Example Training Script](binding_affinity_prediction)
## Reaction Prediction
- A graph-convolutional neural network model for the prediction of chemical reactivity [[paper]](https://pubs.rsc.org/en/content/articlelanding/2019/sc/c8sc04228d#!divAbstract), [[github]](https://github.com/connorcoley/rexgen_direct)
- An earlier version was published in NeurIPS 2017 as "Predicting Organic Reaction Outcomes with Weisfeiler-Lehman Network" [[paper]](https://arxiv.org/abs/1709.04555)
- [WLN with DGL for Reaction Center Prediction](../dgllife/model/model_zoo/wln_reaction_center.py)
- [Example Script](reaction_prediction/rexgen_direct)
...@@ -83,12 +83,14 @@ def load_dataset_for_regression(args): ...@@ -83,12 +83,14 @@ def load_dataset_for_regression(args):
def collate_molgraphs(data): def collate_molgraphs(data):
"""Batching a list of datapoints for dataloader. """Batching a list of datapoints for dataloader.
Parameters Parameters
---------- ----------
data : list of 3-tuples or 4-tuples. data : list of 3-tuples or 4-tuples.
Each tuple is for a single datapoint, consisting of Each tuple is for a single datapoint, consisting of
a SMILES, a DGLGraph, all-task labels and optionally a SMILES, a DGLGraph, all-task labels and optionally
a binary mask indicating the existence of labels. a binary mask indicating the existence of labels.
Returns Returns
------- -------
smiles : list smiles : list
......
# A graph-convolutional neural network model for the prediction of chemical reactivity
- [paper in Chemical Science](https://pubs.rsc.org/en/content/articlelanding/2019/sc/c8sc04228d#!divAbstract)
- [authors' code](https://github.com/connorcoley/rexgen_direct)
An earlier version of the work was published in NeurIPS 2017 as
["Predicting Organic Reaction Outcomes with Weisfeiler-Lehman Network"](https://arxiv.org/abs/1709.04555) with some
slight difference in modeling.
## Dataset
The example by default works with reactions from USPTO (United States Patent and Trademark) granted patents,
collected by Lowe [1]. After removing duplicates and erroneous reactions, the authors obtain a set of 480K reactions.
The dataset is divided into 400K, 40K, and 40K for training, validation and test.
## Reaction Center Prediction
### Modeling
Reaction centers refer to the pairs of atoms that lose/form a bond in the reactions. A Graph neural network
(Weisfeiler-Lehman Network in this case) is trained to update the representations of all atoms. Then we combine
pairs of atom representations to predict the likelihood for the corresponding atoms to form/lose a bond.
For evaluation, we select pairs of atoms with top-k scores for each reaction and compute the proportion of reactions
whose reaction centers have all been selected.
### Training with Default Options
We use GPU whenever possible. To train the model with default options, simply do
```bash
python find_reaction_center.py
```
Once the training process starts, the progress will be printed out in the terminal as follows:
```bash
Epoch 1/50, iter 8150/20452 | time/minibatch 0.0260 | loss 8.4788 | grad norm 12.9927
Epoch 1/50, iter 8200/20452 | time/minibatch 0.0260 | loss 8.6722 | grad norm 14.0833
```
After an epoch of training is completed, we evaluate the model on the validation set and
print the evaluation results as follows:
```bash
Epoch 4/50, validation | acc@10 0.8213 | acc@20 0.9016 |
```
By default, we store the model per 10000 iterations in `center_results`.
**Speedup**: For an epoch of training, our implementation takes about 5095s for the first epoch while the authors'
implementation takes about 11657s, which is roughly a speedup by 2.3x.
For model evaluation, we can choose whether to exclude reactants not contributing atoms to the product
(e.g. reagents and solvents) in top-k atom pair selection, which will make the task easier.
For the easier evaluation, do
```bash
python find_reaction_center.py --easy
```
A summary of the model performance is as follows:
| Item | Top 6 accuracy | Top 8 accuracy | Top 10 accuracy |
| --------------- | -------------- | -------------- | --------------- |
| Paper | 89.8 | 92.0 | 93.3 |
| Hard evaluation | 86.5 | 89.6 | 91.2 |
| Easy evaluation | 88.9 | 92.0 | 93.5 |
### Pre-trained Model
We provide a pre-trained model so users do not need to train from scratch. To evaluate the pre-trained model, simply do
```bash
python find_reaction_center.py -p
```
### Adapting to a new dataset.
New datasets should be processed such that each line corresponds to the SMILES for a reaction like below:
```bash
[CH3:14][NH2:15].[N+:1](=[O:2])([O-:3])[c:4]1[cH:5][c:6]([C:7](=[O:8])[OH:9])[cH:10][cH:11][c:12]1[Cl:13].[OH2:16]>>[N+:1](=[O:2])([O-:3])[c:4]1[cH:5][c:6]([C:7](=[O:8])[OH:9])[cH:10][cH:11][c:12]1[NH:15][CH3:14]
```
The reactants are placed before `>>` and the product is placed after `>>`. The reactants are separated by `.`.
In addition, atom mapping information is provided.
You can then train a model on new datasets with
```bash
python find_reaction_center.py --train-path X --val-path Y --test-path Z
```
where `X`, `Y`, `Z` are paths to the new training/validation/test set as described above.
## References
[1] D. M.Lowe, Patent reaction extraction: downloads,
https://bitbucket.org/dan2097/patent-reaction-extraction/downloads, 2014.
# Configuration for reaction center identification
reaction_center_config = {
'batch_size': 20,
'hidden_size': 300,
'max_norm': 5.0,
'node_in_feats': 82,
'edge_in_feats': 6,
'node_pair_in_feats': 10,
'node_out_feats': 300,
'n_layers': 3,
'n_tasks': 5,
'lr': 0.001,
'num_epochs': 25,
'print_every': 50,
'decay_every': 10000, # Learning rate decay
'lr_decay_factor': 0.9,
'top_ks': [6, 8, 10],
'max_k': 80
}
import numpy as np
import time
import torch
from dgllife.data import USPTO, WLNReactionDataset
from dgllife.model import WLNReactionCenter, load_pretrained
from torch.nn import BCEWithLogitsLoss
from torch.nn.utils import clip_grad_norm_
from torch.optim import Adam
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader
from utils import setup, collate, reaction_center_prediction, \
rough_eval_on_a_loader, reaction_center_final_eval
def main(args):
setup(args)
if args['train_path'] is None:
train_set = USPTO('train')
else:
train_set = WLNReactionDataset(raw_file_path=args['train_path'],
mol_graph_path='train.bin')
if args['val_path'] is None:
val_set = USPTO('val')
else:
val_set = WLNReactionDataset(raw_file_path=args['val_path'],
mol_graph_path='val.bin')
if args['test_path'] is None:
test_set = USPTO('test')
else:
test_set = WLNReactionDataset(raw_file_path=args['test_path'],
mol_graph_path='test.bin')
train_loader = DataLoader(train_set, batch_size=args['batch_size'],
collate_fn=collate, shuffle=True)
val_loader = DataLoader(val_set, batch_size=args['batch_size'],
collate_fn=collate, shuffle=False)
test_loader = DataLoader(test_set, batch_size=args['batch_size'],
collate_fn=collate, shuffle=False)
if args['pre_trained']:
model = load_pretrained('wln_center_uspto').to(args['device'])
args['num_epochs'] = 0
else:
model = WLNReactionCenter(node_in_feats=args['node_in_feats'],
edge_in_feats=args['edge_in_feats'],
node_pair_in_feats=args['node_pair_in_feats'],
node_out_feats=args['node_out_feats'],
n_layers=args['n_layers'],
n_tasks=args['n_tasks']).to(args['device'])
criterion = BCEWithLogitsLoss(reduction='sum')
optimizer = Adam(model.parameters(), lr=args['lr'])
scheduler = StepLR(optimizer, step_size=args['decay_every'], gamma=args['lr_decay_factor'])
total_iter = 0
grad_norm_sum = 0
loss_sum = 0
dur = []
for epoch in range(args['num_epochs']):
t0 = time.time()
for batch_id, batch_data in enumerate(train_loader):
total_iter += 1
batch_reactions, batch_graph_edits, batch_mols, batch_mol_graphs, \
batch_complete_graphs, batch_atom_pair_labels = batch_data
labels = batch_atom_pair_labels.to(args['device'])
pred, biased_pred = reaction_center_prediction(
args['device'], model, batch_mol_graphs, batch_complete_graphs)
loss = criterion(pred, labels) / len(batch_reactions)
loss_sum += loss.cpu().detach().data.item()
optimizer.zero_grad()
loss.backward()
grad_norm = clip_grad_norm_(model.parameters(), args['max_norm'])
grad_norm_sum += grad_norm
optimizer.step()
scheduler.step()
if total_iter % args['print_every'] == 0:
progress = 'Epoch {:d}/{:d}, iter {:d}/{:d} | time/minibatch {:.4f} | ' \
'loss {:.4f} | grad norm {:.4f}'.format(
epoch + 1, args['num_epochs'], batch_id + 1, len(train_loader),
(np.sum(dur) + time.time() - t0) / total_iter, loss_sum / args['print_every'],
grad_norm_sum / args['print_every'])
grad_norm_sum = 0
loss_sum = 0
print(progress)
if total_iter % args['decay_every'] == 0:
torch.save(model.state_dict(), args['result_path'] + '/model.pkl')
dur.append(time.time() - t0)
print('Epoch {:d}/{:d}, validation '.format(epoch + 1, args['num_epochs']) + \
rough_eval_on_a_loader(args, model, val_loader))
del train_loader
del val_loader
del train_set
del val_set
print('Evaluation on the test set.')
test_result = reaction_center_final_eval(args, model, test_loader, args['easy'])
print(test_result)
with open(args['result_path'] + '/results.txt', 'w') as f:
f.write(test_result)
if __name__ == '__main__':
from argparse import ArgumentParser
from configure import reaction_center_config
parser = ArgumentParser(description='Reaction Center Identification')
parser.add_argument('--result-path', type=str, default='center_results',
help='Path to training results')
parser.add_argument('--train-path', type=str, default=None,
help='Path to a new training set. '
'If None, we will use the default training set in USPTO.')
parser.add_argument('--val-path', type=str, default=None,
help='Path to a new validation set. '
'If None, we will use the default validation set in USPTO.')
parser.add_argument('--test-path', type=str, default=None,
help='Path to a new test set.'
'If None, we will use the default test set in USPTO.')
parser.add_argument('-p', '--pre-trained', action='store_true', default=False,
help='If true, we will directly evaluate a '
'pretrained model on the test set.')
parser.add_argument('--easy', action='store_true', default=False,
help='Whether to exclude reactants not contributing atoms to the '
'product in top-k atom pair selection, which will make the '
'task easier.')
args = parser.parse_args().__dict__
args.update(reaction_center_config)
main(args)
import dgl
import errno
import numpy as np
import os
import random
import torch
from collections import defaultdict
from rdkit import Chem
def mkdir_p(path):
"""Create a folder for the given path.
Parameters
----------
path: str
Folder to create
"""
try:
os.makedirs(path)
print('Created directory {}'.format(path))
except OSError as exc:
if exc.errno == errno.EEXIST and os.path.isdir(path):
print('Directory {} already exists.'.format(path))
else:
raise
def setup(args, seed=0):
"""Setup for the experiment:
1. Decide whether to use CPU or GPU for training
2. Fix random seed for python, NumPy and PyTorch.
Parameters
----------
seed : int
Random seed to use.
Returns
-------
args
Updated configuration
"""
assert args['max_k'] >= max(args['top_ks']), \
'Expect max_k to be no smaller than the possible options ' \
'of top_ks, got {:d} and {:d}'.format(args['max_k'], max(args['top_ks']))
if torch.cuda.is_available():
args['device'] = 'cuda:0'
else:
args['device'] = 'cpu'
# Set random seed
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
mkdir_p(args['result_path'])
return args
def collate(data):
"""Collate multiple datapoints
Parameters
----------
data : list of 7-tuples
Each tuple is for a single datapoint, consisting of
a reaction, graph edits in the reaction, an RDKit molecule instance for all reactants,
a DGLGraph for all reactants, a complete graph for all reactants, the features for each
pair of atoms and the labels for each pair of atoms.
Returns
-------
reactions : list of str
List of reactions.
graph_edits : list of str
List of graph edits in the reactions.
mols : list of rdkit.Chem.rdchem.Mol
List of RDKit molecule instances for the reactants.
batch_mol_graphs : DGLGraph
DGLGraph for a batch of molecular graphs.
batch_complete_graphs : DGLGraph
DGLGraph for a batch of complete graphs.
batch_atom_pair_labels : float32 tensor of shape (V, 10)
Labels of atom pairs in the batch of graphs.
"""
reactions, graph_edits, mols, mol_graphs, complete_graphs, \
atom_pair_feats, atom_pair_labels = map(list, zip(*data))
batch_mol_graphs = dgl.batch(mol_graphs)
batch_mol_graphs.set_n_initializer(dgl.init.zero_initializer)
batch_mol_graphs.set_e_initializer(dgl.init.zero_initializer)
batch_complete_graphs = dgl.batch(complete_graphs)
batch_complete_graphs.set_n_initializer(dgl.init.zero_initializer)
batch_complete_graphs.set_e_initializer(dgl.init.zero_initializer)
batch_complete_graphs.edata['feats'] = torch.cat(atom_pair_feats, dim=0)
batch_atom_pair_labels = torch.cat(atom_pair_labels, dim=0)
return reactions, graph_edits, mols, batch_mol_graphs, \
batch_complete_graphs, batch_atom_pair_labels
def reaction_center_prediction(device, model, mol_graphs, complete_graphs):
"""Perform a soft prediction on reaction center.
Parameters
----------
device : str
Device to use for computation, e.g. 'cpu', 'cuda:0'
model : nn.Module
Model for prediction.
mol_graphs : DGLGraph
DGLGraph for a batch of molecular graphs
complete_graphs : DGLGraph
DGLGraph for a batch of complete graphs
Returns
-------
scores : float32 tensor of shape (E_full, 5)
Predicted scores for each pair of atoms to perform one of the following
5 actions in reaction:
* The bond between them gets broken
* Forming a single bond
* Forming a double bond
* Forming a triple bond
* Forming an aromatic bond
biased_scores : float32 tensor of shape (E_full, 5)
Comparing to scores, a bias is added if the pair is for a same atom.
"""
node_feats = mol_graphs.ndata.pop('hv').to(device)
edge_feats = mol_graphs.edata.pop('he').to(device)
node_pair_feats = complete_graphs.edata.pop('feats').to(device)
return model(mol_graphs, complete_graphs, node_feats, edge_feats, node_pair_feats)
def rough_eval(complete_graphs, preds, labels, num_correct):
batch_size = complete_graphs.batch_size
start = 0
for i in range(batch_size):
end = start + complete_graphs.batch_num_edges[i]
preds_i = preds[start:end, :].flatten()
labels_i = labels[start:end, :].flatten()
for k in num_correct.keys():
topk_values, topk_indices = torch.topk(preds_i, k)
is_correct = labels_i[topk_indices].sum() == labels_i.sum().float().cpu().data.item()
num_correct[k].append(is_correct)
start = end
def rough_eval_on_a_loader(args, model, data_loader):
"""A rough evaluation of model performance in the middle of training.
For final evaluation, we will eliminate some possibilities based on prior knowledge.
Parameters
----------
args : dict
Configurations fot the experiment.
model : nn.Module
Model for reaction center prediction.
data_loader : torch.utils.data.DataLoader
Loader for fetching and batching data.
Returns
-------
str
Message for evluation result.
"""
model.eval()
num_correct = {k: [] for k in args['top_ks']}
for batch_id, batch_data in enumerate(data_loader):
batch_reactions, batch_graph_edits, batch_mols, batch_mol_graphs, \
batch_complete_graphs, batch_atom_pair_labels = batch_data
with torch.no_grad():
pred, biased_pred = reaction_center_prediction(
args['device'], model, batch_mol_graphs, batch_complete_graphs)
rough_eval(batch_complete_graphs, biased_pred, batch_atom_pair_labels, num_correct)
msg = '|'
for k, correct_count in num_correct.items():
msg += ' acc@{:d} {:.4f} |'.format(k, np.mean(correct_count))
return msg
def eval(complete_graphs, preds, reactions, graph_edits, num_correct, max_k, easy):
"""Evaluate top-k accuracies for reaction center prediction.
Parameters
----------
complete_graphs : DGLGraph
DGLGraph for a batch of complete graphs
preds : float32 tensor of shape (E_full, 5)
Soft predictions for reaction center, E_full being the number of possible
atom-pairs and 5 being the number of possible bond changes
reactions : list of str
List of reactions.
graph_edits : list of str
List of graph edits in the reactions.
num_correct : dict
Counting the number of datapoints for meeting top-k accuracies.
max_k : int
Maximum number of atom pairs to be selected. This is intended to be larger
than max(num_correct.keys()) as we will filter out many atom pairs due to
considerations such as avoiding duplicates.
easy : bool
If True, reactants not contributing atoms to the product will be excluded in
top-k atom pair selection, which will make the task easier.
"""
# 0 for losing the bond
# 1, 2, 3, 1.5 separately for forming a single, double, triple or aromatic bond.
bond_change_to_id = {0.0: 0, 1:1, 2:2, 3:3, 1.5:4}
id_to_bond_change = {v: k for k, v in bond_change_to_id.items()}
num_change_types = len(bond_change_to_id)
batch_size = complete_graphs.batch_size
start = 0
for i in range(batch_size):
# Decide which atom-pairs will be considered.
reaction_i = reactions[i]
reaction_atoms_i = []
reaction_bonds_i = defaultdict(bool)
reactants_i, _, product_i = reaction_i.split('>')
product_mol_i = Chem.MolFromSmiles(product_i)
product_atoms_i = set([atom.GetAtomMapNum() for atom in product_mol_i.GetAtoms()])
for reactant in reactants_i.split('.'):
reactant_mol = Chem.MolFromSmiles(reactant)
reactant_atoms = [atom.GetAtomMapNum() for atom in reactant_mol.GetAtoms()]
if (len(set(reactant_atoms) & product_atoms_i) > 0) or (not easy):
reaction_atoms_i.extend(reactant_atoms)
for bond in reactant_mol.GetBonds():
end_atoms = sorted([bond.GetBeginAtom().GetAtomMapNum(),
bond.GetEndAtom().GetAtomMapNum()])
bond = tuple(end_atoms + [bond.GetBondTypeAsDouble()])
reaction_bonds_i[bond] = True
num_nodes = complete_graphs.batch_num_nodes[i]
end = start + complete_graphs.batch_num_edges[i]
preds_i = preds[start:end, :].flatten()
candidate_bonds = []
topk_values, topk_indices = torch.topk(preds_i, max_k)
for j in range(max_k):
preds_i_j = topk_indices[j].cpu().item()
# A bond change can be either losing the bond or forming a
# single, double, triple or aromatic bond
change_id = preds_i_j % num_change_types
change_type = id_to_bond_change[change_id]
pair_id = preds_i_j // num_change_types
atom1 = pair_id // num_nodes + 1
atom2 = pair_id % num_nodes + 1
# Avoid duplicates and an atom cannot form a bond with itself
if atom1 >= atom2:
continue
if atom1 not in reaction_atoms_i:
continue
if atom2 not in reaction_atoms_i:
continue
candidate = (int(atom1), int(atom2), float(change_type))
if reaction_bonds_i[candidate]:
continue
candidate_bonds.append(candidate)
gold_bonds = []
gold_edits = graph_edits[i]
for edit in gold_edits.split(';'):
atom1, atom2, change_type = edit.split('-')
atom1, atom2 = int(atom1), int(atom2)
gold_bonds.append((min(atom1, atom2), max(atom1, atom2), float(change_type)))
for k in num_correct.keys():
if set(gold_bonds) <= set(candidate_bonds[:k]):
num_correct[k] += 1
start = end
def reaction_center_final_eval(args, model, data_loader, easy):
"""Final evaluation of model performance.
args : dict
Configurations fot the experiment.
model : nn.Module
Model for reaction center prediction.
data_loader : torch.utils.data.DataLoader
Loader for fetching and batching data.
easy : bool
If True, reactants not contributing atoms to the product will be excluded in
top-k atom pair selection, which will make the task easier.
Returns
-------
msg : str
Summary of the top-k evaluation.
"""
model.eval()
num_correct = {k: 0 for k in args['top_ks']}
for batch_id, batch_data in enumerate(data_loader):
batch_reactions, batch_graph_edits, batch_mols, batch_mol_graphs, \
batch_complete_graphs, batch_atom_pair_labels = batch_data
with torch.no_grad():
pred, biased_pred = reaction_center_prediction(
args['device'], model, batch_mol_graphs, batch_complete_graphs)
eval(batch_complete_graphs, biased_pred, batch_reactions,
batch_graph_edits, num_correct, args['max_k'], easy)
msg = '|'
for k, correct_count in num_correct.items():
msg += ' acc@{:d} {:.4f} |'.format(k, correct_count / len(data_loader.dataset))
return msg
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment