Unverified Commit 545cc065 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[DGL-LifeSci] WLN for Reaction Center Prediction (#1360)

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update
parent 5a1ef70f
......@@ -10,6 +10,8 @@ with graph neural networks.
We provide various functionalities, including but not limited to methods for graph construction,
featurization, and evaluation, model architectures, training scripts and pre-trained models.
**For a full list of work implemented in DGL-LifeSci, see [here](examples/README.md).**
## Dependencies
For the time being, we only support PyTorch.
......@@ -19,14 +21,12 @@ Depending on the features you want to use, you may need to manually install the
- RDKit 2018.09.3
- We recommend installation with `conda install -c conda-forge rdkit==2018.09.3`. For other installation recipes,
see the [official documentation](https://www.rdkit.org/docs/Install.html).
- MDTraj
- (optional) MDTraj
- We recommend installation with `conda install -c conda-forge mdtraj`. For alternative ways of installation,
see the [official documentation](http://mdtraj.org/1.9.3/installation.html).
## Organization
For a full list of work implemented in DGL-LifeSci, **see implemented.md**.
```
dgllife
data
......@@ -124,7 +124,8 @@ SVG(Draw.MolsToGridImage(mols, molsPerRow=4, subImgSize=(180, 150), useSVG=True)
Below we provide some reference numbers to show how DGL improves the speed of training models per epoch in seconds.
| Model | Original Implementation | DGL Implementation | Improvement |
| -------------------------- | ----------------------- | ------------------ | ----------- |
| ---------------------------------- | ----------------------- | ------------------ | ----------- |
| GCN on Tox21 | 5.5 (DeepChem) | 1.0 | 5.5x |
| AttentiveFP on Aromaticity | 6.0 | 1.2 | 5x |
| JTNN on ZINC | 1826 | 743 | 2.5x |
| WLN for reaction center prediction | 11657 | 5095 | 2.3x | |
__version__ = '0.1.0'
__version__ = '0.2.0'
......@@ -4,3 +4,4 @@ from .csv_dataset import *
from .pdbbind import *
from .pubchem_aromaticity import *
from .tox21 import *
from .uspto import *
......@@ -146,7 +146,7 @@ class TencentAlchemyDataset(object):
contest is ongoing.
mol_to_graph: callable, str -> DGLGraph
A function turning an RDKit molecule instance into a DGLGraph.
Default to :func:`dgl.data.chem.mol_to_complete_graph`.
Default to :func:`dgllife.utils.mol_to_complete_graph`.
node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for nodes like atoms in a molecule, which can be used to update
ndata for a DGLGraph. By default, we construct graphs where nodes represent atoms
......
......@@ -23,7 +23,7 @@ class PubChemBioAssayAromaticity(MoleculeCSVDataset):
----------
smiles_to_graph: callable, str -> DGLGraph
A function turning smiles into a DGLGraph.
Default to :func:`dgl.data.chem.smiles_to_bigraph`.
Default to :func:`dgllife.utils.smiles_to_bigraph`.
node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for nodes like atoms in a molecule, which can be used to update
ndata for a DGLGraph. Default to None.
......
......@@ -30,7 +30,7 @@ class Tox21(MoleculeCSVDataset):
----------
smiles_to_graph: callable, str -> DGLGraph
A function turning smiles into a DGLGraph.
Default to :func:`dgl.data.chem.smiles_to_bigraph`.
Default to :func:`dgllife.utils.smiles_to_bigraph`.
node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for nodes like atoms in a molecule, which can be used to update
ndata for a DGLGraph. Default to None.
......
"""USPTO for reaction prediction"""
import numpy as np
import os
import torch
from collections import defaultdict
from dgl.data.utils import get_download_dir, download, _get_dgl_url, extract_archive, \
save_graphs, load_graphs
from functools import partial
from rdkit import Chem, RDLogger
from rdkit.Chem import rdmolops
from tqdm import tqdm
from ..utils.featurizers import BaseAtomFeaturizer, ConcatFeaturizer, atom_type_one_hot, \
atom_degree_one_hot, atom_explicit_valence_one_hot, atom_implicit_valence_one_hot, \
atom_is_aromatic, BaseBondFeaturizer, bond_type_one_hot, bond_is_conjugated, bond_is_in_ring
from ..utils.mol_to_graph import mol_to_bigraph, mol_to_complete_graph
__all__ = ['WLNReactionDataset',
'USPTO']
# Disable RDKit warnings
RDLogger.DisableLog('rdApp.*')
# Atom types distinguished in featurization
atom_types = ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na', 'Ca', 'Fe',
'As', 'Al', 'I', 'B', 'V', 'K', 'Tl', 'Yb', 'Sb', 'Sn', 'Ag', 'Pd', 'Co',
'Se', 'Ti', 'Zn', 'H', 'Li', 'Ge', 'Cu', 'Au', 'Ni', 'Cd', 'In', 'Mn', 'Zr',
'Cr', 'Pt', 'Hg', 'Pb', 'W', 'Ru', 'Nb', 'Re', 'Te', 'Rh', 'Tc', 'Ba', 'Bi',
'Hf', 'Mo', 'U', 'Sm', 'Os', 'Ir', 'Ce', 'Gd', 'Ga', 'Cs']
default_node_featurizer = BaseAtomFeaturizer({
'hv': ConcatFeaturizer(
[partial(atom_type_one_hot, allowable_set=atom_types, encode_unknown=True),
partial(atom_degree_one_hot, allowable_set=list(range(6))),
atom_explicit_valence_one_hot,
partial(atom_implicit_valence_one_hot, allowable_set=list(range(6))),
atom_is_aromatic]
)
})
default_edge_featurizer = BaseBondFeaturizer({
'he': ConcatFeaturizer([
bond_type_one_hot, bond_is_conjugated, bond_is_in_ring]
)
})
def default_atom_pair_featurizer(reactants):
"""Featurize each pair of atoms, which will be used in updating
the edata of a complete DGLGraph.
The features include the bond type between the atoms (if any) and whether
they belong to the same molecule. It is used in the global attention mechanism.
Parameters
----------
reactants : str
SMILES for reactants
data_field : str
Key for storing the features in DGLGraph.edata. Default to 'atom_pair'
Returns
-------
float32 tensor of shape (V^2, 10)
features for each pair of atoms.
"""
# Decide the reactant membership for each atom
atom_to_reactant = dict()
reactant_list = reactants.split('.')
for id, s in enumerate(reactant_list):
mol = Chem.MolFromSmiles(s)
for atom in mol.GetAtoms():
atom_to_reactant[atom.GetIntProp('molAtomMapNumber') - 1] = id
# Construct mapping from atom pair to RDKit bond object
all_reactant_mol = Chem.MolFromSmiles(reactants)
atom_pair_to_bond = dict()
for bond in all_reactant_mol.GetBonds():
atom1 = bond.GetBeginAtom().GetIntProp('molAtomMapNumber') - 1
atom2 = bond.GetEndAtom().GetIntProp('molAtomMapNumber') - 1
atom_pair_to_bond[(atom1, atom2)] = bond
atom_pair_to_bond[(atom2, atom1)] = bond
def _featurize_a_bond(bond):
return bond_type_one_hot(bond) + bond_is_conjugated(bond) + bond_is_in_ring(bond)
features = []
num_atoms = all_reactant_mol.GetNumAtoms()
for i in range(num_atoms):
for j in range(num_atoms):
pair_feature = np.zeros(10)
if i == j:
features.append(pair_feature)
continue
bond = atom_pair_to_bond.get((i, j), None)
if bond is not None:
pair_feature[1:7] = _featurize_a_bond(bond)
else:
pair_feature[0] = 1.
pair_feature[-4] = 1. if atom_to_reactant[i] != atom_to_reactant[j] else 0.
pair_feature[-3] = 1. if atom_to_reactant[i] == atom_to_reactant[j] else 0.
pair_feature[-2] = 1. if len(reactant_list) == 1 else 0.
pair_feature[-1] = 1. if len(reactant_list) > 1 else 0.
features.append(pair_feature)
return torch.from_numpy(np.stack(features, axis=0).astype(np.float32))
def get_pair_label(reactants_mol, graph_edits):
"""Construct labels for each pair of atoms in reaction center prediction
Parameters
----------
reactants_mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance for all reactants in a reaction
graph_edits : str
Specifying which pairs of atoms loss a bond or form a particular bond in the reaction
Returns
-------
float32 tensor of shape (V^2, 5)
Labels constructed. V for the number of atoms in the reactants.
"""
# 0 for losing the bond
# 1, 2, 3, 1.5 separately for forming a single, double, triple or aromatic bond.
bond_change_to_id = {0.0: 0, 1:1, 2:2, 3:3, 1.5:4}
pair_to_changes = defaultdict(list)
for edit in graph_edits.split(';'):
a1, a2, change = edit.split('-')
atom1 = int(a1) - 1
atom2 = int(a2) - 1
change = bond_change_to_id[float(change)]
pair_to_changes[(atom1, atom2)].append(change)
pair_to_changes[(atom2, atom1)].append(change)
num_atoms = reactants_mol.GetNumAtoms()
labels = torch.zeros((num_atoms, num_atoms, 5))
for pair in pair_to_changes.keys():
i, j = pair
labels[i, j, pair_to_changes[(j, i)]] = 1.
return labels.reshape(-1, 5)
def get_bond_changes(reaction):
"""Get the bond changes in a reaction.
Parameters
----------
reaction : str
SMILES for a reaction, e.g. [CH3:14][NH2:15].[N+:1](=[O:2])([O-:3])[c:4]1[cH:5][c:6]([C:7]
(=[O:8])[OH:9])[cH:10][cH:11][c:12]1[Cl:13].[OH2:16]>>[N+:1](=[O:2])([O-:3])[c:4]1[cH:5]
[c:6]([C:7](=[O:8])[OH:9])[cH:10][cH:11][c:12]1[NH:15][CH3:14]. It consists of reactants,
products and the atom mapping.
Returns
-------
bond_changes : set of 3-tuples
Each tuple consists of (atom1, atom2, change type)
There are 5 possible values for change type. 0 for losing the bond, and 1, 2, 3, 1.5
separately for forming a single, double, triple or aromatic bond.
"""
reactants = Chem.MolFromSmiles(reaction.split('>')[0])
products = Chem.MolFromSmiles(reaction.split('>')[2])
conserved_maps = [
a.GetProp('molAtomMapNumber')
for a in products.GetAtoms() if a.HasProp('molAtomMapNumber')]
bond_changes = set() # keep track of bond changes
# Look at changed bonds
bonds_prev = {}
for bond in reactants.GetBonds():
nums = sorted(
[bond.GetBeginAtom().GetProp('molAtomMapNumber'),
bond.GetEndAtom().GetProp('molAtomMapNumber')])
if (nums[0] not in conserved_maps) and (nums[1] not in conserved_maps):
continue
bonds_prev['{}~{}'.format(nums[0], nums[1])] = bond.GetBondTypeAsDouble()
bonds_new = {}
for bond in products.GetBonds():
nums = sorted(
[bond.GetBeginAtom().GetProp('molAtomMapNumber'),
bond.GetEndAtom().GetProp('molAtomMapNumber')])
bonds_new['{}~{}'.format(nums[0], nums[1])] = bond.GetBondTypeAsDouble()
for bond in bonds_prev:
if bond not in bonds_new:
# lost bond
bond_changes.add((bond.split('~')[0], bond.split('~')[1], 0.0))
else:
if bonds_prev[bond] != bonds_new[bond]:
# changed bond
bond_changes.add((bond.split('~')[0], bond.split('~')[1], bonds_new[bond]))
for bond in bonds_new:
if bond not in bonds_prev:
# new bond
bond_changes.add((bond.split('~')[0], bond.split('~')[1], bonds_new[bond]))
return bond_changes
def process_file(path):
"""Pre-process a file of reactions for working with WLN.
Parameters
----------
path : str
Path to the file of reactions
"""
with open(path, 'r') as input_file, open(path + '.proc', 'w') as output_file:
for line in tqdm(input_file):
reaction = line.strip()
bond_changes = get_bond_changes(reaction)
output_file.write('{} {}\n'.format(
reaction,
';'.join(['{}-{}-{}'.format(x[0], x[1], x[2]) for x in bond_changes])))
print('Finished processing {}'.format(path))
class WLNReactionDataset(object):
"""Dataset for reaction prediction with WLN
Parameters
----------
raw_file_path : str
Path to the raw reaction file, where each line is the SMILES for a reaction.
We will check if raw_file_path + '.proc' exists, where each line has the reaction
SMILES and the corresponding graph edits. If not, we will preprocess
the raw reaction file.
mol_graph_path : str
Path to save/load DGLGraphs for molecules.
mol_to_graph: callable, str -> DGLGraph
A function turning RDKit molecule instances into DGLGraphs.
Default to :func:`dgllife.utils.mol_to_bigraph`.
node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for nodes like atoms in a molecule, which can be used to update
ndata for a DGLGraph. By default, we consider descriptors including atom type,
atom degree, atom explicit valence, atom implicit valence, aromaticity.
edge_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for edges like bonds in a molecule, which can be used to update
edata for a DGLGraph. By default, we consider descriptors including bond type,
whether bond is conjugated and whether bond is in ring.
atom_pair_featurizer : callable, str -> dict
Featurization for each pair of atoms in multiple reactants. The result will be
used to update edata in the complete DGLGraphs. By default, the features include
the bond type between the atoms (if any) and whether they belong to the same molecule.
load : bool
Whether to load the previously pre-processed dataset or pre-process from scratch.
``load`` should be False when we want to try different graph construction and
featurization methods and need to preprocess from scratch. Default to True.
"""
def __init__(self,
raw_file_path,
mol_graph_path,
mol_to_graph=mol_to_bigraph,
node_featurizer=default_node_featurizer,
edge_featurizer=default_edge_featurizer,
atom_pair_featurizer=default_atom_pair_featurizer,
load=True):
super(WLNReactionDataset, self).__init__()
self._atom_pair_featurizer = atom_pair_featurizer
self.atom_pair_features = []
self.atom_pair_labels = []
# Map number of nodes to a corresponding complete graph
self.complete_graphs = dict()
path_to_reaction_file = raw_file_path + '.proc'
if not os.path.isfile(path_to_reaction_file):
# Pre-process graph edits information
process_file(raw_file_path)
full_mols, full_reactions, full_graph_edits = \
self.load_reaction_data(path_to_reaction_file)
if load and os.path.isfile(mol_graph_path):
self.reactant_mol_graphs, _ = load_graphs(mol_graph_path)
else:
self.reactant_mol_graphs = []
for i in range(len(full_mols)):
if i % 10000 == 0:
print('Processing reaction {:d}/{:d}'.format(i + 1, len(full_mols)))
mol = full_mols[i]
reactant_mol_graph = mol_to_graph(mol, node_featurizer=node_featurizer,
edge_featurizer=edge_featurizer,
canonical_atom_order=False)
self.reactant_mol_graphs.append(reactant_mol_graph)
save_graphs(mol_graph_path, self.reactant_mol_graphs)
self.mols = full_mols
self.reactions = full_reactions
self.graph_edits = full_graph_edits
self.atom_pair_features.extend([None for _ in range(len(self.mols))])
self.atom_pair_labels.extend([None for _ in range(len(self.mols))])
def load_reaction_data(self, file_path):
"""Load reaction data from the raw file.
Parameters
----------
file_path : str
Path to read the file.
Returns
-------
all_mols : list of rdkit.Chem.rdchem.Mol
RDKit molecule instances
all_reactions : list of str
Reactions
all_graph_edits : list of str
Graph edits in the reactions.
"""
all_mols = []
all_reactions = []
all_graph_edits = []
with open(file_path, 'r') as f:
for i, line in enumerate(f):
if i % 10000 == 0:
print('Processing line {:d}'.format(i))
# Each line represents a reaction and the corresponding graph edits
#
# reaction example:
# [CH3:14][OH:15].[NH2:12][NH2:13].[OH2:11].[n:1]1[n:2][cH:3][c:4]
# ([C:7]([O:9][CH3:8])=[O:10])[cH:5][cH:6]1>>[n:1]1[n:2][cH:3][c:4]
# ([C:7](=[O:9])[NH:12][NH2:13])[cH:5][cH:6]1
# The reactants are on the left-hand-side of the reaction and the product
# is on the right-hand-side of the reaction. The numbers represent atom mapping.
#
# graph_edits example:
# 23-33-1.0;23-25-0.0
# For a triplet a-b-c, a and b are the atoms that form or loss the bond.
# c specifies the particular change, 0.0 for losing a bond, 1.0, 2.0, 3.0 and
# 1.5 separately for forming a single, double, triple or aromatic bond.
reaction, graph_edits = line.strip("\r\n ").split()
reactants = reaction.split('>')[0]
mol = Chem.MolFromSmiles(reactants)
if mol is None:
continue
# Reorder atoms according to the order specified in the atom map
atom_map_order = [-1 for _ in range(mol.GetNumAtoms())]
for i in range(mol.GetNumAtoms()):
atom = mol.GetAtomWithIdx(i)
atom_map_order[atom.GetIntProp('molAtomMapNumber') - 1] = i
mol = rdmolops.RenumberAtoms(mol, atom_map_order)
all_mols.append(mol)
all_reactions.append(reaction)
all_graph_edits.append(graph_edits)
return all_mols, all_reactions, all_graph_edits
def __len__(self):
"""Get the size for the dataset.
Returns
-------
int
Number of reactions in the dataset.
"""
return len(self.mols)
def __getitem__(self, item):
"""Get the i-th datapoint.
Returns
-------
str
Reaction
str
Graph edits for the reaction
rdkit.Chem.rdchem.Mol
RDKit molecule instance
DGLGraph
DGLGraph for the ith molecular graph
DGLGraph
Complete DGLGraph, which will be needed for predicting
scores between each pair of atoms
float32 tensor of shape (V^2, 10)
Features for each pair of atoms.
float32 tensor of shape (V^2, 5)
Labels for reaction center prediction.
V for the number of atoms in the reactants.
"""
mol = self.mols[item]
num_atoms = mol.GetNumAtoms()
if num_atoms not in self.complete_graphs:
self.complete_graphs[num_atoms] = mol_to_complete_graph(
mol, add_self_loop=True, canonical_atom_order=True)
if self.atom_pair_features[item] is None:
reactants = self.reactions[item].split('>')[0]
self.atom_pair_features[item] = self._atom_pair_featurizer(reactants)
if self.atom_pair_labels[item] is None:
self.atom_pair_labels[item] = get_pair_label(mol, self.graph_edits[item])
return self.reactions[item], self.graph_edits[item], mol, \
self.reactant_mol_graphs[item], \
self.complete_graphs[num_atoms], \
self.atom_pair_features[item], \
self.atom_pair_labels[item]
class USPTO(WLNReactionDataset):
"""USPTO dataset for reaction prediction.
The dataset contains reactions from patents granted by United States Patent
and Trademark Office (USPTO), collected by Lowe [1]. Jin et al. removes duplicates
and erroneous reactions, obtaining a set of 480K reactions. They divide it
into 400K, 40K, and 40K for training, validation and test.
References:
* [1] Patent reaction extraction
* [2] Predicting Organic Reaction Outcomes with Weisfeiler-Lehman Network
Parameters
----------
subset : str
Whether to use the training/validation/test set as in Jin et al.
* 'train' for the training set
* 'val' for the validation set
* 'test' for the test set
mol_to_graph: callable, str -> DGLGraph
A function turning RDKit molecule instances into DGLGraphs.
Default to :func:`dgllife.utils.mol_to_bigraph`.
node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for nodes like atoms in a molecule, which can be used to update
ndata for a DGLGraph. By default, we consider descriptors including atom type,
atom degree, atom explicit valence, atom implicit valence, aromaticity.
edge_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for edges like bonds in a molecule, which can be used to update
edata for a DGLGraph. By default, we consider descriptors including bond type,
whether bond is conjugated and whether bond is in ring.
atom_pair_featurizer : callable, str -> dict
Featurization for each pair of atoms in multiple reactants. The result will be
used to update edata in the complete DGLGraphs. By default, the features include
the bond type between the atoms (if any) and whether they belong to the same molecule.
load : bool
Whether to load the previously pre-processed dataset or pre-process from scratch.
``load`` should be False when we want to try different graph construction and
featurization methods and need to preprocess from scratch. Default to True.
"""
def __init__(self,
subset,
mol_to_graph=mol_to_bigraph,
node_featurizer=default_node_featurizer,
edge_featurizer=default_edge_featurizer,
atom_pair_featurizer=default_atom_pair_featurizer,
load=True):
assert subset in ['train', 'val', 'test'], \
'Expect subset to be "train" or "val" or "test", got {}'.format(subset)
print('Preparing {} subset of USPTO'.format(subset))
self._subset = subset
if subset == 'val':
subset = 'valid'
self._url = 'dataset/uspto.zip'
data_path = get_download_dir() + '/uspto.zip'
extracted_data_path = get_download_dir() + '/uspto'
download(_get_dgl_url(self._url), path=data_path)
extract_archive(data_path, extracted_data_path)
super(USPTO, self).__init__(
raw_file_path=extracted_data_path + '/{}.txt'.format(subset),
mol_graph_path=extracted_data_path + '/{}_mol_graphs.bin'.format(subset),
mol_to_graph=mol_to_graph,
node_featurizer=node_featurizer,
edge_featurizer=edge_featurizer,
atom_pair_featurizer=atom_pair_featurizer,
load=load)
@property
def subset(self):
"""Get the subset used for USPTO
Returns
-------
str
* 'full' for the complete dataset
* 'train' for the training set
* 'val' for the validation set
* 'test' for the test set
"""
return self._subset
......@@ -5,3 +5,4 @@ from .gcn import *
from .mgcn import *
from .mpnn import *
from .schnet import *
from .wln import *
"""WLN"""
import dgl.function as fn
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Parameter
__all__ = ['WLN']
class WLNLinear(nn.Module):
"""Linear layer for WLN
Let stddev be
.. math::
\min(\frac{1.0}{\sqrt{in_feats}}, 0.1)
The weight of the linear layer is initialized from a normal distribution
with mean 0 and std as specified in stddev.
Parameters
----------
in_feats : int
Size for the input.
out_feats : int
Size for the output.
bias : bool
Whether bias will be added to the output. Default to True.
"""
def __init__(self, in_feats, out_feats, bias=True):
super(WLNLinear, self).__init__()
self.in_feats = in_feats
self.out_feats = out_feats
self.weight = Parameter(torch.Tensor(out_feats, in_feats))
if bias:
self.bias = Parameter(torch.Tensor(out_feats))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
"""Initialize model parameters."""
stddev = min(1.0 / math.sqrt(self.in_feats), 0.1)
nn.init.normal_(self.weight, std=stddev)
if self.bias is not None:
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
bound = 1 / math.sqrt(fan_in)
nn.init.uniform_(self.bias, -bound, bound)
def forward(self, input):
"""Applies the layer.
Parameters
----------
input : float32 tensor of shape (N, *, in_feats)
N for the number of samples, * for any additional dimensions.
Returns
-------
float32 tensor of shape (N, *, out_feats)
Result of the layer.
"""
return F.linear(input, self.weight, self.bias)
def extra_repr(self):
"""Return a description of the layer."""
return 'in_feats={}, out_feats={}, bias={}'.format(
self.in_feats, self.out_feats, self.bias is not None
)
class WLN(nn.Module):
"""Weisfeiler-Lehman Network (WLN)
WLN is introduced in `Predicting Organic Reaction Outcomes with
Weisfeiler-Lehman Network <https://arxiv.org/abs/1709.04555>`__.
This class performs message passing and updates node representations.
Parameters
----------
node_in_feats : int
Size for the input node features.
edge_in_feats : int
Size for the input edge features.
node_out_feats : int
Size for the output node representations. Default to 300.
n_layers : int
Number of times for message passing. Note that same parameters
are shared across n_layers message passing. Default to 3.
"""
def __init__(self,
node_in_feats,
edge_in_feats,
node_out_feats=300,
n_layers=3):
super(WLN, self).__init__()
self.n_layers = n_layers
self.project_node_in_feats = nn.Sequential(
WLNLinear(node_in_feats, node_out_feats, bias=False),
nn.ReLU()
)
self.project_concatenated_messages = nn.Sequential(
WLNLinear(edge_in_feats + node_out_feats, node_out_feats),
nn.ReLU()
)
self.get_new_node_feats = nn.Sequential(
WLNLinear(2 * node_out_feats, node_out_feats),
nn.ReLU()
)
self.project_edge_messages = WLNLinear(edge_in_feats, node_out_feats, bias=False)
self.project_node_messages = WLNLinear(node_out_feats, node_out_feats, bias=False)
self.project_self = WLNLinear(node_out_feats, node_out_feats, bias=False)
def forward(self, g, node_feats, edge_feats):
"""Performs message passing and updates node representations.
Parameters
----------
g : DGLGraph
DGLGraph for a batch of graphs
node_feats : float32 tensor of shape (V, node_in_feats)
Input node features. V for the number of nodes.
edge_feats : float32 tensor of shape (E, edge_in_feats)
Input edge features. E for the number of edges.
Returns
-------
float32 tensor of shape (V, node_out_feats)
Updated node representations.
"""
node_feats = self.project_node_in_feats(node_feats)
for l in range(self.n_layers):
g = g.local_var()
g.ndata['hv'] = node_feats
g.apply_edges(fn.copy_src('hv', 'he_src'))
concat_edge_feats = torch.cat([g.edata['he_src'], edge_feats], dim=1)
g.edata['he'] = self.project_concatenated_messages(concat_edge_feats)
g.update_all(fn.copy_edge('he', 'm'), fn.sum('m', 'hv_new'))
g = g.local_var()
g.ndata['hv'] = self.project_node_messages(node_feats)
g.edata['he'] = self.project_edge_messages(edge_feats)
g.update_all(fn.u_mul_e('hv', 'he', 'm'), fn.sum('m', 'h_nbr'))
h_self = self.project_self(node_feats) # (V, node_out_feats)
return g.ndata['h_nbr'] * h_self
......@@ -9,3 +9,4 @@ from .schnet_predictor import *
from .mgcn_predictor import *
from .mpnn_predictor import *
from .acnn import *
from .wln_reaction_center import *
"""Weisfeiler-Lehman Network (WLN) for Reaction Center Prediction."""
import dgl.function as fn
import torch
import torch.nn as nn
from ..gnn.wln import WLNLinear, WLN
__all__ = ['WLNReactionCenter']
class WLNContext(nn.Module):
"""Attention-based context computation for each node.
A context vector is computed by taking a weighted sum of node representations,
with weights computed from an attention module.
Parameters
----------
node_in_feats : int
Size for the input node features.
node_pair_in_feats : int
Size for the input features of node pairs.
"""
def __init__(self, node_in_feats, node_pair_in_feats):
super(WLNContext, self).__init__()
self.project_feature_sum = WLNLinear(node_in_feats, node_in_feats, bias=False)
self.project_node_pair_feature = WLNLinear(node_pair_in_feats, node_in_feats)
self.compute_attention = nn.Sequential(
nn.ReLU(),
WLNLinear(node_in_feats, 1),
nn.Sigmoid()
)
def forward(self, batch_complete_graphs, node_feats, feat_sum, node_pair_feat):
"""Compute context vectors for each node.
Parameters
----------
batch_complete_graphs : DGLGraph
A batch of fully connected graphs.
node_feats : float32 tensor of shape (V, node_in_feats)
Input node features. V for the number of nodes.
feat_sum : float32 tensor of shape (E_full, node_in_feats)
Sum of node_feats between each pair of nodes. E_full for the number of
edges in the batch of complete graphs.
node_pair_feat : float32 tensor of shape (E_full, node_pair_in_feats)
Input features for each pair of nodes. E_full for the number of edges in
the batch of complete graphs.
Returns
-------
node_contexts : float32 tensor of shape (V, node_in_feats)
Context vectors for nodes.
"""
with batch_complete_graphs.local_scope():
batch_complete_graphs.ndata['hv'] = node_feats
batch_complete_graphs.edata['a'] = self.compute_attention(
self.project_feature_sum(feat_sum) + \
self.project_node_pair_feature(node_pair_feat)
)
batch_complete_graphs.update_all(
fn.src_mul_edge('hv', 'a', 'm'), fn.sum('m', 'context'))
node_contexts = batch_complete_graphs.ndata.pop('context')
return node_contexts
class WLNReactionCenter(nn.Module):
r"""Weisfeiler-Lehman Network (WLN) for Reaction Center Prediction.
The model is introduced in `Predicting Organic Reaction Outcomes with
Weisfeiler-Lehman Network <https://arxiv.org/abs/1709.04555>`__.
The model uses WLN to update atom representations and then predicts the
score for each pair of atoms to form a bond.
Parameters
----------
node_in_feats : int
Size for the input node features.
edge_in_feats : int
Size for the input edge features.
node_out_feats : int
Size for the output node representations. Default to 300.
node_pair_in_feats : int
Size for the input features of node pairs.
n_layers : int
Number of times for message passing. Note that same parameters
are shared across n_layers message passing. Default to 3.
n_tasks : int
Number of tasks for prediction.
"""
def __init__(self,
node_in_feats,
edge_in_feats,
node_pair_in_feats,
node_out_feats=300,
n_layers=3,
n_tasks=5):
super(WLNReactionCenter, self).__init__()
self.gnn = WLN(node_in_feats=node_in_feats,
edge_in_feats=edge_in_feats,
node_out_feats=node_out_feats,
n_layers=n_layers)
self.context_module = WLNContext(node_in_feats=node_out_feats,
node_pair_in_feats=node_pair_in_feats)
self.project_feature_sum = WLNLinear(node_out_feats, node_out_feats, bias=False)
self.project_node_pair_feature = WLNLinear(node_pair_in_feats, node_out_feats, bias=False)
self.project_context_sum = WLNLinear(node_out_feats, node_out_feats)
self.predict = nn.Sequential(
nn.ReLU(),
WLNLinear(node_out_feats, n_tasks)
)
def forward(self, batch_mol_graphs, batch_complete_graphs,
node_feats, edge_feats, node_pair_feats):
r"""Predict score for each pair of nodes.
Parameters
----------
batch_mol_graphs : DGLGraph
A batch of molecular graphs.
batch_complete_graphs : DGLGraph
A batch of fully connected graphs.
node_feats : float32 tensor of shape (V, node_in_feats)
Input node features. V for the number of nodes.
edge_feats : float32 tensor of shape (E, edge_in_feats)
Input edge features. E for the number of edges.
node_pair_feats : float32 tensor of shape (E_full, node_pair_in_feats)
Input features for each pair of nodes. E_full for the number of edges in
the batch of complete graphs.
Returns
-------
scores : float32 tensor of shape (E_full, 5)
Predicted scores for each pair of atoms to perform one of the following
5 actions in reaction:
* The bond between them gets broken
* Forming a single bond
* Forming a double bond
* Forming a triple bond
* Forming an aromatic bond
biased_scores : float32 tensor of shape (E_full, 5)
Comparing to scores, a bias is added if the pair is for a same atom.
"""
node_feats = self.gnn(batch_mol_graphs, node_feats, edge_feats)
# Compute context vectors for all atoms, which are weighted sum of atom
# representations in all reactants.
with batch_complete_graphs.local_scope():
batch_complete_graphs.ndata['hv'] = node_feats
batch_complete_graphs.apply_edges(fn.u_add_v('hv', 'hv', 'feature_sum'))
feat_sum = batch_complete_graphs.edata.pop('feature_sum')
node_contexts = self.context_module(batch_complete_graphs, node_feats,
feat_sum, node_pair_feats)
# Predict score
with batch_complete_graphs.local_scope():
batch_complete_graphs.ndata['context'] = node_contexts
batch_complete_graphs.apply_edges(fn.u_add_v('context', 'context', 'context_sum'))
scores = self.predict(
self.project_feature_sum(feat_sum) + \
self.project_node_pair_feature(node_pair_feats) + \
self.project_context_sum(batch_complete_graphs.edata['context_sum'])
)
# Masking self loops
nodes = batch_complete_graphs.nodes()
e_ids = batch_complete_graphs.edge_ids(nodes, nodes)
bias = torch.zeros(scores.shape[0], 5).to(scores.device)
bias[e_ids, :] = 1e4
biased_scores = scores - bias
return scores, biased_scores
......@@ -6,7 +6,8 @@ import torch.nn.functional as F
from dgl.data.utils import _get_dgl_url, download, get_download_dir, extract_archive
from rdkit import Chem
from ..model import GCNPredictor, GATPredictor, AttentiveFPPredictor, DGMG, DGLJTNNVAE
from ..model import GCNPredictor, GATPredictor, AttentiveFPPredictor, DGMG, DGLJTNNVAE, \
WLNReactionCenter
__all__ = ['load_pretrained']
......@@ -18,7 +19,8 @@ URL = {
'DGMG_ChEMBL_random': 'pre_trained/dgmg_ChEMBL_random.pth',
'DGMG_ZINC_canonical': 'pre_trained/dgmg_ZINC_canonical.pth',
'DGMG_ZINC_random': 'pre_trained/dgmg_ZINC_random.pth',
'JTNN_ZINC': 'pre_trained/JTNN_ZINC.pth'
'JTNN_ZINC': 'pre_trained/JTNN_ZINC.pth',
'wln_center_uspto': 'dgllife/pre_trained/wln_center_uspto.pth'
}
def download_and_load_checkpoint(model_name, model, model_postfix,
......@@ -72,6 +74,7 @@ def load_pretrained(model_name, log=True):
* ``'DGMG_ZINC_canonical'``
* ``'DGMG_ZINC_random'``
* ``'JTNN_ZINC'``
* ``'wln_center_uspto'``
log : bool
Whether to print progress for model loading
......@@ -134,4 +137,12 @@ def load_pretrained(model_name, log=True):
hidden_size=450,
latent_size=56)
elif model_name == 'wln_center_uspto':
model = WLNReactionCenter(node_in_feats=82,
edge_in_feats=6,
node_pair_in_feats=10,
node_out_feats=300,
n_layers=3,
n_tasks=5)
return download_and_load_checkpoint(model_name, model, URL[model_name], log=log)
......@@ -14,6 +14,8 @@ __all__ = ['one_hot_encoding',
'atom_degree',
'atom_total_degree_one_hot',
'atom_total_degree',
'atom_explicit_valence_one_hot',
'atom_explicit_valence',
'atom_implicit_valence_one_hot',
'atom_implicit_valence',
'atom_hybridization_one_hot',
......@@ -25,6 +27,8 @@ __all__ = ['one_hot_encoding',
'atom_num_radical_electrons',
'atom_is_aromatic_one_hot',
'atom_is_aromatic',
'atom_is_in_ring_one_hot',
'atom_is_in_ring',
'atom_chiral_tag_one_hot',
'atom_mass',
'ConcatFeaturizer',
......@@ -224,8 +228,45 @@ def atom_total_degree(atom):
"""
return [atom.GetTotalDegree()]
def atom_explicit_valence_one_hot(atom, allowable_set=None, encode_unknown=False):
"""One hot encoding for the explicit valence of an aotm.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
allowable_set : list of int
Atom explicit valences to consider. Default: ``1`` - ``6``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
Returns
-------
list
List of boolean values where at most one value is True.
"""
if allowable_set is None:
allowable_set = list(range(1, 7))
return one_hot_encoding(atom.GetExplicitValence(), allowable_set, encode_unknown)
def atom_explicit_valence(atom):
"""Get the explicit valence of an atom.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
Returns
-------
list
List containing one int only.
"""
return [atom.GetExplicitValence()]
def atom_implicit_valence_one_hot(atom, allowable_set=None, encode_unknown=False):
"""One hot encoding for the implicit valences of an atom.
"""One hot encoding for the implicit valence of an atom.
Parameters
----------
......@@ -437,6 +478,43 @@ def atom_is_aromatic(atom):
"""
return [atom.GetIsAromatic()]
def atom_is_in_ring_one_hot(atom, allowable_set=None, encode_unknown=False):
"""One hot encoding for whether the atom is in ring.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
allowable_set : list of bool
Conditions to consider. Default: ``False`` and ``True``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
Returns
-------
list
List of boolean values where at most one value is True.
"""
if allowable_set is None:
allowable_set = [False, True]
return one_hot_encoding(atom.IsInRing(), allowable_set, encode_unknown)
def atom_is_in_ring(atom):
"""Get whether the atom is in ring.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
Returns
-------
list
List containing one bool only.
"""
return [atom.IsInRing()]
def atom_chiral_tag_one_hot(atom, allowable_set=None, encode_unknown=False):
"""One hot encoding for the chiral tag of an atom.
......
......@@ -18,7 +18,7 @@ __all__ = ['mol_to_graph',
'mol_to_complete_graph',
'k_nearest_neighbors']
def mol_to_graph(mol, graph_constructor, node_featurizer, edge_featurizer):
def mol_to_graph(mol, graph_constructor, node_featurizer, edge_featurizer, canonical_atom_order):
"""Convert an RDKit molecule object into a DGLGraph and featurize for it.
Parameters
......@@ -33,12 +33,16 @@ def mol_to_graph(mol, graph_constructor, node_featurizer, edge_featurizer):
edge_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for edges like bonds in a molecule, which can be used to
update edata for a DGLGraph.
canonical_atom_order : bool
Whether to use a canonical order of atoms returned by RDKit. Setting it
to true might change the order of atoms in the graph constructed.
Returns
-------
g : DGLGraph
Converted DGLGraph for the molecule
"""
if canonical_atom_order:
new_order = rdmolfiles.CanonicalRankAtoms(mol)
mol = rdmolops.RenumberAtoms(mol, new_order)
g = graph_constructor(mol)
......@@ -103,7 +107,8 @@ def construct_bigraph_from_mol(mol, add_self_loop=False):
def mol_to_bigraph(mol, add_self_loop=False,
node_featurizer=None,
edge_featurizer=None):
edge_featurizer=None,
canonical_atom_order=True):
"""Convert an RDKit molecule object into a bi-directed DGLGraph and featurize for it.
Parameters
......@@ -118,6 +123,10 @@ def mol_to_bigraph(mol, add_self_loop=False,
edge_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for edges like bonds in a molecule, which can be used to update
edata for a DGLGraph. Default to None.
canonical_atom_order : bool
Whether to use a canonical order of atoms returned by RDKit. Setting it
to true might change the order of atoms in the graph constructed. Default
to True.
Returns
-------
......@@ -125,11 +134,12 @@ def mol_to_bigraph(mol, add_self_loop=False,
Bi-directed DGLGraph for the molecule
"""
return mol_to_graph(mol, partial(construct_bigraph_from_mol, add_self_loop=add_self_loop),
node_featurizer, edge_featurizer)
node_featurizer, edge_featurizer, canonical_atom_order)
def smiles_to_bigraph(smiles, add_self_loop=False,
node_featurizer=None,
edge_featurizer=None):
edge_featurizer=None,
canonical_atom_order=True):
"""Convert a SMILES into a bi-directed DGLGraph and featurize for it.
Parameters
......@@ -144,6 +154,10 @@ def smiles_to_bigraph(smiles, add_self_loop=False,
edge_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for edges like bonds in a molecule, which can be used to update
edata for a DGLGraph. Default to None.
canonical_atom_order : bool
Whether to use a canonical order of atoms returned by RDKit. Setting it
to true might change the order of atoms in the graph constructed. Default
to True.
Returns
-------
......@@ -151,7 +165,8 @@ def smiles_to_bigraph(smiles, add_self_loop=False,
Bi-directed DGLGraph for the molecule
"""
mol = Chem.MolFromSmiles(smiles)
return mol_to_bigraph(mol, add_self_loop, node_featurizer, edge_featurizer)
return mol_to_bigraph(mol, add_self_loop, node_featurizer,
edge_featurizer, canonical_atom_order)
def construct_complete_graph_from_mol(mol, add_self_loop=False):
"""Construct a complete graph with topology only for the molecule
......@@ -174,26 +189,20 @@ def construct_complete_graph_from_mol(mol, add_self_loop=False):
g : DGLGraph
Empty complete graph topology of the molecule
"""
g = DGLGraph()
num_atoms = mol.GetNumAtoms()
g.add_nodes(num_atoms)
if add_self_loop:
g.add_edges(
[i for i in range(num_atoms) for j in range(num_atoms)],
[j for i in range(num_atoms) for j in range(num_atoms)])
else:
g.add_edges(
[i for i in range(num_atoms) for j in range(num_atoms - 1)], [
j for i in range(num_atoms)
for j in range(num_atoms) if i != j
])
edge_list = []
for i in range(num_atoms):
for j in range(num_atoms):
if i != j or add_self_loop:
edge_list.append((i, j))
g = DGLGraph(edge_list)
return g
def mol_to_complete_graph(mol, add_self_loop=False,
node_featurizer=None,
edge_featurizer=None):
edge_featurizer=None,
canonical_atom_order=True):
"""Convert an RDKit molecule into a complete DGLGraph and featurize for it.
Parameters
......@@ -208,6 +217,10 @@ def mol_to_complete_graph(mol, add_self_loop=False,
edge_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for edges like bonds in a molecule, which can be used to update
edata for a DGLGraph. Default to None.
canonical_atom_order : bool
Whether to use a canonical order of atoms returned by RDKit. Setting it
to true might change the order of atoms in the graph constructed. Default
to True.
Returns
-------
......@@ -215,11 +228,12 @@ def mol_to_complete_graph(mol, add_self_loop=False,
Complete DGLGraph for the molecule
"""
return mol_to_graph(mol, partial(construct_complete_graph_from_mol, add_self_loop=add_self_loop),
node_featurizer, edge_featurizer)
node_featurizer, edge_featurizer, canonical_atom_order)
def smiles_to_complete_graph(smiles, add_self_loop=False,
node_featurizer=None,
edge_featurizer=None):
edge_featurizer=None,
canonical_atom_order=True):
"""Convert a SMILES into a complete DGLGraph and featurize for it.
Parameters
......@@ -234,6 +248,10 @@ def smiles_to_complete_graph(smiles, add_self_loop=False,
edge_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for edges like bonds in a molecule, which can be used to update
edata for a DGLGraph. Default to None.
canonical_atom_order : bool
Whether to use a canonical order of atoms returned by RDKit. Setting it
to true might change the order of atoms in the graph constructed. Default
to True.
Returns
-------
......@@ -241,7 +259,8 @@ def smiles_to_complete_graph(smiles, add_self_loop=False,
Complete DGLGraph for the molecule
"""
mol = Chem.MolFromSmiles(smiles)
return mol_to_complete_graph(mol, add_self_loop, node_featurizer, edge_featurizer)
return mol_to_complete_graph(mol, add_self_loop, node_featurizer,
edge_featurizer, canonical_atom_order)
def k_nearest_neighbors(coordinates, neighbor_cutoff, max_num_neighbors):
"""Find k nearest neighbors for each atom based on the 3D coordinates and
......
# Work Implemented in DGL-LifeSci
We provide various examples across 3 applications -- property prediction, generative models and protein-ligand binding affinity prediction.
## Datasets/Benchmarks
- MoleculeNet: A Benchmark for Molecular Machine Learning [[paper]](https://arxiv.org/abs/1703.00564), [[website]](http://moleculenet.ai/)
- [Tox21 with DGL](dgllife/data/tox21.py)
- [PDBBind with DGL](dgllife/data/pdbbind.py)
- [Tox21 with DGL](../dgllife/data/tox21.py)
- [PDBBind with DGL](../dgllife/data/pdbbind.py)
- Alchemy: A Quantum Chemistry Dataset for Benchmarking AI Models [[paper]](https://arxiv.org/abs/1906.09427), [[github]](https://github.com/tencent-alchemy/Alchemy)
- [Alchemy with DGL](dgllife/data/alchemy.py)
- [Alchemy with DGL](../dgllife/data/alchemy.py)
## Property Prediction
- Semi-Supervised Classification with Graph Convolutional Networks (GCN) [[paper]](https://arxiv.org/abs/1609.02907), [[github]](https://github.com/tkipf/gcn)
- [GCN-Based Predictor with DGL](dgllife/model/model_zoo/gcn_predictor.py)
- [Example for Molecule Classification](examples/property_prediction/classification.py)
- [GCN-Based Predictor with DGL](../dgllife/model/model_zoo/gcn_predictor.py)
- [Example for Molecule Classification](property_prediction/classification.py)
- Graph Attention Networks (GAT) [[paper]](https://arxiv.org/abs/1710.10903), [[github]](https://github.com/PetarV-/GAT)
- [GAT-Based Predictor with DGL](dgllife/model/model_zoo/gat_predictor.py)
- [Example for Molecule Classification](examples/property_prediction/classification.py)
- [GAT-Based Predictor with DGL](../dgllife/model/model_zoo/gat_predictor.py)
- [Example for Molecule Classification](property_prediction/classification.py)
- SchNet: A continuous-filter convolutional neural network for modeling quantum interactions [[paper]](https://arxiv.org/abs/1706.08566), [[github]](https://github.com/atomistic-machine-learning/SchNet)
- [SchNet with DGL](dgllife/model/model_zoo/schnet_predictor.py)
- [Example for Molecule Regression](examples/property_prediction/regression.py)
- [SchNet with DGL](../dgllife/model/model_zoo/schnet_predictor.py)
- [Example for Molecule Regression](property_prediction/regression.py)
- Molecular Property Prediction: A Multilevel Quantum Interactions Modeling Perspective (MGCN) [[paper]](https://arxiv.org/abs/1906.11081)
- [MGCN with DGL](dgllife/model/model_zoo/mgcn_predictor.py)
- [Example for Molecule Regression](examples/property_prediction/regression.py)
- [MGCN with DGL](../dgllife/model/model_zoo/mgcn_predictor.py)
- [Example for Molecule Regression](property_prediction/regression.py)
- Neural Message Passing for Quantum Chemistry (MPNN) [[paper]](https://arxiv.org/abs/1704.01212), [[github]](https://github.com/brain-research/mpnn)
- [MPNN with DGL](dgllife/model/model_zoo/mpnn_predictor.py)
- [Example for Molecule Regression](examples/property_prediction/regression.py)
- [MPNN with DGL](../dgllife/model/model_zoo/mpnn_predictor.py)
- [Example for Molecule Regression](property_prediction/regression.py)
- Pushing the Boundaries of Molecular Representation for Drug Discovery with the Graph Attention Mechanism (AttentiveFP) [[paper]](https://pubs.acs.org/doi/abs/10.1021/acs.jmedchem.9b00959)
- [AttentiveFP with DGL](dgllife/model/model_zoo/attentivefp_predictor.py)
- [Example for Molecule Regression](examples/property_prediction/regression.py)
- [AttentiveFP with DGL](../dgllife/model/model_zoo/attentivefp_predictor.py)
- [Example for Molecule Regression](property_prediction/regression.py)
## Generative Models
- Learning Deep Generative Models of Graphs (DGMG) [[paper]](https://arxiv.org/abs/1803.03324)
- [DGMG with DGL](dgllife/model/model_zoo/dgmg.py)
- [Example Training Script](examples/generative_models/dgmg)
- [DGMG with DGL](../dgllife/model/model_zoo/dgmg.py)
- [Example Training Script](generative_models/dgmg)
- Junction Tree Variational Autoencoder for Molecular Graph Generation (JTNN) [[paper]](https://arxiv.org/abs/1802.04364)
- [JTNN with DGL](dgllife/model/model_zoo/jtnn)
- [Example Training Script](examples/generative_models/jtnn)
- [JTNN with DGL](../dgllife/model/model_zoo/jtnn)
- [Example Training Script](generative_models/jtnn)
## Binding Affinity Prediction
- Atomic Convolutional Networks for Predicting Protein-Ligand Binding Affinity (ACNN) [[paper]](https://arxiv.org/abs/1703.10603), [[github]](https://github.com/deepchem/deepchem/tree/master/contrib/atomicconv)
- [ACNN with DGL](dgllife/model/model_zoo/acnn.py)
- [Example Training Script](examples/binding_affinity_prediction)
- [ACNN with DGL](../dgllife/model/model_zoo/acnn.py)
- [Example Training Script](binding_affinity_prediction)
## Reaction Prediction
- A graph-convolutional neural network model for the prediction of chemical reactivity [[paper]](https://pubs.rsc.org/en/content/articlelanding/2019/sc/c8sc04228d#!divAbstract), [[github]](https://github.com/connorcoley/rexgen_direct)
- An earlier version was published in NeurIPS 2017 as "Predicting Organic Reaction Outcomes with Weisfeiler-Lehman Network" [[paper]](https://arxiv.org/abs/1709.04555)
- [WLN with DGL for Reaction Center Prediction](../dgllife/model/model_zoo/wln_reaction_center.py)
- [Example Script](reaction_prediction/rexgen_direct)
......@@ -83,12 +83,14 @@ def load_dataset_for_regression(args):
def collate_molgraphs(data):
"""Batching a list of datapoints for dataloader.
Parameters
----------
data : list of 3-tuples or 4-tuples.
Each tuple is for a single datapoint, consisting of
a SMILES, a DGLGraph, all-task labels and optionally
a binary mask indicating the existence of labels.
Returns
-------
smiles : list
......
# A graph-convolutional neural network model for the prediction of chemical reactivity
- [paper in Chemical Science](https://pubs.rsc.org/en/content/articlelanding/2019/sc/c8sc04228d#!divAbstract)
- [authors' code](https://github.com/connorcoley/rexgen_direct)
An earlier version of the work was published in NeurIPS 2017 as
["Predicting Organic Reaction Outcomes with Weisfeiler-Lehman Network"](https://arxiv.org/abs/1709.04555) with some
slight difference in modeling.
## Dataset
The example by default works with reactions from USPTO (United States Patent and Trademark) granted patents,
collected by Lowe [1]. After removing duplicates and erroneous reactions, the authors obtain a set of 480K reactions.
The dataset is divided into 400K, 40K, and 40K for training, validation and test.
## Reaction Center Prediction
### Modeling
Reaction centers refer to the pairs of atoms that lose/form a bond in the reactions. A Graph neural network
(Weisfeiler-Lehman Network in this case) is trained to update the representations of all atoms. Then we combine
pairs of atom representations to predict the likelihood for the corresponding atoms to form/lose a bond.
For evaluation, we select pairs of atoms with top-k scores for each reaction and compute the proportion of reactions
whose reaction centers have all been selected.
### Training with Default Options
We use GPU whenever possible. To train the model with default options, simply do
```bash
python find_reaction_center.py
```
Once the training process starts, the progress will be printed out in the terminal as follows:
```bash
Epoch 1/50, iter 8150/20452 | time/minibatch 0.0260 | loss 8.4788 | grad norm 12.9927
Epoch 1/50, iter 8200/20452 | time/minibatch 0.0260 | loss 8.6722 | grad norm 14.0833
```
After an epoch of training is completed, we evaluate the model on the validation set and
print the evaluation results as follows:
```bash
Epoch 4/50, validation | acc@10 0.8213 | acc@20 0.9016 |
```
By default, we store the model per 10000 iterations in `center_results`.
**Speedup**: For an epoch of training, our implementation takes about 5095s for the first epoch while the authors'
implementation takes about 11657s, which is roughly a speedup by 2.3x.
For model evaluation, we can choose whether to exclude reactants not contributing atoms to the product
(e.g. reagents and solvents) in top-k atom pair selection, which will make the task easier.
For the easier evaluation, do
```bash
python find_reaction_center.py --easy
```
A summary of the model performance is as follows:
| Item | Top 6 accuracy | Top 8 accuracy | Top 10 accuracy |
| --------------- | -------------- | -------------- | --------------- |
| Paper | 89.8 | 92.0 | 93.3 |
| Hard evaluation | 86.5 | 89.6 | 91.2 |
| Easy evaluation | 88.9 | 92.0 | 93.5 |
### Pre-trained Model
We provide a pre-trained model so users do not need to train from scratch. To evaluate the pre-trained model, simply do
```bash
python find_reaction_center.py -p
```
### Adapting to a new dataset.
New datasets should be processed such that each line corresponds to the SMILES for a reaction like below:
```bash
[CH3:14][NH2:15].[N+:1](=[O:2])([O-:3])[c:4]1[cH:5][c:6]([C:7](=[O:8])[OH:9])[cH:10][cH:11][c:12]1[Cl:13].[OH2:16]>>[N+:1](=[O:2])([O-:3])[c:4]1[cH:5][c:6]([C:7](=[O:8])[OH:9])[cH:10][cH:11][c:12]1[NH:15][CH3:14]
```
The reactants are placed before `>>` and the product is placed after `>>`. The reactants are separated by `.`.
In addition, atom mapping information is provided.
You can then train a model on new datasets with
```bash
python find_reaction_center.py --train-path X --val-path Y --test-path Z
```
where `X`, `Y`, `Z` are paths to the new training/validation/test set as described above.
## References
[1] D. M.Lowe, Patent reaction extraction: downloads,
https://bitbucket.org/dan2097/patent-reaction-extraction/downloads, 2014.
# Configuration for reaction center identification
reaction_center_config = {
'batch_size': 20,
'hidden_size': 300,
'max_norm': 5.0,
'node_in_feats': 82,
'edge_in_feats': 6,
'node_pair_in_feats': 10,
'node_out_feats': 300,
'n_layers': 3,
'n_tasks': 5,
'lr': 0.001,
'num_epochs': 25,
'print_every': 50,
'decay_every': 10000, # Learning rate decay
'lr_decay_factor': 0.9,
'top_ks': [6, 8, 10],
'max_k': 80
}
import numpy as np
import time
import torch
from dgllife.data import USPTO, WLNReactionDataset
from dgllife.model import WLNReactionCenter, load_pretrained
from torch.nn import BCEWithLogitsLoss
from torch.nn.utils import clip_grad_norm_
from torch.optim import Adam
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader
from utils import setup, collate, reaction_center_prediction, \
rough_eval_on_a_loader, reaction_center_final_eval
def main(args):
setup(args)
if args['train_path'] is None:
train_set = USPTO('train')
else:
train_set = WLNReactionDataset(raw_file_path=args['train_path'],
mol_graph_path='train.bin')
if args['val_path'] is None:
val_set = USPTO('val')
else:
val_set = WLNReactionDataset(raw_file_path=args['val_path'],
mol_graph_path='val.bin')
if args['test_path'] is None:
test_set = USPTO('test')
else:
test_set = WLNReactionDataset(raw_file_path=args['test_path'],
mol_graph_path='test.bin')
train_loader = DataLoader(train_set, batch_size=args['batch_size'],
collate_fn=collate, shuffle=True)
val_loader = DataLoader(val_set, batch_size=args['batch_size'],
collate_fn=collate, shuffle=False)
test_loader = DataLoader(test_set, batch_size=args['batch_size'],
collate_fn=collate, shuffle=False)
if args['pre_trained']:
model = load_pretrained('wln_center_uspto').to(args['device'])
args['num_epochs'] = 0
else:
model = WLNReactionCenter(node_in_feats=args['node_in_feats'],
edge_in_feats=args['edge_in_feats'],
node_pair_in_feats=args['node_pair_in_feats'],
node_out_feats=args['node_out_feats'],
n_layers=args['n_layers'],
n_tasks=args['n_tasks']).to(args['device'])
criterion = BCEWithLogitsLoss(reduction='sum')
optimizer = Adam(model.parameters(), lr=args['lr'])
scheduler = StepLR(optimizer, step_size=args['decay_every'], gamma=args['lr_decay_factor'])
total_iter = 0
grad_norm_sum = 0
loss_sum = 0
dur = []
for epoch in range(args['num_epochs']):
t0 = time.time()
for batch_id, batch_data in enumerate(train_loader):
total_iter += 1
batch_reactions, batch_graph_edits, batch_mols, batch_mol_graphs, \
batch_complete_graphs, batch_atom_pair_labels = batch_data
labels = batch_atom_pair_labels.to(args['device'])
pred, biased_pred = reaction_center_prediction(
args['device'], model, batch_mol_graphs, batch_complete_graphs)
loss = criterion(pred, labels) / len(batch_reactions)
loss_sum += loss.cpu().detach().data.item()
optimizer.zero_grad()
loss.backward()
grad_norm = clip_grad_norm_(model.parameters(), args['max_norm'])
grad_norm_sum += grad_norm
optimizer.step()
scheduler.step()
if total_iter % args['print_every'] == 0:
progress = 'Epoch {:d}/{:d}, iter {:d}/{:d} | time/minibatch {:.4f} | ' \
'loss {:.4f} | grad norm {:.4f}'.format(
epoch + 1, args['num_epochs'], batch_id + 1, len(train_loader),
(np.sum(dur) + time.time() - t0) / total_iter, loss_sum / args['print_every'],
grad_norm_sum / args['print_every'])
grad_norm_sum = 0
loss_sum = 0
print(progress)
if total_iter % args['decay_every'] == 0:
torch.save(model.state_dict(), args['result_path'] + '/model.pkl')
dur.append(time.time() - t0)
print('Epoch {:d}/{:d}, validation '.format(epoch + 1, args['num_epochs']) + \
rough_eval_on_a_loader(args, model, val_loader))
del train_loader
del val_loader
del train_set
del val_set
print('Evaluation on the test set.')
test_result = reaction_center_final_eval(args, model, test_loader, args['easy'])
print(test_result)
with open(args['result_path'] + '/results.txt', 'w') as f:
f.write(test_result)
if __name__ == '__main__':
from argparse import ArgumentParser
from configure import reaction_center_config
parser = ArgumentParser(description='Reaction Center Identification')
parser.add_argument('--result-path', type=str, default='center_results',
help='Path to training results')
parser.add_argument('--train-path', type=str, default=None,
help='Path to a new training set. '
'If None, we will use the default training set in USPTO.')
parser.add_argument('--val-path', type=str, default=None,
help='Path to a new validation set. '
'If None, we will use the default validation set in USPTO.')
parser.add_argument('--test-path', type=str, default=None,
help='Path to a new test set.'
'If None, we will use the default test set in USPTO.')
parser.add_argument('-p', '--pre-trained', action='store_true', default=False,
help='If true, we will directly evaluate a '
'pretrained model on the test set.')
parser.add_argument('--easy', action='store_true', default=False,
help='Whether to exclude reactants not contributing atoms to the '
'product in top-k atom pair selection, which will make the '
'task easier.')
args = parser.parse_args().__dict__
args.update(reaction_center_config)
main(args)
import dgl
import errno
import numpy as np
import os
import random
import torch
from collections import defaultdict
from rdkit import Chem
def mkdir_p(path):
"""Create a folder for the given path.
Parameters
----------
path: str
Folder to create
"""
try:
os.makedirs(path)
print('Created directory {}'.format(path))
except OSError as exc:
if exc.errno == errno.EEXIST and os.path.isdir(path):
print('Directory {} already exists.'.format(path))
else:
raise
def setup(args, seed=0):
"""Setup for the experiment:
1. Decide whether to use CPU or GPU for training
2. Fix random seed for python, NumPy and PyTorch.
Parameters
----------
seed : int
Random seed to use.
Returns
-------
args
Updated configuration
"""
assert args['max_k'] >= max(args['top_ks']), \
'Expect max_k to be no smaller than the possible options ' \
'of top_ks, got {:d} and {:d}'.format(args['max_k'], max(args['top_ks']))
if torch.cuda.is_available():
args['device'] = 'cuda:0'
else:
args['device'] = 'cpu'
# Set random seed
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
mkdir_p(args['result_path'])
return args
def collate(data):
"""Collate multiple datapoints
Parameters
----------
data : list of 7-tuples
Each tuple is for a single datapoint, consisting of
a reaction, graph edits in the reaction, an RDKit molecule instance for all reactants,
a DGLGraph for all reactants, a complete graph for all reactants, the features for each
pair of atoms and the labels for each pair of atoms.
Returns
-------
reactions : list of str
List of reactions.
graph_edits : list of str
List of graph edits in the reactions.
mols : list of rdkit.Chem.rdchem.Mol
List of RDKit molecule instances for the reactants.
batch_mol_graphs : DGLGraph
DGLGraph for a batch of molecular graphs.
batch_complete_graphs : DGLGraph
DGLGraph for a batch of complete graphs.
batch_atom_pair_labels : float32 tensor of shape (V, 10)
Labels of atom pairs in the batch of graphs.
"""
reactions, graph_edits, mols, mol_graphs, complete_graphs, \
atom_pair_feats, atom_pair_labels = map(list, zip(*data))
batch_mol_graphs = dgl.batch(mol_graphs)
batch_mol_graphs.set_n_initializer(dgl.init.zero_initializer)
batch_mol_graphs.set_e_initializer(dgl.init.zero_initializer)
batch_complete_graphs = dgl.batch(complete_graphs)
batch_complete_graphs.set_n_initializer(dgl.init.zero_initializer)
batch_complete_graphs.set_e_initializer(dgl.init.zero_initializer)
batch_complete_graphs.edata['feats'] = torch.cat(atom_pair_feats, dim=0)
batch_atom_pair_labels = torch.cat(atom_pair_labels, dim=0)
return reactions, graph_edits, mols, batch_mol_graphs, \
batch_complete_graphs, batch_atom_pair_labels
def reaction_center_prediction(device, model, mol_graphs, complete_graphs):
"""Perform a soft prediction on reaction center.
Parameters
----------
device : str
Device to use for computation, e.g. 'cpu', 'cuda:0'
model : nn.Module
Model for prediction.
mol_graphs : DGLGraph
DGLGraph for a batch of molecular graphs
complete_graphs : DGLGraph
DGLGraph for a batch of complete graphs
Returns
-------
scores : float32 tensor of shape (E_full, 5)
Predicted scores for each pair of atoms to perform one of the following
5 actions in reaction:
* The bond between them gets broken
* Forming a single bond
* Forming a double bond
* Forming a triple bond
* Forming an aromatic bond
biased_scores : float32 tensor of shape (E_full, 5)
Comparing to scores, a bias is added if the pair is for a same atom.
"""
node_feats = mol_graphs.ndata.pop('hv').to(device)
edge_feats = mol_graphs.edata.pop('he').to(device)
node_pair_feats = complete_graphs.edata.pop('feats').to(device)
return model(mol_graphs, complete_graphs, node_feats, edge_feats, node_pair_feats)
def rough_eval(complete_graphs, preds, labels, num_correct):
batch_size = complete_graphs.batch_size
start = 0
for i in range(batch_size):
end = start + complete_graphs.batch_num_edges[i]
preds_i = preds[start:end, :].flatten()
labels_i = labels[start:end, :].flatten()
for k in num_correct.keys():
topk_values, topk_indices = torch.topk(preds_i, k)
is_correct = labels_i[topk_indices].sum() == labels_i.sum().float().cpu().data.item()
num_correct[k].append(is_correct)
start = end
def rough_eval_on_a_loader(args, model, data_loader):
"""A rough evaluation of model performance in the middle of training.
For final evaluation, we will eliminate some possibilities based on prior knowledge.
Parameters
----------
args : dict
Configurations fot the experiment.
model : nn.Module
Model for reaction center prediction.
data_loader : torch.utils.data.DataLoader
Loader for fetching and batching data.
Returns
-------
str
Message for evluation result.
"""
model.eval()
num_correct = {k: [] for k in args['top_ks']}
for batch_id, batch_data in enumerate(data_loader):
batch_reactions, batch_graph_edits, batch_mols, batch_mol_graphs, \
batch_complete_graphs, batch_atom_pair_labels = batch_data
with torch.no_grad():
pred, biased_pred = reaction_center_prediction(
args['device'], model, batch_mol_graphs, batch_complete_graphs)
rough_eval(batch_complete_graphs, biased_pred, batch_atom_pair_labels, num_correct)
msg = '|'
for k, correct_count in num_correct.items():
msg += ' acc@{:d} {:.4f} |'.format(k, np.mean(correct_count))
return msg
def eval(complete_graphs, preds, reactions, graph_edits, num_correct, max_k, easy):
"""Evaluate top-k accuracies for reaction center prediction.
Parameters
----------
complete_graphs : DGLGraph
DGLGraph for a batch of complete graphs
preds : float32 tensor of shape (E_full, 5)
Soft predictions for reaction center, E_full being the number of possible
atom-pairs and 5 being the number of possible bond changes
reactions : list of str
List of reactions.
graph_edits : list of str
List of graph edits in the reactions.
num_correct : dict
Counting the number of datapoints for meeting top-k accuracies.
max_k : int
Maximum number of atom pairs to be selected. This is intended to be larger
than max(num_correct.keys()) as we will filter out many atom pairs due to
considerations such as avoiding duplicates.
easy : bool
If True, reactants not contributing atoms to the product will be excluded in
top-k atom pair selection, which will make the task easier.
"""
# 0 for losing the bond
# 1, 2, 3, 1.5 separately for forming a single, double, triple or aromatic bond.
bond_change_to_id = {0.0: 0, 1:1, 2:2, 3:3, 1.5:4}
id_to_bond_change = {v: k for k, v in bond_change_to_id.items()}
num_change_types = len(bond_change_to_id)
batch_size = complete_graphs.batch_size
start = 0
for i in range(batch_size):
# Decide which atom-pairs will be considered.
reaction_i = reactions[i]
reaction_atoms_i = []
reaction_bonds_i = defaultdict(bool)
reactants_i, _, product_i = reaction_i.split('>')
product_mol_i = Chem.MolFromSmiles(product_i)
product_atoms_i = set([atom.GetAtomMapNum() for atom in product_mol_i.GetAtoms()])
for reactant in reactants_i.split('.'):
reactant_mol = Chem.MolFromSmiles(reactant)
reactant_atoms = [atom.GetAtomMapNum() for atom in reactant_mol.GetAtoms()]
if (len(set(reactant_atoms) & product_atoms_i) > 0) or (not easy):
reaction_atoms_i.extend(reactant_atoms)
for bond in reactant_mol.GetBonds():
end_atoms = sorted([bond.GetBeginAtom().GetAtomMapNum(),
bond.GetEndAtom().GetAtomMapNum()])
bond = tuple(end_atoms + [bond.GetBondTypeAsDouble()])
reaction_bonds_i[bond] = True
num_nodes = complete_graphs.batch_num_nodes[i]
end = start + complete_graphs.batch_num_edges[i]
preds_i = preds[start:end, :].flatten()
candidate_bonds = []
topk_values, topk_indices = torch.topk(preds_i, max_k)
for j in range(max_k):
preds_i_j = topk_indices[j].cpu().item()
# A bond change can be either losing the bond or forming a
# single, double, triple or aromatic bond
change_id = preds_i_j % num_change_types
change_type = id_to_bond_change[change_id]
pair_id = preds_i_j // num_change_types
atom1 = pair_id // num_nodes + 1
atom2 = pair_id % num_nodes + 1
# Avoid duplicates and an atom cannot form a bond with itself
if atom1 >= atom2:
continue
if atom1 not in reaction_atoms_i:
continue
if atom2 not in reaction_atoms_i:
continue
candidate = (int(atom1), int(atom2), float(change_type))
if reaction_bonds_i[candidate]:
continue
candidate_bonds.append(candidate)
gold_bonds = []
gold_edits = graph_edits[i]
for edit in gold_edits.split(';'):
atom1, atom2, change_type = edit.split('-')
atom1, atom2 = int(atom1), int(atom2)
gold_bonds.append((min(atom1, atom2), max(atom1, atom2), float(change_type)))
for k in num_correct.keys():
if set(gold_bonds) <= set(candidate_bonds[:k]):
num_correct[k] += 1
start = end
def reaction_center_final_eval(args, model, data_loader, easy):
"""Final evaluation of model performance.
args : dict
Configurations fot the experiment.
model : nn.Module
Model for reaction center prediction.
data_loader : torch.utils.data.DataLoader
Loader for fetching and batching data.
easy : bool
If True, reactants not contributing atoms to the product will be excluded in
top-k atom pair selection, which will make the task easier.
Returns
-------
msg : str
Summary of the top-k evaluation.
"""
model.eval()
num_correct = {k: 0 for k in args['top_ks']}
for batch_id, batch_data in enumerate(data_loader):
batch_reactions, batch_graph_edits, batch_mols, batch_mol_graphs, \
batch_complete_graphs, batch_atom_pair_labels = batch_data
with torch.no_grad():
pred, biased_pred = reaction_center_prediction(
args['device'], model, batch_mol_graphs, batch_complete_graphs)
eval(batch_complete_graphs, biased_pred, batch_reactions,
batch_graph_edits, num_correct, args['max_k'], easy)
msg = '|'
for k, correct_count in num_correct.items():
msg += ' acc@{:d} {:.4f} |'.format(k, correct_count / len(data_loader.dataset))
return msg
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment