"vscode:/vscode.git/clone" did not exist on "fff7fbabe63f44e6aa4fe7d3a4a5a3215370ad4f"
Unverified Commit 36c7b771 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[LifeSci] Move to Independent Repo (#1592)

* Move LifeSci

* Remove doc
parent 94c67203
......@@ -113,7 +113,6 @@ a useful manual for in-depth developers.
api/python/data
api/python/nodeflow
api/python/random
api/python/model_zoo
.. toctree::
:maxdepth: 1
......
# DGL for Chemistry
With atoms being nodes and bonds being edges, molecular graphs are among the core objects for study in Chemistry.
Deep learning on graphs can be beneficial for various applications in Chemistry like drug and material discovery
[1], [2], [12].
To make it easy for domain scientists, the DGL team releases a model zoo for Chemistry, spanning three cases
-- property prediction, target generation/optimization and binding affinity prediction.
With pre-trained models and training scripts, we hope this model zoo will be helpful for both
the chemistry community and the deep learning community to further their research.
## Dependencies
Before you proceed, depending on the model/task you are interested,
you may need to install the dependencies below:
- PyTorch 1.2
- Check the [official website](https://pytorch.org/) for installation guide.
- RDKit 2018.09.3
- We recommend installation with `conda install -c conda-forge rdkit==2018.09.3`. For other installation recipes,
see the [official documentation](https://www.rdkit.org/docs/Install.html).
- Pdbfixer
- We recommend installation with `conda install -c omnia pdbfixer`. To install from source, see the
[manual](http://htmlpreview.github.io/?https://raw.github.com/pandegroup/pdbfixer/master/Manual.html).
- MDTraj
- We recommend installation with `conda install -c conda-forge mdtraj`. For alternative ways of installation,
see the [official documentation](http://mdtraj.org/1.9.3/installation.html).
The rest dependencies can be installed with `pip install -r requirements.txt`.
## Speed Reference
Below we provide some reference numbers to show how DGL improves the speed of training models per epoch in seconds.
| Model | Original Implementation | DGL Implementation | Improvement |
| -------------------------- | ----------------------- | ------------------ | ----------- |
| GCN on Tox21 | 5.5 (DeepChem) | 1.0 | 5.5x |
| AttentiveFP on Aromaticity | 6.0 | 1.2 | 5x |
| JTNN on ZINC | 1826 | 743 | 2.5x |
## Featurization and Representation Learning
Fingerprint has been a widely used concept in cheminformatics. Chemists developed hand designed rules to convert
molecules into binary strings where each bit indicates the presence or absence of a particular substructure. The
development of fingerprints makes the comparison of molecules a lot easier. Previous machine learning methods are
mostly developed based on molecule fingerprints.
Graph neural networks make it possible for a data-driven representation of molecules out of the atoms, bonds and
molecular graph topology, which may be viewed as a learned fingerprint [3].
## Property Prediction
To evaluate molecules for drug candidates, we need to know their properties and activities. In practice, this is
mostly achieved via wet lab experiments. We can cast the problem as a regression or classification problem.
In practice, this can be quite difficult due to the scarcity of labeled data.
### Models
- **Graph Convolutional Networks** [3], [9]: Graph Convolutional Networks (GCN) have been one of the most popular graph
neural networks and they can be easily extended for graph level prediction.
- **Graph Attention Networks** [10]: Graph Attention Networks (GATs) incorporate multi-head attention into GCNs,
explicitly modeling the interactions between adjacent atoms.
- **SchNet** [4]: SchNet is a novel deep learning architecture modeling quantum interactions in molecules which utilize
the continuous-filter convolutional layers.
- **Multilevel Graph Convolutional neural Network** [5]: Multilevel Graph Convolutional neural Network (MGCN) is a well-designed
hierarchical graph neural network directly extracts features from the conformation and spatial information followed
by the multilevel interactions.
- **Message Passing Neural Network** [6]: Message Passing Neural Network (MPNN) is a well-designed network with edge network (enn)
as front end and Set2Set for output prediction.
### Example Usage of Pre-trained Models
```python
from dgl.data.chem import Tox21, smiles_to_bigraph, CanonicalAtomFeaturizer
from dgl import model_zoo
dataset = Tox21(smiles_to_bigraph, CanonicalAtomFeaturizer())
model = model_zoo.chem.load_pretrained('GCN_Tox21') # Pretrained model loaded
model.eval()
smiles, g, label, mask = dataset[0]
feats = g.ndata.pop('h')
label_pred = model(g, feats)
print(smiles) # CCOc1ccc2nc(S(N)(=O)=O)sc2c1
print(label_pred[:, mask != 0]) # Mask non-existing labels
# tensor([[-0.7956, 0.4054, 0.4288, -0.5565, -0.0911,
# 0.9981, -0.1663, 0.2311, -0.2376, 0.9196]])
```
## Generative Models
We use generative models for two different purposes when it comes to molecules:
- **Distribution Learning**: Given a collection of molecules, we want to model their distribution and generate new
molecules with similar properties.
- **Goal-directed Optimization**: Find molecules with desired properties.
For this model zoo, we will only focused on generative models for molecular graphs. There are other generative models
working with alternative representations like SMILES.
Generative models are known to be difficult for evaluation. [GuacaMol](https://github.com/BenevolentAI/guacamol) and
[MOSES](https://github.com/molecularsets/moses) have been two recent efforts to benchmark generative models. There
are also two accompanying review papers that are well written [7], [8].
### Models
- **Deep Generative Models of Graphs (DGMG)** [11]: A very general framework for graph distribution learning by
progressively adding atoms and bonds.
- **Junction Tree Variational Autoencoder for Molecular Graph Generation (JTNN)** [13]: JTNNs are able to incrementally
expand molecules while maintaining chemical valency at every step. They can be used for both molecule generation and
optimization.
### Example Usage of Pre-trained Models
```python
# We recommend running the code below with Jupyter notebooks
from IPython.display import SVG
from rdkit import Chem
from rdkit.Chem import Draw
from dgl import model_zoo
model = model_zoo.chem.load_pretrained('DGMG_ZINC_canonical')
model.eval()
mols = []
for i in range(4):
SMILES = model(rdkit_mol=True)
mols.append(Chem.MolFromSmiles(SMILES))
# Generating 4 molecules takes less than a second.
SVG(Draw.MolsToGridImage(mols, molsPerRow=4, subImgSize=(180, 150), useSVG=True))
```
![](https://data.dgl.ai/dgllife/dgmg/dgmg_model_zoo_example2.png)
## Binding affinity prediction
The interaction of drugs and proteins can be characterized in terms of binding affinity. Given a pair of ligand
(drug candidate) and protein with particular conformations, we are interested in predicting the
binding affinity between them.
### Models
- **Atomic Convolutional Networks** [14]: Constructs nearest neighbor graphs separately for the ligand, protein and complex
based on the 3D coordinates of the atoms and predicts the binding free energy.
## References
[1] Chen et al. (2018) The rise of deep learning in drug discovery. *Drug Discov Today* 6, 1241-1250.
[2] Vamathevan et al. (2019) Applications of machine learning in drug discovery and development.
*Nature Reviews Drug Discovery* 18, 463-477.
[3] Duvenaud et al. (2015) Convolutional networks on graphs for learning molecular fingerprints. *Advances in neural
information processing systems (NeurIPS)*, 2224-2232.
[4] Schütt et al. (2017) SchNet: A continuous-filter convolutional neural network for modeling quantum interactions.
*Advances in Neural Information Processing Systems (NeurIPS)*, 992-1002.
[5] Lu et al. Molecular Property Prediction: A Multilevel Quantum Interactions Modeling Perspective.
*The 33rd AAAI Conference on Artificial Intelligence*.
[6] Gilmer et al. (2017) Neural Message Passing for Quantum Chemistry. *Proceedings of the 34th International Conference on
Machine Learning* JMLR. 1263-1272.
[7] Brown et al. (2019) GuacaMol: Benchmarking Models for de Novo Molecular Design. *J. Chem. Inf. Model*, 2019, 59, 3,
1096-1108.
[8] Polykovskiy et al. (2019) Molecular Sets (MOSES): A Benchmarking Platform for Molecular Generation Models. *arXiv*.
[9] Kipf et al. (2017) Semi-Supervised Classification with Graph Convolutional Networks.
*The International Conference on Learning Representations (ICLR)*.
[10] Veličković et al. (2018) Graph Attention Networks.
*The International Conference on Learning Representations (ICLR)*.
[11] Li et al. (2018) Learning Deep Generative Models of Graphs. *arXiv preprint arXiv:1803.03324*.
[12] Goh et al. (2017) Deep learning for computational chemistry. *Journal of Computational Chemistry* 16, 1291-1307.
[13] Jin et al. (2018) Junction Tree Variational Autoencoder for Molecular Graph Generation.
*Proceedings of the 35th International Conference on Machine Learning (ICML)*, 2323-2332.
[14] Gomes et al. (2017) Atomic Convolutional Networks for Predicting Protein-Ligand Binding Affinity. *arXiv preprint arXiv:1703.10603*.
# Binding Affinity Prediction
## Datasets
- **PDBBind**: The PDBBind dataset in MoleculeNet [1] processed from the PDBBind database. The PDBBind
database consists of experimentally measured binding affinities for bio-molecular complexes [2], [3].
It provides detailed 3D Cartesian coordinates of both ligands and their target proteins derived from
experimental(e.g., X-ray crystallography) measurements. The availability of coordinates of the
protein-ligand complexes permits structure-based featurization that is aware of the protein-ligand
binding geometry. The authors of [1] use the "refined" and "core" subsets of the database [4], more carefully
processed for data artifacts, as additional benchmarking targets.
## Models
- **Atomic Convolutional Networks (ACNN)** [5]: Constructs nearest neighbor graphs separately for the ligand, protein and complex
based on the 3D coordinates of the atoms and predicts the binding free energy.
## Usage
Use `main.py` with arguments
```
-m {ACNN}, Model to use
-d {PDBBind_core_pocket_random, PDBBind_core_pocket_scaffold, PDBBind_core_pocket_stratified,
PDBBind_core_pocket_temporal, PDBBind_refined_pocket_random, PDBBind_refined_pocket_scaffold,
PDBBind_refined_pocket_stratified, PDBBind_refined_pocket_temporal}, dataset and splitting method to use
```
## Performance
### PDBBind
#### ACNN
| Subset | Splitting Method | Test MAE | Test R2 |
| ------- | ---------------- | -------- | ------- |
| Core | Random | 1.7688 | 0.1511 |
| Core | Scaffold | 2.5420 | 0.1471 |
| Core | Stratified | 1.7419 | 0.1520 |
| Core | Temporal | 1.9543 | 0.1640 |
| Refined | Random | 1.1948 | 0.4373 |
| Refined | Scaffold | 1.4021 | 0.2086 |
| Refined | Stratified | 1.6376 | 0.3050 |
| Refined | Temporal | 1.2457 | 0.3438 |
## Speed
### ACNN
Comparing to the [DeepChem's implementation](https://github.com/joegomes/deepchem/tree/acdc), we achieve a speedup by
roughly 3.3 for training time per epoch (from 1.40s to 0.42s). If we do not care about
randomness introduced by some kernel optimization, we can achieve a speedup by roughly 4.4 (from 1.40s to 0.32s).
## References
[1] Wu et al. (2017) MoleculeNet: a benchmark for molecular machine learning. *Chemical Science* 9, 513-530.
[2] Wang et al. (2004) The PDBbind database: collection of binding affinities for protein-ligand complexes
with known three-dimensional structures. *J Med Chem* 3;47(12):2977-80.
[3] Wang et al. (2005) The PDBbind database: methodologies and updates. *J Med Chem* 16;48(12):4111-9.
[4] Liu et al. (2015) PDB-wide collection of binding data: current status of the PDBbind database. *Bioinformatics* 1;31(3):405-12.
[5] Gomes et al. (2017) Atomic Convolutional Networks for Predicting Protein-Ligand Binding Affinity. *arXiv preprint arXiv:1703.10603*.
import numpy as np
import torch
ACNN_PDBBind_core_pocket_random = {
'dataset': 'PDBBind',
'subset': 'core',
'load_binding_pocket': True,
'random_seed': 123,
'frac_train': 0.8,
'frac_val': 0.,
'frac_test': 0.2,
'batch_size': 24,
'shuffle': False,
'hidden_sizes': [32, 32, 16],
'weight_init_stddevs': [1. / float(np.sqrt(32)), 1. / float(np.sqrt(32)),
1. / float(np.sqrt(16)), 0.01],
'dropouts': [0., 0., 0.],
'atomic_numbers_considered': torch.tensor([
1., 6., 7., 8., 9., 11., 12., 15., 16., 17., 20., 25., 30., 35., 53.]),
'radial': [[12.0], [0.0, 4.0, 8.0], [4.0]],
'lr': 0.001,
'num_epochs': 120,
'metrics': ['pearson_r2', 'mae'],
'split': 'random'
}
ACNN_PDBBind_core_pocket_scaffold = {
'dataset': 'PDBBind',
'subset': 'core',
'load_binding_pocket': True,
'random_seed': 123,
'frac_train': 0.8,
'frac_val': 0.,
'frac_test': 0.2,
'batch_size': 24,
'shuffle': False,
'hidden_sizes': [32, 32, 16],
'weight_init_stddevs': [1. / float(np.sqrt(32)), 1. / float(np.sqrt(32)),
1. / float(np.sqrt(16)), 0.01],
'dropouts': [0., 0., 0.],
'atomic_numbers_considered': torch.tensor([
1., 6., 7., 8., 9., 11., 12., 15., 16., 17., 20., 25., 30., 35., 53.]),
'radial': [[12.0], [0.0, 4.0, 8.0], [4.0]],
'lr': 0.001,
'num_epochs': 170,
'metrics': ['pearson_r2', 'mae'],
'split': 'scaffold'
}
ACNN_PDBBind_core_pocket_stratified = {
'dataset': 'PDBBind',
'subset': 'core',
'load_binding_pocket': True,
'random_seed': 123,
'frac_train': 0.8,
'frac_val': 0.,
'frac_test': 0.2,
'batch_size': 24,
'shuffle': False,
'hidden_sizes': [32, 32, 16],
'weight_init_stddevs': [1. / float(np.sqrt(32)), 1. / float(np.sqrt(32)),
1. / float(np.sqrt(16)), 0.01],
'dropouts': [0., 0., 0.],
'atomic_numbers_considered': torch.tensor([
1., 6., 7., 8., 9., 11., 12., 15., 16., 17., 20., 25., 30., 35., 53.]),
'radial': [[12.0], [0.0, 4.0, 8.0], [4.0]],
'lr': 0.001,
'num_epochs': 110,
'metrics': ['pearson_r2', 'mae'],
'split': 'stratified'
}
ACNN_PDBBind_core_pocket_temporal = {
'dataset': 'PDBBind',
'subset': 'core',
'load_binding_pocket': True,
'random_seed': 123,
'frac_train': 0.8,
'frac_val': 0.,
'frac_test': 0.2,
'batch_size': 24,
'shuffle': False,
'hidden_sizes': [32, 32, 16],
'weight_init_stddevs': [1. / float(np.sqrt(32)), 1. / float(np.sqrt(32)),
1. / float(np.sqrt(16)), 0.01],
'dropouts': [0., 0., 0.],
'atomic_numbers_considered': torch.tensor([
1., 6., 7., 8., 9., 11., 12., 15., 16., 17., 20., 25., 30., 35., 53.]),
'radial': [[12.0], [0.0, 4.0, 8.0], [4.0]],
'lr': 0.001,
'num_epochs': 80,
'metrics': ['pearson_r2', 'mae'],
'split': 'temporal'
}
ACNN_PDBBind_refined_pocket_random = {
'dataset': 'PDBBind',
'subset': 'refined',
'load_binding_pocket': True,
'random_seed': 123,
'frac_train': 0.8,
'frac_val': 0.,
'frac_test': 0.2,
'batch_size': 24,
'shuffle': False,
'hidden_sizes': [128, 128, 64],
'weight_init_stddevs': [0.125, 0.125, 0.177, 0.01],
'dropouts': [0.4, 0.4, 0.],
'atomic_numbers_considered': torch.tensor([
1., 6., 7., 8., 9., 11., 12., 15., 16., 17., 19., 20., 25., 26., 27., 28.,
29., 30., 34., 35., 38., 48., 53., 55., 80.]),
'radial': [[12.0], [0.0, 2.0, 4.0, 6.0, 8.0], [4.0]],
'lr': 0.001,
'num_epochs': 200,
'metrics': ['pearson_r2', 'mae'],
'split': 'random'
}
ACNN_PDBBind_refined_pocket_scaffold = {
'dataset': 'PDBBind',
'subset': 'refined',
'load_binding_pocket': True,
'random_seed': 123,
'frac_train': 0.8,
'frac_val': 0.,
'frac_test': 0.2,
'batch_size': 24,
'shuffle': False,
'hidden_sizes': [128, 128, 64],
'weight_init_stddevs': [0.125, 0.125, 0.177, 0.01],
'dropouts': [0.4, 0.4, 0.],
'atomic_numbers_considered': torch.tensor([
1., 6., 7., 8., 9., 11., 12., 15., 16., 17., 19., 20., 25., 26., 27., 28.,
29., 30., 34., 35., 38., 48., 53., 55., 80.]),
'radial': [[12.0], [0.0, 2.0, 4.0, 6.0, 8.0], [4.0]],
'lr': 0.001,
'num_epochs': 350,
'metrics': ['pearson_r2', 'mae'],
'split': 'scaffold'
}
ACNN_PDBBind_refined_pocket_stratified = {
'dataset': 'PDBBind',
'subset': 'refined',
'load_binding_pocket': True,
'random_seed': 123,
'frac_train': 0.8,
'frac_val': 0.,
'frac_test': 0.2,
'batch_size': 24,
'shuffle': False,
'hidden_sizes': [128, 128, 64],
'weight_init_stddevs': [0.125, 0.125, 0.177, 0.01],
'dropouts': [0.4, 0.4, 0.],
'atomic_numbers_considered': torch.tensor([
1., 6., 7., 8., 9., 11., 12., 15., 16., 17., 19., 20., 25., 26., 27., 28.,
29., 30., 34., 35., 38., 48., 53., 55., 80.]),
'radial': [[12.0], [0.0, 2.0, 4.0, 6.0, 8.0], [4.0]],
'lr': 0.001,
'num_epochs': 400,
'metrics': ['pearson_r2', 'mae'],
'split': 'stratified'
}
ACNN_PDBBind_refined_pocket_temporal = {
'dataset': 'PDBBind',
'subset': 'refined',
'load_binding_pocket': True,
'random_seed': 123,
'frac_train': 0.8,
'frac_val': 0.,
'frac_test': 0.2,
'batch_size': 24,
'shuffle': False,
'hidden_sizes': [128, 128, 64],
'weight_init_stddevs': [0.125, 0.125, 0.177, 0.01],
'dropouts': [0.4, 0.4, 0.],
'atomic_numbers_considered': torch.tensor([
1., 6., 7., 8., 9., 11., 12., 15., 16., 17., 19., 20., 25., 26., 27., 28.,
29., 30., 34., 35., 38., 48., 53., 55., 80.]),
'radial': [[12.0], [0.0, 2.0, 4.0, 6.0, 8.0], [4.0]],
'lr': 0.001,
'num_epochs': 350,
'metrics': ['pearson_r2', 'mae'],
'split': 'temporal'
}
experiment_configures = {
'ACNN_PDBBind_core_pocket_random': ACNN_PDBBind_core_pocket_random,
'ACNN_PDBBind_core_pocket_scaffold': ACNN_PDBBind_core_pocket_scaffold,
'ACNN_PDBBind_core_pocket_stratified': ACNN_PDBBind_core_pocket_stratified,
'ACNN_PDBBind_core_pocket_temporal': ACNN_PDBBind_core_pocket_temporal,
'ACNN_PDBBind_refined_pocket_random': ACNN_PDBBind_refined_pocket_random,
'ACNN_PDBBind_refined_pocket_scaffold': ACNN_PDBBind_refined_pocket_scaffold,
'ACNN_PDBBind_refined_pocket_stratified': ACNN_PDBBind_refined_pocket_stratified,
'ACNN_PDBBind_refined_pocket_temporal': ACNN_PDBBind_refined_pocket_temporal
}
def get_exp_configure(exp_name):
return experiment_configures[exp_name]
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from utils import set_random_seed, load_dataset, collate, load_model, Meter
def update_msg_from_scores(msg, scores):
for metric, score in scores.items():
msg += ', {} {:.4f}'.format(metric, score)
return msg
def run_a_train_epoch(args, epoch, model, data_loader,
loss_criterion, optimizer):
model.train()
train_meter = Meter(args['train_mean'], args['train_std'])
epoch_loss = 0
for batch_id, batch_data in enumerate(data_loader):
indices, ligand_mols, protein_mols, bg, labels = batch_data
labels, bg = labels.to(args['device']), bg.to(args['device'])
prediction = model(bg)
loss = loss_criterion(prediction, (labels - args['train_mean']) / args['train_std'])
epoch_loss += loss.data.item() * len(indices)
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_meter.update(prediction, labels)
avg_loss = epoch_loss / len(data_loader.dataset)
total_scores = {metric: train_meter.compute_metric(metric) for metric in args['metrics']}
msg = 'epoch {:d}/{:d}, training | loss {:.4f}'.format(
epoch + 1, args['num_epochs'], avg_loss)
msg = update_msg_from_scores(msg, total_scores)
print(msg)
def run_an_eval_epoch(args, model, data_loader):
model.eval()
eval_meter = Meter(args['train_mean'], args['train_std'])
with torch.no_grad():
for batch_id, batch_data in enumerate(data_loader):
indices, ligand_mols, protein_mols, bg, labels = batch_data
labels, bg = labels.to(args['device']), bg.to(args['device'])
prediction = model(bg)
eval_meter.update(prediction, labels)
total_scores = {metric: eval_meter.compute_metric(metric) for metric in args['metrics']}
return total_scores
def main(args):
args['device'] = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
set_random_seed(args['random_seed'])
dataset, train_set, test_set = load_dataset(args)
args['train_mean'] = train_set.labels_mean.to(args['device'])
args['train_std'] = train_set.labels_std.to(args['device'])
train_loader = DataLoader(dataset=train_set,
batch_size=args['batch_size'],
shuffle=False,
collate_fn=collate)
test_loader = DataLoader(dataset=test_set,
batch_size=args['batch_size'],
shuffle=True,
collate_fn=collate)
model = load_model(args)
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=args['lr'])
model.to(args['device'])
for epoch in range(args['num_epochs']):
run_a_train_epoch(args, epoch, model, train_loader, loss_fn, optimizer)
test_scores = run_an_eval_epoch(args, model, test_loader)
test_msg = update_msg_from_scores('test results', test_scores)
print(test_msg)
if __name__ == '__main__':
import argparse
from configure import get_exp_configure
parser = argparse.ArgumentParser(description='Protein-Ligand Binding Affinity Prediction')
parser.add_argument('-m', '--model', type=str, choices=['ACNN'],
help='Model to use')
parser.add_argument('-d', '--dataset', type=str,
choices=['PDBBind_core_pocket_random', 'PDBBind_core_pocket_scaffold',
'PDBBind_core_pocket_stratified', 'PDBBind_core_pocket_temporal',
'PDBBind_refined_pocket_random', 'PDBBind_refined_pocket_scaffold',
'PDBBind_refined_pocket_stratified', 'PDBBind_refined_pocket_temporal'],
help='Dataset to use')
args = parser.parse_args().__dict__
args['exp'] = '_'.join([args['model'], args['dataset']])
args.update(get_exp_configure(args['exp']))
main(args)
import dgl
import numpy as np
import random
import torch
import torch.nn.functional as F
from dgl import model_zoo
from dgl.data.chem import PDBBind, RandomSplitter, ScaffoldSplitter, SingleTaskStratifiedSplitter
from dgl.data.utils import Subset
from itertools import accumulate
from scipy.stats import pearsonr
def set_random_seed(seed=0):
"""Set random seed.
Parameters
----------
seed : int
Random seed to use. Default to 0.
"""
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
def load_dataset(args):
"""Load the dataset.
Parameters
----------
args : dict
Input arguments.
Returns
-------
dataset
Full dataset.
train_set
Train subset of the dataset.
val_set
Validation subset of the dataset.
"""
assert args['dataset'] in ['PDBBind'], 'Unexpected dataset {}'.format(args['dataset'])
if args['dataset'] == 'PDBBind':
dataset = PDBBind(subset=args['subset'],
load_binding_pocket=args['load_binding_pocket'],
zero_padding=True)
# No validation set is used and frac_val = 0.
if args['split'] == 'random':
train_set, _, test_set = RandomSplitter.train_val_test_split(
dataset,
frac_train=args['frac_train'],
frac_val=args['frac_val'],
frac_test=args['frac_test'],
random_state=args['random_seed'])
elif args['split'] == 'scaffold':
train_set, _, test_set = ScaffoldSplitter.train_val_test_split(
dataset,
mols=dataset.ligand_mols,
sanitize=False,
frac_train=args['frac_train'],
frac_val=args['frac_val'],
frac_test=args['frac_test'])
elif args['split'] == 'stratified':
train_set, _, test_set = SingleTaskStratifiedSplitter.train_val_test_split(
dataset,
labels=dataset.labels,
task_id=0,
frac_train=args['frac_train'],
frac_val=args['frac_val'],
frac_test=args['frac_test'],
random_state=args['random_seed'])
elif args['split'] == 'temporal':
years = dataset.df['release_year'].values.astype(np.float32)
indices = np.argsort(years).tolist()
frac_list = np.array([args['frac_train'], args['frac_val'], args['frac_test']])
num_data = len(dataset)
lengths = (num_data * frac_list).astype(int)
lengths[-1] = num_data - np.sum(lengths[:-1])
train_set, val_set, test_set = [
Subset(dataset, list(indices[offset - length:offset]))
for offset, length in zip(accumulate(lengths), lengths)]
else:
raise ValueError('Expect the splitting method '
'to be "random" or "scaffold", got {}'.format(args['split']))
train_labels = torch.stack([train_set.dataset.labels[i] for i in train_set.indices])
train_set.labels_mean = train_labels.mean(dim=0)
train_set.labels_std = train_labels.std(dim=0)
return dataset, train_set, test_set
def collate(data):
indices, ligand_mols, protein_mols, graphs, labels = map(list, zip(*data))
bg = dgl.batch_hetero(graphs)
for nty in bg.ntypes:
bg.set_n_initializer(dgl.init.zero_initializer, ntype=nty)
for ety in bg.canonical_etypes:
bg.set_e_initializer(dgl.init.zero_initializer, etype=ety)
labels = torch.stack(labels, dim=0)
return indices, ligand_mols, protein_mols, bg, labels
def load_model(args):
assert args['model'] in ['ACNN'], 'Unexpected model {}'.format(args['model'])
if args['model'] == 'ACNN':
model = model_zoo.chem.ACNN(hidden_sizes=args['hidden_sizes'],
weight_init_stddevs=args['weight_init_stddevs'],
dropouts=args['dropouts'],
features_to_use=args['atomic_numbers_considered'],
radial=args['radial'])
return model
class Meter(object):
"""Track and summarize model performance on a dataset for (multi-label) prediction.
Parameters
----------
torch.float32 tensor of shape (T)
Mean of existing training labels across tasks, T for the number of tasks
torch.float32 tensor of shape (T)
Std of existing training labels across tasks, T for the number of tasks
"""
def __init__(self, mean=None, std=None):
self.y_pred = []
self.y_true = []
if (type(mean) != type(None)) and (type(std) != type(None)):
self.mean = mean.cpu()
self.std = std.cpu()
else:
self.mean = None
self.std = None
def update(self, y_pred, y_true):
"""Update for the result of an iteration
Parameters
----------
y_pred : float32 tensor
Predicted molecule labels with shape (B, T),
B for batch size and T for the number of tasks
y_true : float32 tensor
Ground truth molecule labels with shape (B, T)
"""
self.y_pred.append(y_pred.detach().cpu())
self.y_true.append(y_true.detach().cpu())
def _finalize_labels_and_prediction(self):
"""Concatenate the labels and predictions.
If normalization was performed on the labels, undo the normalization.
"""
y_pred = torch.cat(self.y_pred, dim=0)
y_true = torch.cat(self.y_true, dim=0)
if (self.mean is not None) and (self.std is not None):
# To compensate for the imbalance between labels during training,
# we normalize the ground truth labels with training mean and std.
# We need to undo that for evaluation.
y_pred = y_pred * self.std + self.mean
return y_pred, y_true
def pearson_r2(self):
"""Compute squared Pearson correlation coefficient
Returns
-------
float
"""
y_pred, y_true = self._finalize_labels_and_prediction()
return pearsonr(y_true[:, 0].numpy(), y_pred[:, 0].numpy())[0] ** 2
def mae(self):
"""Compute MAE
Returns
-------
float
"""
y_pred, y_true = self._finalize_labels_and_prediction()
return F.l1_loss(y_true, y_pred).data.item()
def rmse(self):
"""
Compute RMSE
Returns
-------
float
"""
y_pred, y_true = self._finalize_labels_and_prediction()
return np.sqrt(F.mse_loss(y_pred, y_true).cpu().item())
def compute_metric(self, metric_name):
"""Compute metric
Parameters
----------
metric_name : str
Name for the metric to compute.
Returns
-------
float
Metric value
"""
assert metric_name in ['pearson_r2', 'mae', 'rmse'], \
'Expect metric name to be "pearson_r2", "mae" or "rmse", got {}'.format(metric_name)
if metric_name == 'pearson_r2':
return self.pearson_r2()
if metric_name == 'mae':
return self.mae()
if metric_name == 'rmse':
return self.rmse()
# Learning Deep Generative Models of Graphs (DGMG)
Yujia Li, Oriol Vinyals, Chris Dyer, Razvan Pascanu, and Peter Battaglia.
Learning Deep Generative Models of Graphs. *arXiv preprint arXiv:1803.03324*, 2018.
DGMG generates graphs by progressively adding nodes and edges as below:
![](https://user-images.githubusercontent.com/19576924/48605003-7f11e900-e9b6-11e8-8880-87362348e154.png)
For molecules, the nodes are atoms and the edges are bonds.
**Goal**: Given a set of real molecules, we want to learn the distribution of them and get new molecules
with similar properties. See the `Evaluation` section for more details.
## Dataset
### Preprocessing
With our implementation, this model has several limitations:
1. Information about protonation and chirality are ignored during generation
2. Molecules consisting of `[N+]`, `[O-]`, etc. cannot be generated.
For example, the model can only generate `O=C1NC(=S)NC(=O)C1=CNC1=CC=C(N(=O)O)C=C1O` from
`O=C1NC(=S)NC(=O)C1=CNC1=CC=C([N+](=O)[O-])C=C1O` even with the correct decisions.
To avoid issues about validity and novelty, we filter out these molecules from the dataset.
### ChEMBL
The authors use the [ChEMBL database](https://www.ebi.ac.uk/chembl/). Since they
did not release the code, we use a subset from [Olivecrona et al.](https://github.com/MarcusOlivecrona/REINVENT),
another work on generative modeling.
The authors restrict their dataset to molecules with at most 20 heavy atoms, and used a training/validation
split of 130, 830/26, 166 examples each. We use the same split but need to relax 20 to 23 as we are using
a different subset.
### ZINC
After the pre-processing, we are left with 232464 molecules for training and 5000 molecules for validation.
## Usage
### Training
Training auto-regressive generative models tends to be very slow. According to the authors, they use multiprocess to
speed up training and gpu does not give much speed advantage. We follow their approach and perform multiprocess cpu
training.
To start training, use `train.py` with required arguments
```
-d DATASET, dataset to use (default: None), built-in support exists for ChEMBL, ZINC
-o {random,canonical}, order to generate graphs (default: None)
```
and optional arguments
```
-s SEED, random seed (default: 0)
-np NUM_PROCESSES, number of processes to use (default: 32)
```
Even though multiprocess yields a significant speedup comparing to a single process, the training can still take a long
time (several days). An epoch of training and validation can take up to one hour and a half on our machine. If not
necessary, we recommend users use our pre-trained models.
Meanwhile, we make a checkpoint of our model whenever there is a performance improvement on the validation set so you
do not need to wait until the training terminates.
All training results can be found in `training_results`.
#### Dataset configuration
You can also use your own dataset with additional arguments
```
-tf TRAIN_FILE, Path to a file with one SMILES a line for training
data. This is only necessary if you want to use a new
dataset. (default: None)
-vf VAL_FILE, Path to a file with one SMILES a line for validation
data. This is only necessary if you want to use a new
dataset. (default: None)
```
#### Monitoring
We can monitor the training process with tensorboard as below:
![](https://data.dgl.ai/dgllife/dgmg/tensorboard.png)
To use tensorboard, you need to install [tensorboardX](https://github.com/lanpa/tensorboardX) and
[TensorFlow](https://www.tensorflow.org/). You can lunch tensorboard with `tensorboard --logdir=.`
If you are training on a remote server, you can still use it with:
1. Launch it on the remote server with `tensorboard --logdir=. --port=A`
2. In the terminal of your local machine, type `ssh -NfL localhost:B:localhost:A username@your_remote_host_name`
3. Go to the address `localhost:B` in your browser
### Evaluation
To start evaluation, use `eval.py` with required arguments
```
-d DATASET, dataset to use (default: None), built-in support exists for ChEMBL, ZINC
-o {random,canonical}, order to generate graphs, used for naming evaluation directory (default: None)
-p MODEL_PATH, path to saved model (default: None). This is not needed if you want to use pretrained models.
-pr, Whether to use a pre-trained model (default: False)
```
and optional arguments
```
-s SEED, random seed (default: 0)
-ns NUM_SAMPLES, Number of molecules to generate (default: 100000)
-mn MAX_NUM_STEPS, Max number of steps allowed in generated molecules to
ensure termination (default: 400)
-np NUM_PROCESSES, number of processes to use (default: 32)
-gt GENERATION_TIME, max time (seconds) allowed for generation with
multiprocess (default: 600)
```
All evaluation results can be found in `eval_results`.
After the evaluation, 100000 molecules will be generated and stored in `generated_smiles.txt` under `eval_results`
directory, with three statistics logged in `generation_stats.txt` under `eval_results`:
1. `Validity among all` gives the percentage of molecules that are valid
2. `Uniqueness among valid ones` gives the percentage of valid molecules that are unique
3. `Novelty among unique ones` gives the percentage of unique valid molecules that are novel (not seen in training data)
We also provide a jupyter notebook where you can visualize the generated molecules
![](https://data.dgl.ai/model_zoo/drug_discovery/dgmg/DGMG_ZINC_canonical_vis.png)
and compare their property distributions against the training molecule property distributions
![](https://data.dgl.ai/model_zoo/drug_discovery/dgmg/DGMG_ZINC_canonical_dist.png)
You can download the notebook with `wget https://data.dgl.ai/dgllife/dgmg/eval_jupyter.ipynb`.
### Pre-trained models
Below gives the statistics of pre-trained models. With random order, the training becomes significantly more difficult
as we now have `N^2` data points with `N` molecules.
| Pre-trained model | % valid | % unique among valid | % novel among unique |
| ------------------ | ------- | -------------------- | -------------------- |
| `ChEMBL_canonical` | 78.80 | 99.19 | 98.60 |
| `ChEMBL_random` | 29.09 | 99.87 | 100.00 |
| `ZINC_canonical` | 74.60 | 99.87 | 99.87 |
| `ZINC_random` | 12.37 | 99.38 | 100.00 |
import os
import pickle
import shutil
import torch
from dgl import model_zoo
from utils import MoleculeDataset, set_random_seed, download_data,\
mkdir_p, summarize_molecules, get_unique_smiles, get_novel_smiles
def generate_and_save(log_dir, num_samples, max_num_steps, model):
with open(os.path.join(log_dir, 'generated_smiles.txt'), 'w') as f:
for i in range(num_samples):
with torch.no_grad():
s = model(rdkit_mol=True, max_num_steps=max_num_steps)
f.write(s + '\n')
def prepare_for_evaluation(rank, args):
worker_seed = args['seed'] + rank * 10000
set_random_seed(worker_seed)
torch.set_num_threads(1)
# Setup dataset and data loader
dataset = MoleculeDataset(args['dataset'], subset_id=rank, n_subsets=args['num_processes'])
# Initialize model
if not args['pretrained']:
model = model_zoo.chem.DGMG(atom_types=dataset.atom_types,
bond_types=dataset.bond_types,
node_hidden_size=args['node_hidden_size'],
num_prop_rounds=args['num_propagation_rounds'], dropout=args['dropout'])
model.load_state_dict(torch.load(args['model_path'])['model_state_dict'])
else:
model = model_zoo.chem.load_pretrained('_'.join(['DGMG', args['dataset'], args['order']]), log=False)
model.eval()
worker_num_samples = args['num_samples'] // args['num_processes']
if rank == args['num_processes'] - 1:
worker_num_samples += args['num_samples'] % args['num_processes']
worker_log_dir = os.path.join(args['log_dir'], str(rank))
mkdir_p(worker_log_dir, log=False)
generate_and_save(worker_log_dir, worker_num_samples, args['max_num_steps'], model)
def remove_worker_tmp_dir(args):
for rank in range(args['num_processes']):
worker_path = os.path.join(args['log_dir'], str(rank))
try:
shutil.rmtree(worker_path)
except OSError:
print('Directory {} does not exist!'.format(worker_path))
def aggregate_and_evaluate(args):
print('Merging generated SMILES into a single file...')
smiles = []
for rank in range(args['num_processes']):
with open(os.path.join(args['log_dir'], str(rank), 'generated_smiles.txt'), 'r') as f:
rank_smiles = f.read().splitlines()
smiles.extend(rank_smiles)
with open(os.path.join(args['log_dir'], 'generated_smiles.txt'), 'w') as f:
for s in smiles:
f.write(s + '\n')
print('Removing temporary dirs...')
remove_worker_tmp_dir(args)
# Summarize training molecules
print('Summarizing training molecules...')
train_file = '_'.join([args['dataset'], 'DGMG_train.txt'])
if not os.path.exists(train_file):
download_data(args['dataset'], train_file)
with open(train_file, 'r') as f:
train_smiles = f.read().splitlines()
train_summary = summarize_molecules(train_smiles, args['num_processes'])
with open(os.path.join(args['log_dir'], 'train_summary.pickle'), 'wb') as f:
pickle.dump(train_summary, f)
# Summarize generated molecules
print('Summarizing generated molecules...')
generation_summary = summarize_molecules(smiles, args['num_processes'])
with open(os.path.join(args['log_dir'], 'generation_summary.pickle'), 'wb') as f:
pickle.dump(generation_summary, f)
# Stats computation
print('Preparing generation statistics...')
valid_generated_smiles = generation_summary['smile']
unique_generated_smiles = get_unique_smiles(valid_generated_smiles)
unique_train_smiles = get_unique_smiles(train_summary['smile'])
novel_generated_smiles = get_novel_smiles(unique_generated_smiles, unique_train_smiles)
with open(os.path.join(args['log_dir'], 'generation_stats.txt'), 'w') as f:
f.write('Total number of generated molecules: {:d}\n'.format(len(smiles)))
f.write('Validity among all: {:.4f}\n'.format(
len(valid_generated_smiles) / len(smiles)))
f.write('Uniqueness among valid ones: {:.4f}\n'.format(
len(unique_generated_smiles) / len(valid_generated_smiles)))
f.write('Novelty among unique ones: {:.4f}\n'.format(
len(novel_generated_smiles) / len(unique_generated_smiles)))
if __name__ == '__main__':
import argparse
import datetime
import time
from rdkit import rdBase
from utils import setup
parser = argparse.ArgumentParser(description='Evaluating DGMG for molecule generation',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
# configure
parser.add_argument('-s', '--seed', type=int, default=0, help='random seed')
# dataset and setting
parser.add_argument('-d', '--dataset',
help='dataset to use')
parser.add_argument('-o', '--order', choices=['random', 'canonical'],
help='order to generate graphs, used for naming evaluation directory')
# log
parser.add_argument('-l', '--log-dir', default='./eval_results',
help='folder to save evaluation results')
parser.add_argument('-p', '--model-path', type=str, default=None,
help='path to saved model')
parser.add_argument('-pr', '--pretrained', action='store_true',
help='Whether to use a pre-trained model')
parser.add_argument('-ns', '--num-samples', type=int, default=100000,
help='Number of molecules to generate')
parser.add_argument('-mn', '--max-num-steps', type=int, default=400,
help='Max number of steps allowed in generated molecules to ensure termination')
# multi-process
parser.add_argument('-np', '--num-processes', type=int, default=32,
help='number of processes to use')
parser.add_argument('-gt', '--generation-time', type=int, default=600,
help='max time (seconds) allowed for generation with multiprocess')
args = parser.parse_args()
args = setup(args, train=False)
rdBase.DisableLog('rdApp.error')
t1 = time.time()
if args['num_processes'] == 1:
prepare_for_evaluation(0, args)
else:
import multiprocessing as mp
procs = []
for rank in range(args['num_processes']):
p = mp.Process(target=prepare_for_evaluation, args=(rank, args,))
procs.append(p)
p.start()
while time.time() - t1 <= args['generation_time']:
if any(p.is_alive() for p in procs):
time.sleep(5)
else:
break
else:
print('Timeout, killing all processes.')
for p in procs:
p.terminate()
p.join()
t2 = time.time()
print('It took {} for generation.'.format(
datetime.timedelta(seconds=t2 - t1)))
aggregate_and_evaluate(args)
#
# calculation of synthetic accessibility score as described in:
#
# Estimation of Synthetic Accessibility Score of Drug-like Molecules
# based on Molecular Complexity and Fragment Contributions
# Peter Ertl and Ansgar Schuffenhauer
# Journal of Cheminformatics 1:8 (2009)
# http://www.jcheminf.com/content/1/1/8
#
# several small modifications to the original paper are included
# particularly slightly different formula for marocyclic penalty
# and taking into account also molecule symmetry (fingerprint density)
#
# for a set of 10k diverse molecules the agreement between the original method
# as implemented in PipelinePilot and this implementation is r2 = 0.97
#
# peter ertl & greg landrum, september 2013
#
# A small modification is performed
#
# DGL team, August 2019
#
from __future__ import print_function
import math
import os
from rdkit import Chem
from rdkit.Chem import rdMolDescriptors
from rdkit.six.moves import cPickle
from rdkit.six import iteritems
from dgl.data.utils import download, _get_dgl_url, get_download_dir
_fscores = None
def readFragmentScores(name='fpscores'):
import gzip
global _fscores
fname = '{}.pkl.gz'.format(name)
download(_get_dgl_url(os.path.join('dataset', fname)), path=fname)
_fscores = cPickle.load(gzip.open(fname))
outDict = {}
for i in _fscores:
for j in range(1, len(i)):
outDict[i[j]] = float(i[0])
_fscores = outDict
def numBridgeheadsAndSpiro(mol):
nSpiro = rdMolDescriptors.CalcNumSpiroAtoms(mol)
nBridgehead = rdMolDescriptors.CalcNumBridgeheadAtoms(mol)
return nBridgehead, nSpiro
def calculateScore(m):
if _fscores is None:
readFragmentScores()
# fragment score
# 2 is the *radius* of the circular fingerprint
fp = rdMolDescriptors.GetMorganFingerprint(m, 2)
fps = fp.GetNonzeroElements()
score1 = 0.
nf = 0
for bitId, v in iteritems(fps):
nf += v
sfp = bitId
score1 += _fscores.get(sfp, -4) * v
# We add L63 to avoid ZeroDivisionError.
if nf != 0:
score1 /= nf
# features score
nAtoms = m.GetNumAtoms()
nChiralCenters = len(Chem.FindMolChiralCenters(m, includeUnassigned=True))
ri = m.GetRingInfo()
nBridgeheads, nSpiro = numBridgeheadsAndSpiro(m)
nMacrocycles = 0
for x in ri.AtomRings():
if len(x) > 8:
nMacrocycles += 1
sizePenalty = nAtoms**1.005 - nAtoms
stereoPenalty = math.log10(nChiralCenters + 1)
spiroPenalty = math.log10(nSpiro + 1)
bridgePenalty = math.log10(nBridgeheads + 1)
macrocyclePenalty = 0.
# ---------------------------------------
# This differs from the paper, which defines:
# macrocyclePenalty = math.log10(nMacrocycles+1)
# This form generates better results when 2 or more macrocycles are present
if nMacrocycles > 0:
macrocyclePenalty = math.log10(2)
score2 = 0. - sizePenalty - stereoPenalty - \
spiroPenalty - bridgePenalty - macrocyclePenalty
# correction for the fingerprint density
# not in the original publication, added in version 1.1
# to make highly symmetrical molecules easier to synthetise
score3 = 0.
if nAtoms > len(fps):
score3 = math.log(float(nAtoms) / len(fps)) * .5
sascore = score1 + score2 + score3
# need to transform "raw" value into scale between 1 and 10
min = -4.0
max = 2.5
sascore = 11. - (sascore - min + 1) / (max - min) * 9.
# smooth the 10-end
if sascore > 8.:
sascore = 8. + math.log(sascore + 1. - 9.)
if sascore > 10.:
sascore = 10.0
elif sascore < 1.:
sascore = 1.0
return sascore
def processMols(mols):
print('smiles\tName\tsa_score')
for i, m in enumerate(mols):
if m is None:
continue
s = calculateScore(m)
smiles = Chem.MolToSmiles(m)
print(smiles + "\t" + m.GetProp('_Name') + "\t%3f" % s)
if __name__ == '__main__':
import sys, time
t1 = time.time()
readFragmentScores("fpscores")
t2 = time.time()
suppl = Chem.SmilesMolSupplier(sys.argv[1])
t3 = time.time()
processMols(suppl)
t4 = time.time()
print('Reading took %.2f seconds. Calculating took %.2f seconds' % ((t2 - t1), (t4 - t3)),
file=sys.stderr)
#
# Copyright (c) 2013, Novartis Institutes for BioMedical Research Inc.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following
# disclaimer in the documentation and/or other materials provided
# with the distribution.
# * Neither the name of Novartis Institutes for BioMedical Research Inc.
# nor the names of its contributors may be used to endorse or promote
# products derived from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
"""
Learning Deep Generative Models of Graphs
Paper: https://arxiv.org/pdf/1803.03324.pdf
"""
import datetime
import time
import torch
import torch.distributed as dist
from dgl import model_zoo
from torch.optim import Adam
from torch.utils.data import DataLoader
from utils import MoleculeDataset, Printer, set_random_seed, synchronize, launch_a_process
def evaluate(epoch, model, data_loader, printer):
model.eval()
batch_size = data_loader.batch_size
total_log_prob = 0
with torch.no_grad():
for i, data in enumerate(data_loader):
log_prob = model(actions=data, compute_log_prob=True).detach()
total_log_prob -= log_prob
if printer is not None:
prob = log_prob.detach().exp()
printer.update(epoch + 1, - log_prob / batch_size, prob / batch_size)
return total_log_prob / len(data_loader)
def main(rank, args):
"""
Parameters
----------
rank : int
Subprocess id
args : dict
Configuration
"""
if rank == 0:
t1 = time.time()
set_random_seed(args['seed'])
# Remove the line below will result in problems for multiprocess
torch.set_num_threads(1)
# Setup dataset and data loader
dataset = MoleculeDataset(args['dataset'], args['order'], ['train', 'val'],
subset_id=rank, n_subsets=args['num_processes'])
# Note that currently the batch size for the loaders should only be 1.
train_loader = DataLoader(dataset.train_set, batch_size=args['batch_size'],
shuffle=True, collate_fn=dataset.collate)
val_loader = DataLoader(dataset.val_set, batch_size=args['batch_size'],
shuffle=True, collate_fn=dataset.collate)
if rank == 0:
try:
from tensorboardX import SummaryWriter
writer = SummaryWriter(args['log_dir'])
except ImportError:
print('If you want to use tensorboard, install tensorboardX with pip.')
writer = None
train_printer = Printer(args['nepochs'], len(dataset.train_set), args['batch_size'], writer)
val_printer = Printer(args['nepochs'], len(dataset.val_set), args['batch_size'])
else:
val_printer = None
# Initialize model
model = model_zoo.chem.DGMG(atom_types=dataset.atom_types,
bond_types=dataset.bond_types,
node_hidden_size=args['node_hidden_size'],
num_prop_rounds=args['num_propagation_rounds'],
dropout=args['dropout'])
if args['num_processes'] == 1:
from utils import Optimizer
optimizer = Optimizer(args['lr'], Adam(model.parameters(), lr=args['lr']))
else:
from utils import MultiProcessOptimizer
optimizer = MultiProcessOptimizer(args['num_processes'], args['lr'],
Adam(model.parameters(), lr=args['lr']))
if rank == 0:
t2 = time.time()
best_val_prob = 0
# Training
for epoch in range(args['nepochs']):
model.train()
if rank == 0:
print('Training')
for i, data in enumerate(train_loader):
log_prob = model(actions=data, compute_log_prob=True)
prob = log_prob.detach().exp()
loss_averaged = - log_prob
prob_averaged = prob
optimizer.backward_and_step(loss_averaged)
if rank == 0:
train_printer.update(epoch + 1, loss_averaged.item(), prob_averaged.item())
synchronize(args['num_processes'])
# Validation
val_log_prob = evaluate(epoch, model, val_loader, val_printer)
if args['num_processes'] > 1:
dist.all_reduce(val_log_prob, op=dist.ReduceOp.SUM)
val_log_prob /= args['num_processes']
# Strictly speaking, the computation of probability here is different from what is
# performed on the training set as we first take an average of log likelihood and then
# take the exponentiation. By Jensen's inequality, the resulting value is then a
# lower bound of the real probabilities.
val_prob = (- val_log_prob).exp().item()
val_log_prob = val_log_prob.item()
if val_prob >= best_val_prob:
if rank == 0:
torch.save({'model_state_dict': model.state_dict()}, args['checkpoint_dir'])
print('Old val prob {:.10f} | new val prob {:.10f} | model saved'.format(best_val_prob, val_prob))
best_val_prob = val_prob
elif epoch >= args['warmup_epochs']:
optimizer.decay_lr()
if rank == 0:
print('Validation')
if writer is not None:
writer.add_scalar('validation_log_prob', val_log_prob, epoch)
writer.add_scalar('validation_prob', val_prob, epoch)
writer.add_scalar('lr', optimizer.lr, epoch)
print('Validation log prob {:.4f} | prob {:.10f}'.format(val_log_prob, val_prob))
synchronize(args['num_processes'])
if rank == 0:
t3 = time.time()
print('It took {} to setup.'.format(datetime.timedelta(seconds=t2 - t1)))
print('It took {} to finish training.'.format(datetime.timedelta(seconds=t3 - t2)))
print('--------------------------------------------------------------------------')
print('On average, an epoch takes {}.'.format(datetime.timedelta(
seconds=(t3 - t2) / args['nepochs'])))
if __name__ == '__main__':
import argparse
from utils import setup
parser = argparse.ArgumentParser(description='Training DGMG for molecule generation',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
# configure
parser.add_argument('-s', '--seed', type=int, default=0, help='random seed')
parser.add_argument('-w', '--warmup-epochs', type=int, default=10,
help='Number of epochs where no lr decay is performed.')
# dataset and setting
parser.add_argument('-d', '--dataset',
help='dataset to use')
parser.add_argument('-o', '--order', choices=['random', 'canonical'],
help='order to generate graphs')
parser.add_argument('-tf', '--train-file', type=str, default=None,
help='Path to a file with one SMILES a line for training data. '
'This is only necessary if you want to use a new dataset.')
parser.add_argument('-vf', '--val-file', type=str, default=None,
help='Path to a file with one SMILES a line for validation data. '
'This is only necessary if you want to use a new dataset.')
# log
parser.add_argument('-l', '--log-dir', default='./training_results',
help='folder to save info like experiment configuration')
# multi-process
parser.add_argument('-np', '--num-processes', type=int, default=32,
help='number of processes to use')
parser.add_argument('-mi', '--master-ip', type=str, default='127.0.0.1')
parser.add_argument('-mp', '--master-port', type=str, default='12345')
args = parser.parse_args()
args = setup(args, train=True)
if args['num_processes'] == 1:
main(0, args)
else:
mp = torch.multiprocessing.get_context('spawn')
procs = []
for rank in range(args['num_processes']):
procs.append(mp.Process(target=launch_a_process, args=(rank, args, main), daemon=True))
procs[-1].start()
for p in procs:
p.join()
import datetime
import dgl
import math
import numpy as np
import os
import pickle
import random
import torch
import torch.distributed as dist
import torch.nn as nn
from collections import defaultdict
from datetime import timedelta
from dgl import DGLGraph
from dgl.data.utils import get_download_dir, download, _get_dgl_url
from dgl.model_zoo.chem.dgmg import MoleculeEnv
from multiprocessing import Pool
from pprint import pprint
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem.Crippen import MolLogP
from rdkit.Chem.QED import qed
from torch.utils.data import Dataset
from sascorer import calculateScore
########################################################################################################################
# configuration #
########################################################################################################################
def mkdir_p(path, log=True):
"""Create a directory for the specified path.
Parameters
----------
path : str
Path name
log : bool
Whether to print result for directory creation
"""
import errno
try:
os.makedirs(path)
if log:
print('Created directory {}'.format(path))
except OSError as exc:
if exc.errno == errno.EEXIST and os.path.isdir(path) and log:
print('Directory {} already exists.'.format(path))
else:
raise
def get_date_postfix():
"""Get a date based postfix for directory name.
Returns
-------
post_fix : str
"""
dt = datetime.datetime.now()
post_fix = '{}_{:02d}-{:02d}-{:02d}'.format(
dt.date(), dt.hour, dt.minute, dt.second)
return post_fix
def setup_log_dir(args):
"""Name and create directory for logging.
Parameters
----------
args : dict
Configuration
Returns
-------
log_dir : str
Path for logging directory
"""
date_postfix = get_date_postfix()
log_dir = os.path.join(
args['log_dir'],
'{}_{}_{}'.format(args['dataset'], args['order'], date_postfix))
mkdir_p(log_dir)
return log_dir
def save_arg_dict(args, filename='settings.txt'):
"""Save all experiment settings in a file.
Parameters
----------
args : dict
Configuration
filename : str
Name for the file to save settings
"""
def _format_value(v):
if isinstance(v, float):
return '{:.4f}'.format(v)
elif isinstance(v, int):
return '{:d}'.format(v)
else:
return '{}'.format(v)
save_path = os.path.join(args['log_dir'], filename)
with open(save_path, 'w') as f:
for key, value in args.items():
f.write('{}\t{}\n'.format(key, _format_value(value)))
print('Saved settings to {}'.format(save_path))
def configure(args):
"""Use default hyperparameters.
Parameters
----------
args : dict
Old configuration
Returns
-------
args : dict
Updated configuration
"""
configure = {
'node_hidden_size': 128,
'num_propagation_rounds': 2,
'lr': 1e-4,
'dropout': 0.2,
'nepochs': 400,
'batch_size': 1,
}
args.update(configure)
return args
def set_random_seed(seed):
"""Fix random seed for reproducible results.
Parameters
----------
seed : int
Random seed to use.
"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
def setup_dataset(args):
"""Dataset setup
For unsupported dataset, we need to perform data preprocessing.
Parameters
----------
args : dict
Configuration
"""
if args['dataset'] in ['ChEMBL', 'ZINC']:
print('Built-in support for dataset {} exists.'.format(args['dataset']))
else:
print('Configure for new dataset {}...'.format(args['dataset']))
configure_new_dataset(args['dataset'], args['train_file'], args['val_file'])
def setup(args, train=True):
"""Setup
Parameters
----------
args : argparse.Namespace
Configuration
train : bool
Whether the setup is for training or evaluation
"""
# Convert argparse.Namespace into a dict
args = args.__dict__.copy()
# Dataset
args = configure(args)
# Log
print('Prepare logging directory...')
log_dir = setup_log_dir(args)
args['log_dir'] = log_dir
save_arg_dict(args)
if train:
setup_dataset(args)
args['checkpoint_dir'] = os.path.join(log_dir, 'checkpoint.pth')
pprint(args)
return args
########################################################################################################################
# multi-process #
########################################################################################################################
def synchronize(num_processes):
"""Synchronize all processes.
Parameters
----------
num_processes : int
Number of subprocesses used
"""
if num_processes > 1:
dist.barrier()
def launch_a_process(rank, args, target, minutes=720):
"""Launch a subprocess for training.
Parameters
----------
rank : int
Subprocess id
args : dict
Configuration
target : callable
Target function for the subprocess
minutes : int
Timeout minutes for operations executed against the process group
"""
dist_init_method = 'tcp://{master_ip}:{master_port}'.format(
master_ip=args['master_ip'], master_port=args['master_port'])
dist.init_process_group(backend='gloo',
init_method=dist_init_method,
# If you have a larger dataset, you will need to increase it.
timeout=timedelta(minutes=minutes),
world_size=args['num_processes'],
rank=rank)
assert torch.distributed.get_rank() == rank
target(rank, args)
########################################################################################################################
# optimization #
########################################################################################################################
class Optimizer(nn.Module):
"""Wrapper for optimization
Parameters
----------
lr : float
Initial learning rate
optimizer
model optimizer
"""
def __init__(self, lr, optimizer):
super(Optimizer, self).__init__()
self.lr = lr
self.optimizer = optimizer
self._reset()
def _reset(self):
self.optimizer.zero_grad()
def backward_and_step(self, loss):
"""Backward and update model.
Parameters
----------
loss : torch.tensor consisting of a float only
"""
loss.backward()
self.optimizer.step()
self._reset()
def decay_lr(self, decay_rate=0.99):
"""Decay learning rate.
Parameters
----------
decay_rate : float
Multiply the current learning rate by the decay_rate
"""
self.lr *= decay_rate
for param_group in self.optimizer.param_groups:
param_group['lr'] = self.lr
class MultiProcessOptimizer(Optimizer):
"""Wrapper for optimization with multiprocess
Parameters
----------
n_processes : int
Number of processes used
lr : float
Initial learning rate
optimizer
model optimizer
"""
def __init__(self, n_processes, lr, optimizer):
super(MultiProcessOptimizer, self).__init__(lr=lr, optimizer=optimizer)
self.n_processes = n_processes
def _sync_gradient(self):
"""Average gradients across all subprocesses."""
for param_group in self.optimizer.param_groups:
for p in param_group['params']:
if p.requires_grad and p.grad is not None:
dist.all_reduce(p.grad.data, op=dist.ReduceOp.SUM)
p.grad.data /= self.n_processes
def backward_and_step(self, loss):
"""Backward and update model.
Parameters
----------
loss : torch.tensor consisting of a float only
"""
loss.backward()
self._sync_gradient()
self.optimizer.step()
self._reset()
########################################################################################################################
# data #
########################################################################################################################
def initialize_neuralization_reactions():
"""Reference neuralization reactions
Code adapted from RDKit Cookbook, by Hans de Winter.
"""
patts = (
# Imidazoles
('[n+;H]', 'n'),
# Amines
('[N+;!H0]', 'N'),
# Carboxylic acids and alcohols
('[$([O-]);!$([O-][#7])]', 'O'),
# Thiols
('[S-;X1]', 'S'),
# Sulfonamides
('[$([N-;X2]S(=O)=O)]', 'N'),
# Enamines
('[$([N-;X2][C,N]=C)]', 'N'),
# Tetrazoles
('[n-]', '[n]'),
# Sulfoxides
('[$([S-]=O)]', 'S'),
# Amides
('[$([N-]C=O)]', 'N'),
)
return [(Chem.MolFromSmarts(x), Chem.MolFromSmiles(y, False)) for x, y in patts]
def neutralize_charges(mol, reactions=None):
"""Deprotonation for molecules.
Code adapted from RDKit Cookbook, by Hans de Winter.
DGMG currently cannot generate protonated molecules.
For example, it can only generate
CC(C)(C)CC1CCC[NH+]1Cc1nnc(-c2ccccc2F)o1
from
CC(C)(C)CC1CCCN1Cc1nnc(-c2ccccc2F)o1
even with correct decisions.
Deprotonation is therefore an important step to avoid
false novel molecules.
Parameters
----------
mol : Chem.rdchem.Mol
reactions : list of 2-tuples
Rules for deprotonation
Returns
-------
mol : Chem.rdchem.Mol
Deprotonated molecule
"""
if reactions is None:
reactions = initialize_neuralization_reactions()
for i, (reactant, product) in enumerate(reactions):
while mol.HasSubstructMatch(reactant):
rms = AllChem.ReplaceSubstructs(mol, reactant, product)
mol = rms[0]
return mol
def standardize_mol(mol):
"""Standardize molecule to avoid false novel molecule.
Kekulize and deprotonate molecules to avoid false novel molecules.
In addition to deprotonation, we also kekulize molecules to avoid
explicit Hs in the SMILES. Otherwise we will get false novel molecules
as well. For example, DGMG can only generate
O=S(=O)(NC1=CC=CC(C(F)(F)F)=C1)C1=CNC=N1
from
O=S(=O)(Nc1cccc(C(F)(F)F)c1)c1c[nH]cn1.
One downside is that we remove all explicit aromatic rings and to
explicitly predict aromatic bond might make the learning easier for
the model.
"""
reactions = initialize_neuralization_reactions()
Chem.Kekulize(mol, clearAromaticFlags=True)
mol = neutralize_charges(mol, reactions)
return mol
def smiles_to_standard_mol(s):
"""Convert SMILES to a standard molecule.
Parameters
----------
s : str
SMILES
Returns
-------
Chem.rdchem.Mol
Standardized molecule
"""
mol = Chem.MolFromSmiles(s)
return standardize_mol(mol)
def mol_to_standard_smile(mol):
"""Standardize a molecule and convert it to a SMILES.
Parameters
----------
mol : Chem.rdchem.Mol
Returns
-------
str
SMILES
"""
return Chem.MolToSmiles(standardize_mol(mol))
def get_atom_and_bond_types(smiles, log=True):
"""Identify the atom types and bond types
appearing in this dataset.
Parameters
----------
smiles : list
List of smiles
log : bool
Whether to print the process of pre-processing.
Returns
-------
atom_types : list
E.g. ['C', 'N']
bond_types : list
E.g. [Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE,
Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC]
"""
atom_types = set()
bond_types = set()
n_smiles = len(smiles)
for i, s in enumerate(smiles):
if log:
print('Processing smiles {:d}/{:d}'.format(i + 1, n_smiles))
mol = smiles_to_standard_mol(s)
if mol is None:
continue
for atom in mol.GetAtoms():
a_symbol = atom.GetSymbol()
if a_symbol not in atom_types:
atom_types.add(a_symbol)
for bond in mol.GetBonds():
b_type = bond.GetBondType()
if b_type not in bond_types:
bond_types.add(b_type)
return list(atom_types), list(bond_types)
def eval_decisions(env, decisions):
"""This function mimics the way DGMG generates a molecule and is
helpful for debugging and verification in data preprocessing.
Parameters
----------
env : MoleculeEnv
MDP environment for generating molecules
decisions : list of 2-tuples of int
A decision sequence for generating a molecule
Returns
-------
str
SMILES for the molecule generated with decisions
"""
env.reset(rdkit_mol=True)
t = 0
def whether_to_add_atom(t):
assert decisions[t][0] == 0
atom_type = decisions[t][1]
t += 1
return t, atom_type
def whether_to_add_bond(t):
assert decisions[t][0] == 1
bond_type = decisions[t][1]
t += 1
return t, bond_type
def decide_atom2(t):
assert decisions[t][0] == 2
dst = decisions[t][1]
t += 1
return t, dst
t, atom_type = whether_to_add_atom(t)
while atom_type != len(env.atom_types):
env.add_atom(atom_type)
t, bond_type = whether_to_add_bond(t)
while bond_type != len(env.bond_types):
t, dst = decide_atom2(t)
env.add_bond((env.num_atoms() - 1), dst, bond_type)
t, bond_type = whether_to_add_bond(t)
t, atom_type = whether_to_add_atom(t)
assert t == len(decisions)
return env.get_current_smiles()
def get_DGMG_smile(env, mol):
"""Mimics the reproduced SMILES with DGMG for a molecule.
Given a molecule, we are interested in what SMILES we will
get if we want to generate it with DGMG. This is an important
step to check false novel molecules.
Parameters
----------
env : MoleculeEnv
MDP environment for generating molecules
mol : Chem.rdchem.Mol
A molecule
Returns
-------
canonical_smile : str
SMILES of the generated molecule with a canonical decision sequence
random_smile : str
SMILES of the generated molecule with a random decision sequence
"""
canonical_decisions = env.get_decision_sequence(mol, list(range(mol.GetNumAtoms())))
canonical_smile = eval_decisions(env, canonical_decisions)
order = list(range(mol.GetNumAtoms()))
random.shuffle(order)
random_decisions = env.get_decision_sequence(mol, order)
random_smile = eval_decisions(env, random_decisions)
return canonical_smile, random_smile
def preprocess_dataset(atom_types, bond_types, smiles, max_num_atoms=23):
"""Preprocess the dataset
1. Standardize the SMILES of the dataset
2. Only keep the SMILES that DGMG can reproduce
3. Drop repeated SMILES
Parameters
----------
atom_types : list
The types of atoms appearing in a dataset. E.g. ['C', 'N']
bond_types : list
The types of bonds appearing in a dataset.
E.g. [Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE,
Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC]
Returns
-------
valid_smiles : list of str
SMILES left after preprocessing
"""
valid_smiles = []
env = MoleculeEnv(atom_types, bond_types)
for id, s in enumerate(smiles):
print('Processing {:d}/{:d}'.format(id + 1, len(smiles)))
raw_s = s.strip()
mol = smiles_to_standard_mol(raw_s)
if mol is None:
continue
standard_s = Chem.MolToSmiles(mol)
if (max_num_atoms is not None) and (mol.GetNumAtoms() > max_num_atoms):
continue
canonical_s, random_s = get_DGMG_smile(env, mol)
canonical_mol = Chem.MolFromSmiles(canonical_s)
random_mol = Chem.MolFromSmiles(random_s)
if (standard_s != canonical_s) or (canonical_s != random_s) or (canonical_mol is None) or (random_mol is None):
continue
valid_smiles.append(standard_s)
valid_smiles = list(set(valid_smiles))
return valid_smiles
def download_data(dataset, fname):
"""Download dataset if built-in support exists
Parameters
----------
dataset : str
Dataset name
fname : str
Name of dataset file
"""
if dataset not in ['ChEMBL', 'ZINC']:
# For dataset without built-in support, they should be locally processed.
return
data_path = fname
download(_get_dgl_url(os.path.join('dataset', fname)), path=data_path)
def load_smiles_from_file(f_name):
"""Load dataset into a list of SMILES
Parameters
----------
f_name : str
Path to a file of molecules, where each line of the file
is a molecule in SMILES format.
Returns
-------
smiles : list of str
List of molecules as SMILES
"""
with open(f_name, 'r') as f:
smiles = f.read().splitlines()
return smiles
def write_smiles_to_file(f_name, smiles):
"""Write dataset to a file.
Parameters
----------
f_name : str
Path to create a file of molecules, where each line of the file
is a molecule in SMILES format.
smiles : list of str
List of SMILES
"""
with open(f_name, 'w') as f:
for s in smiles:
f.write(s + '\n')
def configure_new_dataset(dataset, train_file, val_file):
"""Configure for a new dataset.
Parameters
----------
dataset : str
Dataset name
train_file : str
Path to a file with one SMILES a line for training data
val_file : str
Path to a file with one SMILES a line for validation data
"""
assert train_file is not None, 'Expect a file of SMILES for training, got None.'
assert val_file is not None, 'Expect a file of SMILES for validation, got None.'
train_smiles = load_smiles_from_file(train_file)
val_smiles = load_smiles_from_file(val_file)
all_smiles = train_smiles + val_smiles
# Get all atom and bond types in the dataset
path_to_atom_and_bond_types = '_'.join([dataset, 'atom_and_bond_types.pkl'])
if not os.path.exists(path_to_atom_and_bond_types):
atom_types, bond_types = get_atom_and_bond_types(all_smiles)
with open(path_to_atom_and_bond_types, 'wb') as f:
pickle.dump({'atom_types': atom_types, 'bond_types': bond_types}, f)
else:
with open(path_to_atom_and_bond_types, 'rb') as f:
type_info = pickle.load(f)
atom_types = type_info['atom_types']
bond_types = type_info['bond_types']
# Standardize training data
path_to_processed_train_data = '_'.join([dataset, 'DGMG', 'train.txt'])
if not os.path.exists(path_to_processed_train_data):
processed_train_smiles = preprocess_dataset(atom_types, bond_types, train_smiles, None)
write_smiles_to_file(path_to_processed_train_data, processed_train_smiles)
path_to_processed_val_data = '_'.join([dataset, 'DGMG', 'val.txt'])
if not os.path.exists(path_to_processed_val_data):
processed_val_smiles = preprocess_dataset(atom_types, bond_types, val_smiles, None)
write_smiles_to_file(path_to_processed_val_data, processed_val_smiles)
class MoleculeDataset(object):
"""Initialize and split the dataset.
Parameters
----------
dataset : str
Dataset name
order : None or str
Order to extract a decision sequence for generating a molecule. Default to be None.
modes : None or list
List of subsets to use, which can contain 'train', 'val', corresponding to
training and validation. Default to be None.
subset_id : int
With multiprocess training, we partition the training set into multiple subsets and
each process will use one subset only. This subset_id corresponds to subprocess id.
n_subsets : int
With multiprocess training, this corresponds to the number of total subprocesses.
"""
def __init__(self, dataset, order=None, modes=None, subset_id=0, n_subsets=1):
super(MoleculeDataset, self).__init__()
if modes is None:
modes = []
else:
assert order is not None, 'An order should be specified for extracting ' \
'decision sequences.'
assert order in ['random', 'canonical', None], \
"Unexpected order option to get sequences of graph generation decisions"
assert len(set(modes) - {'train', 'val'}) == 0, \
"modes should be a list, representing a subset of ['train', 'val']"
self.dataset = dataset
self.order = order
self.modes = modes
self.subset_id = subset_id
self.n_subsets = n_subsets
self._setup()
def collate(self, samples):
"""PyTorch's approach to batch multiple samples.
For auto-regressive generative models, we process one sample at a time.
Parameters
----------
samples : list
A list of length 1 that consists of decision sequence to generate a molecule.
Returns
-------
list
List of 2-tuples, a decision sequence to generate a molecule
"""
assert len(samples) == 1
return samples[0]
def _create_a_subset(self, smiles):
"""Create a dataset from a subset of smiles.
Parameters
----------
smiles : list of str
List of molecules in SMILES format
"""
# We evenly divide the smiles into multiple susbets with multiprocess
subset_size = len(smiles) // self.n_subsets
return Subset(smiles[self.subset_id * subset_size: (self.subset_id + 1) * subset_size],
self.order, self.env)
def _setup(self):
"""
1. Instantiate an MDP environment for molecule generation
2. Download the dataset, which is a file of SMILES
3. Create subsets for training and validation
"""
if self.dataset == 'ChEMBL':
# For new datasets, get_atom_and_bond_types can be used to
# identify the atom and bond types in them.
self.atom_types = ['O', 'Cl', 'C', 'S', 'F', 'Br', 'N']
self.bond_types = [Chem.rdchem.BondType.SINGLE,
Chem.rdchem.BondType.DOUBLE,
Chem.rdchem.BondType.TRIPLE]
elif self.dataset == 'ZINC':
self.atom_types = ['Br', 'S', 'C', 'P', 'N', 'O', 'F', 'Cl', 'I']
self.bond_types = [Chem.rdchem.BondType.SINGLE,
Chem.rdchem.BondType.DOUBLE,
Chem.rdchem.BondType.TRIPLE]
else:
path_to_atom_and_bond_types = '_'.join([self.dataset, 'atom_and_bond_types.pkl'])
with open(path_to_atom_and_bond_types, 'rb') as f:
type_info = pickle.load(f)
self.atom_types = type_info['atom_types']
self.bond_types = type_info['bond_types']
self.env = MoleculeEnv(self.atom_types, self.bond_types)
dataset_prefix = self._dataset_prefix()
if 'train' in self.modes:
fname = '_'.join([dataset_prefix, 'train.txt'])
download_data(self.dataset, fname)
smiles = load_smiles_from_file(fname)
self.train_set = self._create_a_subset(smiles)
if 'val' in self.modes:
fname = '_'.join([dataset_prefix, 'val.txt'])
download_data(self.dataset, fname)
smiles = load_smiles_from_file(fname)
# We evenly divide the smiles into multiple susbets with multiprocess
self.val_set = self._create_a_subset(smiles)
def _dataset_prefix(self):
"""Get the prefix for the data files of supported datasets.
Returns
-------
str
Prefix for dataset file name
"""
return '_'.join([self.dataset, 'DGMG'])
class Subset(Dataset):
"""A set of molecules which can be used for training, validation, test.
Parameters
----------
smiles : list
List of SMILES for the dataset
order : str
Specifies how decision sequences for molecule generation
are obtained, can be either "random" or "canonical"
env : MoleculeEnv object
MDP environment for generating molecules
"""
def __init__(self, smiles, order, env):
super(Subset, self).__init__()
self.smiles = smiles
self.order = order
self.env = env
self._setup()
def _setup(self):
"""Convert SMILES into rdkit molecule objects.
Decision sequences are extracted if we use a fixed order.
"""
smiles_ = []
mols = []
for s in self.smiles:
m = smiles_to_standard_mol(s)
if m is None:
continue
smiles_.append(s)
mols.append(m)
self.smiles = smiles_
self.mols = mols
if self.order is 'random':
return
self.decisions = []
for m in self.mols:
self.decisions.append(
self.env.get_decision_sequence(m, list(range(m.GetNumAtoms())))
)
def __len__(self):
"""Get number of molecules in the dataset."""
return len(self.mols)
def __getitem__(self, item):
"""Get the decision sequence for generating the molecule indexed by item."""
if self.order == 'canonical':
return self.decisions[item]
else:
m = self.mols[item]
nodes = list(range(m.GetNumAtoms()))
random.shuffle(nodes)
return self.env.get_decision_sequence(m, nodes)
########################################################################################################################
# progress tracking #
########################################################################################################################
class Printer(object):
def __init__(self, num_epochs, dataset_size, batch_size, writer=None):
"""Wrapper to track the learning progress.
Parameters
----------
num_epochs : int
Number of epochs for training
dataset_size : int
batch_size : int
writer : None or SummaryWriter
If not None, tensorboard will be used to visualize learning curves.
"""
super(Printer, self).__init__()
self.num_epochs = num_epochs
self.batch_size = batch_size
self.num_batches = math.ceil(dataset_size / batch_size)
self.count = 0
self.batch_count = 0
self.writer = writer
self._reset()
def _reset(self):
"""Reset when an epoch is completed."""
self.batch_loss = 0
self.batch_prob = 0
def _get_current_batch(self):
"""Get current batch index."""
remainer = self.batch_count % self.num_batches
if (remainer == 0):
return self.num_batches
else:
return remainer
def update(self, epoch, loss, prob):
"""Update learning progress.
Parameters
----------
epoch : int
loss : float
prob : float
"""
self.count += 1
self.batch_loss += loss
self.batch_prob += prob
if self.count % self.batch_size == 0:
self.batch_count += 1
if self.writer is not None:
self.writer.add_scalar('train_log_prob', self.batch_loss, self.batch_count)
self.writer.add_scalar('train_prob', self.batch_prob, self.batch_count)
print('epoch {:d}/{:d}, batch {:d}/{:d}, loss {:.4f}, prob {:.4f}'.format(
epoch, self.num_epochs, self._get_current_batch(),
self.num_batches, self.batch_loss, self.batch_prob))
self._reset()
########################################################################################################################
# eval #
########################################################################################################################
def summarize_a_molecule(smile, checklist=None):
"""Get information about a molecule.
Parameters
----------
smile : str
Molecule in SMILES format
checklist : dict
Things to learn about the molecule
"""
if checklist is None:
checklist = {
'HBA': Chem.rdMolDescriptors.CalcNumHBA,
'HBD': Chem.rdMolDescriptors.CalcNumHBD,
'logP': MolLogP,
'SA': calculateScore,
'TPSA': Chem.rdMolDescriptors.CalcTPSA,
'QED': qed,
'NumAtoms': lambda mol: mol.GetNumAtoms(),
'NumBonds': lambda mol: mol.GetNumBonds()
}
summary = dict()
mol = Chem.MolFromSmiles(smile)
if mol is None:
summary.update({
'smile': smile,
'valid': False
})
for k in checklist.keys():
summary[k] = None
else:
mol = standardize_mol(mol)
summary.update({
'smile': Chem.MolToSmiles(mol),
'valid': True
})
Chem.SanitizeMol(mol)
for k, f in checklist.items():
summary[k] = f(mol)
return summary
def summarize_molecules(smiles, num_processes):
"""Summarize molecules with multiprocess.
Parameters
----------
smiles : list of str
List of molecules in SMILES for summarization
num_processes : int
Number of processes to use for summarization
Returns
-------
summary_for_valid : dict
Summary of all valid molecules, where
summary_for_valid[k] gives the values of all
valid molecules on item k.
"""
with Pool(processes=num_processes) as pool:
result = pool.map(summarize_a_molecule, smiles)
items = list(result[0].keys())
items.remove('valid')
summary_for_valid = defaultdict(list)
for summary in result:
if summary['valid']:
for k in items:
summary_for_valid[k].append(summary[k])
return summary_for_valid
def get_unique_smiles(smiles):
"""Given a list of smiles, return a list consisting of unique elements in it.
Parameters
----------
smiles : list of str
Molecules in SMILES
Returns
-------
list of str
Sublist where each SMIES occurs exactly once
"""
unique_set = set()
for mol_s in smiles:
if mol_s not in unique_set:
unique_set.add(mol_s)
return list(unique_set)
def get_novel_smiles(new_unique_smiles, reference_unique_smiles):
"""Get novel smiles which do not appear in the reference set.
Parameters
----------
new_unique_smiles : list of str
List of SMILES from which we want to identify novel ones
reference_unique_smiles : list of str
List of reference SMILES that we already have
"""
return set(new_unique_smiles).difference(set(reference_unique_smiles))
# Junction Tree Variational Autoencoder for Molecular Graph Generation (JTNN)
Wengong Jin, Regina Barzilay, Tommi Jaakkola.
Junction Tree Variational Autoencoder for Molecular Graph Generation.
*arXiv preprint arXiv:1802.04364*, 2018.
JTNN uses algorithm called junction tree algorithm to form a tree from the molecular graph.
Then the model will encode the tree and graph into two separate vectors `z_G` and `z_T`. Details can
be found in original paper. The brief process is as below (from original paper):
![image](https://user-images.githubusercontent.com/8686776/63677300-3fb6d980-c81f-11e9-8a65-57c8b03aaf52.png)
**Goal**: JTNN is an auto-encoder model, aiming to learn hidden representation for molecular graphs.
These representations can be used for downstream tasks, such as property prediction, or molecule optimizations.
## Dataset
### ZINC
> The ZINC database is a curated collection of commercially available chemical compounds
prepared especially for virtual screening. (introduction from Wikipedia)
Generally speaking, molecules in the ZINC dataset are more drug-like. We uses ~220,000
molecules for training and 5000 molecules for validation.
### Preprocessing
Class `JTNNDataset` will process a SMILES into a dict, including the junction tree, graph with
encoded nodes(atoms) and edges(bonds), and other information for model to use.
## Usage
### Training
To start training, use `python train.py`. By default, the script will use ZINC dataset
with preprocessed vocabulary, and save model checkpoint at the current working directory.
```
-s SAVE_PATH, Path to save checkpoint models, default to be current
working directory (default: ./)
-m MODEL_PATH, Path to load pre-trained model (default: None)
-b BATCH_SIZE, Batch size (default: 40)
-w HIDDEN_SIZE, Size of representation vectors (default: 200)
-l LATENT_SIZE, Latent Size of node(atom) features and edge(atom)
features (default: 56)
-d DEPTH, Depth of message passing hops (default: 3)
-z BETA, Coefficient of KL Divergence term (default: 1.0)
-q LR, Learning Rate (default: 0.001)
```
Model will be saved periodically.
All training checkpoint will be stored at `SAVE_PATH`, passed by command line or by default.
#### Dataset configuration
If you want to use your own dataset, please create a file contains one SMILES a line,
and pass the file path to the `-t` or `--train` option.
```
-t TRAIN, --train TRAIN
Training file name (default: train)
```
### Evaluation
To start evaluation, use `python reconstruct_eval.py`, and following arguments
```
-t TRAIN, Training file name (default: test)
-m MODEL_PATH, Pre-trained model to be loaded for evalutaion. If not
specified, would use pre-trained model from model zoo
(default: None)
-w HIDDEN_SIZE, Hidden size of representation vector, should be
consistent with pre-trained model (default: 450)
-l LATENT_SIZE, Latent Size of node(atom) features and edge(atom)
features, should be consistent with pre-trained model
(default: 56)
-d DEPTH, Depth of message passing hops, should be consistent
with pre-trained model (default: 3)
```
And it would print out the success rate of reconstructing the same molecules.
### Pre-trained models
Below gives the statistics of pre-trained `JTNN_ZINC` model.
| Pre-trained model | % Reconstruction Accuracy
| ------------------ | -------
| `JTNN_ZINC` | 73.7
### Visualization
Here we draw some "neighbor" of a given molecule, by adding noises on the intermediate representations.
You can download the script with `wget https://data.dgl.ai/dgllife/jtnn_viz_neighbor_mol.ipynb`.
Please put this script at the current directory (`examples/pytorch/model_zoo/chem/generative_models/jtnn/`).
#### Given Molecule
![image](https://user-images.githubusercontent.com/8686776/63773593-0d37da00-c90e-11e9-8933-0abca4b430db.png)
#### Neighbor Molecules
![image](https://user-images.githubusercontent.com/8686776/63773602-1163f780-c90e-11e9-8341-5122dc0d0c82.png)
### Warnings from PyTorch 1.2
If you are using PyTorch 1.2, there might be warning saying
`UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.`. This is due to the new feature in PyTorch 1.2. Please kindly ignore it.
\ No newline at end of file
from .mol_tree import Vocab
from .nnutils import cuda
from .datautils import JTNNDataset, JTNNCollator
from .chemutils import decode_stereo
import rdkit.Chem as Chem
import torch
from scipy.sparse import csr_matrix
from scipy.sparse.csgraph import minimum_spanning_tree
from collections import defaultdict
from rdkit.Chem.EnumerateStereoisomers import EnumerateStereoisomers
from dgl import DGLGraph
ELEM_LIST = ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na',
'Ca', 'Fe', 'Al', 'I', 'B', 'K', 'Se', 'Zn', 'H', 'Cu', 'Mn', 'unknown']
MST_MAX_WEIGHT = 100
MAX_NCAND = 2000
def onek_encoding_unk(x, allowable_set):
if x not in allowable_set:
x = allowable_set[-1]
return [x == s for s in allowable_set]
def set_atommap(mol, num=0):
for atom in mol.GetAtoms():
atom.SetAtomMapNum(num)
def get_mol(smiles):
mol = Chem.MolFromSmiles(smiles)
if mol is None:
return None
Chem.Kekulize(mol)
return mol
def get_smiles(mol):
return Chem.MolToSmiles(mol, kekuleSmiles=True)
def decode_stereo(smiles2D):
mol = Chem.MolFromSmiles(smiles2D)
dec_isomers = list(EnumerateStereoisomers(mol))
dec_isomers = [Chem.MolFromSmiles(Chem.MolToSmiles(
mol, isomericSmiles=True)) for mol in dec_isomers]
smiles3D = [Chem.MolToSmiles(mol, isomericSmiles=True)
for mol in dec_isomers]
chiralN = [atom.GetIdx() for atom in dec_isomers[0].GetAtoms() if int(
atom.GetChiralTag()) > 0 and atom.GetSymbol() == "N"]
if len(chiralN) > 0:
for mol in dec_isomers:
for idx in chiralN:
mol.GetAtomWithIdx(idx).SetChiralTag(
Chem.rdchem.ChiralType.CHI_UNSPECIFIED)
smiles3D.append(Chem.MolToSmiles(mol, isomericSmiles=True))
return smiles3D
def sanitize(mol):
try:
smiles = get_smiles(mol)
mol = get_mol(smiles)
except Exception as e:
return None
return mol
def copy_atom(atom):
new_atom = Chem.Atom(atom.GetSymbol())
new_atom.SetFormalCharge(atom.GetFormalCharge())
new_atom.SetAtomMapNum(atom.GetAtomMapNum())
return new_atom
def copy_edit_mol(mol):
new_mol = Chem.RWMol(Chem.MolFromSmiles(''))
for atom in mol.GetAtoms():
new_atom = copy_atom(atom)
new_mol.AddAtom(new_atom)
for bond in mol.GetBonds():
a1 = bond.GetBeginAtom().GetIdx()
a2 = bond.GetEndAtom().GetIdx()
bt = bond.GetBondType()
new_mol.AddBond(a1, a2, bt)
return new_mol
def get_clique_mol(mol, atoms):
smiles = Chem.MolFragmentToSmiles(mol, atoms, kekuleSmiles=True)
new_mol = Chem.MolFromSmiles(smiles, sanitize=False)
new_mol = copy_edit_mol(new_mol).GetMol()
new_mol = sanitize(new_mol) # We assume this is not None
return new_mol
def tree_decomp(mol):
n_atoms = mol.GetNumAtoms()
if n_atoms == 1:
return [[0]], []
cliques = []
for bond in mol.GetBonds():
a1 = bond.GetBeginAtom().GetIdx()
a2 = bond.GetEndAtom().GetIdx()
if not bond.IsInRing():
cliques.append([a1, a2])
ssr = [list(x) for x in Chem.GetSymmSSSR(mol)]
cliques.extend(ssr)
nei_list = [[] for i in range(n_atoms)]
for i in range(len(cliques)):
for atom in cliques[i]:
nei_list[atom].append(i)
# Merge Rings with intersection > 2 atoms
for i in range(len(cliques)):
if len(cliques[i]) <= 2:
continue
for atom in cliques[i]:
for j in nei_list[atom]:
if i >= j or len(cliques[j]) <= 2:
continue
inter = set(cliques[i]) & set(cliques[j])
if len(inter) > 2:
cliques[i].extend(cliques[j])
cliques[i] = list(set(cliques[i]))
cliques[j] = []
cliques = [c for c in cliques if len(c) > 0]
nei_list = [[] for i in range(n_atoms)]
for i in range(len(cliques)):
for atom in cliques[i]:
nei_list[atom].append(i)
# Build edges and add singleton cliques
edges = defaultdict(int)
for atom in range(n_atoms):
if len(nei_list[atom]) <= 1:
continue
cnei = nei_list[atom]
bonds = [c for c in cnei if len(cliques[c]) == 2]
rings = [c for c in cnei if len(cliques[c]) > 4]
# In general, if len(cnei) >= 3, a singleton should be added, but 1 bond + 2 ring is currently not dealt with.
if len(bonds) > 2 or (len(bonds) == 2 and len(cnei) > 2):
cliques.append([atom])
c2 = len(cliques) - 1
for c1 in cnei:
edges[(c1, c2)] = 1
elif len(rings) > 2: # Multiple (n>2) complex rings
cliques.append([atom])
c2 = len(cliques) - 1
for c1 in cnei:
edges[(c1, c2)] = MST_MAX_WEIGHT - 1
else:
for i in range(len(cnei)):
for j in range(i + 1, len(cnei)):
c1, c2 = cnei[i], cnei[j]
inter = set(cliques[c1]) & set(cliques[c2])
if edges[(c1, c2)] < len(inter):
# cnei[i] < cnei[j] by construction
edges[(c1, c2)] = len(inter)
edges = [u + (MST_MAX_WEIGHT-v,) for u, v in edges.items()]
if len(edges) == 0:
return cliques, edges
# Compute Maximum Spanning Tree
row, col, data = list(zip(*edges))
n_clique = len(cliques)
clique_graph = csr_matrix((data, (row, col)), shape=(n_clique, n_clique))
junc_tree = minimum_spanning_tree(clique_graph)
row, col = junc_tree.nonzero()
edges = [(row[i], col[i]) for i in range(len(row))]
return (cliques, edges)
def atom_equal(a1, a2):
return a1.GetSymbol() == a2.GetSymbol() and a1.GetFormalCharge() == a2.GetFormalCharge()
# Bond type not considered because all aromatic (so SINGLE matches DOUBLE)
def ring_bond_equal(b1, b2, reverse=False):
b1 = (b1.GetBeginAtom(), b1.GetEndAtom())
if reverse:
b2 = (b2.GetEndAtom(), b2.GetBeginAtom())
else:
b2 = (b2.GetBeginAtom(), b2.GetEndAtom())
return atom_equal(b1[0], b2[0]) and atom_equal(b1[1], b2[1])
def attach_mols_nx(ctr_mol, neighbors, prev_nodes, nei_amap):
prev_nids = [node['nid'] for node in prev_nodes]
for nei_node in prev_nodes + neighbors:
nei_id, nei_mol = nei_node['nid'], nei_node['mol']
amap = nei_amap[nei_id]
for atom in nei_mol.GetAtoms():
if atom.GetIdx() not in amap:
new_atom = copy_atom(atom)
amap[atom.GetIdx()] = ctr_mol.AddAtom(new_atom)
if nei_mol.GetNumBonds() == 0:
nei_atom = nei_mol.GetAtomWithIdx(0)
ctr_atom = ctr_mol.GetAtomWithIdx(amap[0])
ctr_atom.SetAtomMapNum(nei_atom.GetAtomMapNum())
else:
for bond in nei_mol.GetBonds():
a1 = amap[bond.GetBeginAtom().GetIdx()]
a2 = amap[bond.GetEndAtom().GetIdx()]
if ctr_mol.GetBondBetweenAtoms(a1, a2) is None:
ctr_mol.AddBond(a1, a2, bond.GetBondType())
elif nei_id in prev_nids: # father node overrides
ctr_mol.RemoveBond(a1, a2)
ctr_mol.AddBond(a1, a2, bond.GetBondType())
return ctr_mol
def local_attach_nx(ctr_mol, neighbors, prev_nodes, amap_list):
ctr_mol = copy_edit_mol(ctr_mol)
nei_amap = {nei['nid']: {} for nei in prev_nodes + neighbors}
for nei_id, ctr_atom, nei_atom in amap_list:
nei_amap[nei_id][nei_atom] = ctr_atom
ctr_mol = attach_mols_nx(ctr_mol, neighbors, prev_nodes, nei_amap)
return ctr_mol.GetMol()
# This version records idx mapping between ctr_mol and nei_mol
def enum_attach_nx(ctr_mol, nei_node, amap, singletons):
nei_mol, nei_idx = nei_node['mol'], nei_node['nid']
att_confs = []
black_list = [atom_idx for nei_id, atom_idx,
_ in amap if nei_id in singletons]
ctr_atoms = [atom for atom in ctr_mol.GetAtoms() if atom.GetIdx()
not in black_list]
ctr_bonds = [bond for bond in ctr_mol.GetBonds()]
if nei_mol.GetNumBonds() == 0: # neighbor singleton
nei_atom = nei_mol.GetAtomWithIdx(0)
used_list = [atom_idx for _, atom_idx, _ in amap]
for atom in ctr_atoms:
if atom_equal(atom, nei_atom) and atom.GetIdx() not in used_list:
new_amap = amap + [(nei_idx, atom.GetIdx(), 0)]
att_confs.append(new_amap)
elif nei_mol.GetNumBonds() == 1: # neighbor is a bond
bond = nei_mol.GetBondWithIdx(0)
bond_val = int(bond.GetBondTypeAsDouble())
b1, b2 = bond.GetBeginAtom(), bond.GetEndAtom()
for atom in ctr_atoms:
# Optimize if atom is carbon (other atoms may change valence)
if atom.GetAtomicNum() == 6 and atom.GetTotalNumHs() < bond_val:
continue
if atom_equal(atom, b1):
new_amap = amap + [(nei_idx, atom.GetIdx(), b1.GetIdx())]
att_confs.append(new_amap)
elif atom_equal(atom, b2):
new_amap = amap + [(nei_idx, atom.GetIdx(), b2.GetIdx())]
att_confs.append(new_amap)
else:
# intersection is an atom
for a1 in ctr_atoms:
for a2 in nei_mol.GetAtoms():
if atom_equal(a1, a2):
# Optimize if atom is carbon (other atoms may change valence)
if a1.GetAtomicNum() == 6 and a1.GetTotalNumHs() + a2.GetTotalNumHs() < 4:
continue
new_amap = amap + [(nei_idx, a1.GetIdx(), a2.GetIdx())]
att_confs.append(new_amap)
# intersection is an bond
if ctr_mol.GetNumBonds() > 1:
for b1 in ctr_bonds:
for b2 in nei_mol.GetBonds():
if ring_bond_equal(b1, b2):
new_amap = amap + [(nei_idx, b1.GetBeginAtom().GetIdx(), b2.GetBeginAtom(
).GetIdx()), (nei_idx, b1.GetEndAtom().GetIdx(), b2.GetEndAtom().GetIdx())]
att_confs.append(new_amap)
if ring_bond_equal(b1, b2, reverse=True):
new_amap = amap + [(nei_idx, b1.GetBeginAtom().GetIdx(), b2.GetEndAtom(
).GetIdx()), (nei_idx, b1.GetEndAtom().GetIdx(), b2.GetBeginAtom().GetIdx())]
att_confs.append(new_amap)
return att_confs
# Try rings first: Speed-Up
def enum_assemble_nx(node, neighbors, prev_nodes=[], prev_amap=[]):
all_attach_confs = []
singletons = [nei_node['nid'] for nei_node in neighbors +
prev_nodes if nei_node['mol'].GetNumAtoms() == 1]
def search(cur_amap, depth):
if len(all_attach_confs) > MAX_NCAND:
return
if depth == len(neighbors):
all_attach_confs.append(cur_amap)
return
nei_node = neighbors[depth]
cand_amap = enum_attach_nx(node['mol'], nei_node, cur_amap, singletons)
cand_smiles = set()
candidates = []
for amap in cand_amap:
cand_mol = local_attach_nx(
node['mol'], neighbors[:depth+1], prev_nodes, amap)
cand_mol = sanitize(cand_mol)
if cand_mol is None:
continue
smiles = get_smiles(cand_mol)
if smiles in cand_smiles:
continue
cand_smiles.add(smiles)
candidates.append(amap)
if len(candidates) == 0:
return []
for new_amap in candidates:
search(new_amap, depth + 1)
search(prev_amap, 0)
cand_smiles = set()
candidates = []
for amap in all_attach_confs:
cand_mol = local_attach_nx(node['mol'], neighbors, prev_nodes, amap)
cand_mol = Chem.MolFromSmiles(Chem.MolToSmiles(cand_mol))
smiles = Chem.MolToSmiles(cand_mol)
if smiles in cand_smiles:
continue
cand_smiles.add(smiles)
Chem.Kekulize(cand_mol)
candidates.append((smiles, cand_mol, amap))
return candidates
# Only used for debugging purpose
def dfs_assemble_nx(graph, cur_mol, global_amap, fa_amap, cur_node_id, fa_node_id):
cur_node = graph.nodes_dict[cur_node_id]
fa_node = graph.nodes_dict[fa_node_id] if fa_node_id is not None else None
fa_nid = fa_node['nid'] if fa_node is not None else -1
prev_nodes = [fa_node] if fa_node is not None else []
children_id = [nei for nei in graph[cur_node_id]
if graph.nodes_dict[nei]['nid'] != fa_nid]
children = [graph.nodes_dict[nei] for nei in children_id]
neighbors = [nei for nei in children if nei['mol'].GetNumAtoms() > 1]
neighbors = sorted(
neighbors, key=lambda x: x['mol'].GetNumAtoms(), reverse=True)
singletons = [nei for nei in children if nei['mol'].GetNumAtoms() == 1]
neighbors = singletons + neighbors
cur_amap = [(fa_nid, a2, a1)
for nid, a1, a2 in fa_amap if nid == cur_node['nid']]
cands = enum_assemble_nx(
graph.nodes_dict[cur_node_id], neighbors, prev_nodes, cur_amap)
if len(cands) == 0:
return
cand_smiles, _, cand_amap = zip(*cands)
label_idx = cand_smiles.index(cur_node['label'])
label_amap = cand_amap[label_idx]
for nei_id, ctr_atom, nei_atom in label_amap:
if nei_id == fa_nid:
continue
global_amap[nei_id][nei_atom] = global_amap[cur_node['nid']][ctr_atom]
# father is already attached
cur_mol = attach_mols_nx(cur_mol, children, [], global_amap)
for nei_node_id, nei_node in zip(children_id, children):
if not nei_node['is_leaf']:
dfs_assemble_nx(graph, cur_mol, global_amap,
label_amap, nei_node_id, cur_node_id)
def mol2dgl_dec(cand_batch):
# Note that during graph decoding they don't predict stereochemistry-related
# characteristics (i.e. Chiral Atoms, E-Z, Cis-Trans). Instead, they decode
# the 2-D graph first, then enumerate all possible 3-D forms and find the
# one with highest score.
def atom_features(atom):
return (torch.Tensor(onek_encoding_unk(atom.GetSymbol(), ELEM_LIST)
+ onek_encoding_unk(atom.GetDegree(),
[0, 1, 2, 3, 4, 5])
+ onek_encoding_unk(atom.GetFormalCharge(),
[-1, -2, 1, 2, 0])
+ [atom.GetIsAromatic()]))
def bond_features(bond):
bt = bond.GetBondType()
return (torch.Tensor([bt == Chem.rdchem.BondType.SINGLE, bt == Chem.rdchem.BondType.DOUBLE, bt == Chem.rdchem.BondType.TRIPLE, bt == Chem.rdchem.BondType.AROMATIC, bond.IsInRing()]))
cand_graphs = []
tree_mess_source_edges = [] # map these edges from trees to...
tree_mess_target_edges = [] # these edges on candidate graphs
tree_mess_target_nodes = []
n_nodes = 0
atom_x = []
bond_x = []
for mol, mol_tree, ctr_node_id in cand_batch:
n_atoms = mol.GetNumAtoms()
g = DGLGraph()
for i, atom in enumerate(mol.GetAtoms()):
assert i == atom.GetIdx()
atom_x.append(atom_features(atom))
g.add_nodes(n_atoms)
bond_src = []
bond_dst = []
for i, bond in enumerate(mol.GetBonds()):
a1 = bond.GetBeginAtom()
a2 = bond.GetEndAtom()
begin_idx = a1.GetIdx()
end_idx = a2.GetIdx()
features = bond_features(bond)
bond_src.append(begin_idx)
bond_dst.append(end_idx)
bond_x.append(features)
bond_src.append(end_idx)
bond_dst.append(begin_idx)
bond_x.append(features)
x_nid, y_nid = a1.GetAtomMapNum(), a2.GetAtomMapNum()
# Tree node ID in the batch
x_bid = mol_tree.nodes_dict[x_nid - 1]['idx'] if x_nid > 0 else -1
y_bid = mol_tree.nodes_dict[y_nid - 1]['idx'] if y_nid > 0 else -1
if x_bid >= 0 and y_bid >= 0 and x_bid != y_bid:
if mol_tree.has_edge_between(x_bid, y_bid):
tree_mess_target_edges.append(
(begin_idx + n_nodes, end_idx + n_nodes))
tree_mess_source_edges.append((x_bid, y_bid))
tree_mess_target_nodes.append(end_idx + n_nodes)
if mol_tree.has_edge_between(y_bid, x_bid):
tree_mess_target_edges.append(
(end_idx + n_nodes, begin_idx + n_nodes))
tree_mess_source_edges.append((y_bid, x_bid))
tree_mess_target_nodes.append(begin_idx + n_nodes)
n_nodes += n_atoms
g.add_edges(bond_src, bond_dst)
cand_graphs.append(g)
return cand_graphs, torch.stack(atom_x), \
torch.stack(bond_x) if len(bond_x) > 0 else torch.zeros(0), \
torch.LongTensor(tree_mess_source_edges), \
torch.LongTensor(tree_mess_target_edges), \
torch.LongTensor(tree_mess_target_nodes)
def mol2dgl_enc(smiles):
def atom_features(atom):
return (torch.Tensor(onek_encoding_unk(atom.GetSymbol(), ELEM_LIST)
+ onek_encoding_unk(atom.GetDegree(),
[0, 1, 2, 3, 4, 5])
+ onek_encoding_unk(atom.GetFormalCharge(), [-1, -2, 1, 2, 0])
+ onek_encoding_unk(int(atom.GetChiralTag()), [0, 1, 2, 3])
+ [atom.GetIsAromatic()]))
def bond_features(bond):
bt = bond.GetBondType()
stereo = int(bond.GetStereo())
fbond = [bt == Chem.rdchem.BondType.SINGLE, bt == Chem.rdchem.BondType.DOUBLE, bt ==
Chem.rdchem.BondType.TRIPLE, bt == Chem.rdchem.BondType.AROMATIC, bond.IsInRing()]
fstereo = onek_encoding_unk(stereo, [0, 1, 2, 3, 4, 5])
return (torch.Tensor(fbond + fstereo))
n_edges = 0
atom_x = []
bond_x = []
mol = get_mol(smiles)
n_atoms = mol.GetNumAtoms()
n_bonds = mol.GetNumBonds()
graph = DGLGraph()
for i, atom in enumerate(mol.GetAtoms()):
assert i == atom.GetIdx()
atom_x.append(atom_features(atom))
graph.add_nodes(n_atoms)
bond_src = []
bond_dst = []
for i, bond in enumerate(mol.GetBonds()):
begin_idx = bond.GetBeginAtom().GetIdx()
end_idx = bond.GetEndAtom().GetIdx()
features = bond_features(bond)
bond_src.append(begin_idx)
bond_dst.append(end_idx)
bond_x.append(features)
# set up the reverse direction
bond_src.append(end_idx)
bond_dst.append(begin_idx)
bond_x.append(features)
graph.add_edges(bond_src, bond_dst)
n_edges += n_bonds
return graph, torch.stack(atom_x), \
torch.stack(bond_x) if len(bond_x) > 0 else torch.zeros(0)
import dgl
import os
import torch
from torch.utils.data import Dataset
from dgl.data.utils import download, extract_archive, get_download_dir, _get_dgl_url
from .mol_tree import Vocab, DGLMolTree
from .chemutils import mol2dgl_dec, mol2dgl_enc
ELEM_LIST = ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na', 'Ca',
'Fe', 'Al', 'I', 'B', 'K', 'Se', 'Zn', 'H', 'Cu', 'Mn', 'unknown']
ATOM_FDIM_DEC = len(ELEM_LIST) + 6 + 5 + 1
BOND_FDIM_DEC = 5
MAX_NB = 10
PAPER = os.getenv('PAPER', False)
def _unpack_field(examples, field):
return [e[field] for e in examples]
def _set_node_id(mol_tree, vocab):
wid = []
for i, node in enumerate(mol_tree.nodes_dict):
mol_tree.nodes_dict[node]['idx'] = i
wid.append(vocab.get_index(mol_tree.nodes_dict[node]['smiles']))
return wid
class JTNNDataset(Dataset):
def __init__(self, data, vocab, training=True):
self.dir = get_download_dir()
self.zip_file_path='{}/jtnn.zip'.format(self.dir)
download(_get_dgl_url('dgllife/jtnn.zip'), path=self.zip_file_path)
extract_archive(self.zip_file_path, '{}/jtnn'.format(self.dir))
print('Loading data...')
if data in ['train', 'test']:
data_file = '{}/jtnn/{}.txt'.format(self.dir, data)
else:
data_file = data
with open(data_file) as f:
self.data = [line.strip("\r\n ").split()[0] for line in f]
self.vocab_file = '{}/jtnn/{}.txt'.format(self.dir, vocab)
print('Loading finished.')
print('\tNum samples:', len(self.data))
print('\tVocab file:', self.vocab_file)
self.training = training
self.vocab = Vocab([x.strip("\r\n ") for x in open(self.vocab_file)])
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
smiles = self.data[idx]
mol_tree = DGLMolTree(smiles)
mol_tree.recover()
mol_tree.assemble()
wid = _set_node_id(mol_tree, self.vocab)
# prebuild the molecule graph
mol_graph, atom_x_enc, bond_x_enc = mol2dgl_enc(mol_tree.smiles)
result = {
'mol_tree': mol_tree,
'mol_graph': mol_graph,
'atom_x_enc': atom_x_enc,
'bond_x_enc': bond_x_enc,
'wid': wid,
}
if not self.training:
return result
# prebuild the candidate graph list
cands = []
for node_id, node in mol_tree.nodes_dict.items():
# fill in ground truth
if node['label'] not in node['cands']:
node['cands'].append(node['label'])
node['cand_mols'].append(node['label_mol'])
if node['is_leaf'] or len(node['cands']) == 1:
continue
cands.extend([(cand, mol_tree, node_id)
for cand in node['cand_mols']])
if len(cands) > 0:
cand_graphs, atom_x_dec, bond_x_dec, tree_mess_src_e, \
tree_mess_tgt_e, tree_mess_tgt_n = mol2dgl_dec(cands)
else:
cand_graphs = []
atom_x_dec = torch.zeros(0, ATOM_FDIM_DEC)
bond_x_dec = torch.zeros(0, BOND_FDIM_DEC)
tree_mess_src_e = torch.zeros(0, 2).long()
tree_mess_tgt_e = torch.zeros(0, 2).long()
tree_mess_tgt_n = torch.zeros(0).long()
# prebuild the stereoisomers
cands = mol_tree.stereo_cands
if len(cands) > 1:
if mol_tree.smiles3D not in cands:
cands.append(mol_tree.smiles3D)
stereo_graphs = [mol2dgl_enc(c) for c in cands]
stereo_cand_graphs, stereo_atom_x_enc, stereo_bond_x_enc = \
zip(*stereo_graphs)
stereo_atom_x_enc = torch.cat(stereo_atom_x_enc)
stereo_bond_x_enc = torch.cat(stereo_bond_x_enc)
stereo_cand_label = [(cands.index(mol_tree.smiles3D), len(cands))]
else:
stereo_cand_graphs = []
stereo_atom_x_enc = torch.zeros(0, atom_x_enc.shape[1])
stereo_bond_x_enc = torch.zeros(0, bond_x_enc.shape[1])
stereo_cand_label = []
result.update({
'cand_graphs': cand_graphs,
'atom_x_dec': atom_x_dec,
'bond_x_dec': bond_x_dec,
'tree_mess_src_e': tree_mess_src_e,
'tree_mess_tgt_e': tree_mess_tgt_e,
'tree_mess_tgt_n': tree_mess_tgt_n,
'stereo_cand_graphs': stereo_cand_graphs,
'stereo_atom_x_enc': stereo_atom_x_enc,
'stereo_bond_x_enc': stereo_bond_x_enc,
'stereo_cand_label': stereo_cand_label,
})
return result
class JTNNCollator(object):
def __init__(self, vocab, training):
self.vocab = vocab
self.training = training
@staticmethod
def _batch_and_set(graphs, atom_x, bond_x, flatten):
if flatten:
graphs = [g for f in graphs for g in f]
graph_batch = dgl.batch(graphs)
graph_batch.ndata['x'] = atom_x
graph_batch.edata.update({
'x': bond_x,
'src_x': atom_x.new(bond_x.shape[0], atom_x.shape[1]).zero_(),
})
return graph_batch
def __call__(self, examples):
# get list of trees
mol_trees = _unpack_field(examples, 'mol_tree')
wid = _unpack_field(examples, 'wid')
for _wid, mol_tree in zip(wid, mol_trees):
mol_tree.ndata['wid'] = torch.LongTensor(_wid)
# TODO: either support pickling or get around ctypes pointers using scipy
# batch molecule graphs
mol_graphs = _unpack_field(examples, 'mol_graph')
atom_x = torch.cat(_unpack_field(examples, 'atom_x_enc'))
bond_x = torch.cat(_unpack_field(examples, 'bond_x_enc'))
mol_graph_batch = self._batch_and_set(mol_graphs, atom_x, bond_x, False)
result = {
'mol_trees': mol_trees,
'mol_graph_batch': mol_graph_batch,
}
if not self.training:
return result
# batch candidate graphs
cand_graphs = _unpack_field(examples, 'cand_graphs')
cand_batch_idx = []
atom_x = torch.cat(_unpack_field(examples, 'atom_x_dec'))
bond_x = torch.cat(_unpack_field(examples, 'bond_x_dec'))
tree_mess_src_e = _unpack_field(examples, 'tree_mess_src_e')
tree_mess_tgt_e = _unpack_field(examples, 'tree_mess_tgt_e')
tree_mess_tgt_n = _unpack_field(examples, 'tree_mess_tgt_n')
n_graph_nodes = 0
n_tree_nodes = 0
for i in range(len(cand_graphs)):
tree_mess_tgt_e[i] += n_graph_nodes
tree_mess_src_e[i] += n_tree_nodes
tree_mess_tgt_n[i] += n_graph_nodes
n_graph_nodes += sum(g.number_of_nodes() for g in cand_graphs[i])
n_tree_nodes += mol_trees[i].number_of_nodes()
cand_batch_idx.extend([i] * len(cand_graphs[i]))
tree_mess_tgt_e = torch.cat(tree_mess_tgt_e)
tree_mess_src_e = torch.cat(tree_mess_src_e)
tree_mess_tgt_n = torch.cat(tree_mess_tgt_n)
cand_graph_batch = self._batch_and_set(cand_graphs, atom_x, bond_x, True)
# batch stereoisomers
stereo_cand_graphs = _unpack_field(examples, 'stereo_cand_graphs')
atom_x = torch.cat(_unpack_field(examples, 'stereo_atom_x_enc'))
bond_x = torch.cat(_unpack_field(examples, 'stereo_bond_x_enc'))
stereo_cand_batch_idx = []
for i in range(len(stereo_cand_graphs)):
stereo_cand_batch_idx.extend([i] * len(stereo_cand_graphs[i]))
if len(stereo_cand_batch_idx) > 0:
stereo_cand_labels = [
(label, length)
for ex in _unpack_field(examples, 'stereo_cand_label')
for label, length in ex
]
stereo_cand_labels, stereo_cand_lengths = zip(*stereo_cand_labels)
stereo_cand_graph_batch = self._batch_and_set(
stereo_cand_graphs, atom_x, bond_x, True)
else:
stereo_cand_labels = []
stereo_cand_lengths = []
stereo_cand_graph_batch = None
stereo_cand_batch_idx = []
result.update({
'cand_graph_batch': cand_graph_batch,
'cand_batch_idx': cand_batch_idx,
'tree_mess_tgt_e': tree_mess_tgt_e,
'tree_mess_src_e': tree_mess_src_e,
'tree_mess_tgt_n': tree_mess_tgt_n,
'stereo_cand_graph_batch': stereo_cand_graph_batch,
'stereo_cand_batch_idx': stereo_cand_batch_idx,
'stereo_cand_labels': stereo_cand_labels,
'stereo_cand_lengths': stereo_cand_lengths,
})
return result
import copy
import numpy as np
from dgl import DGLGraph
import rdkit.Chem as Chem
from .chemutils import get_clique_mol, tree_decomp, get_mol, get_smiles, \
set_atommap, enum_assemble_nx, decode_stereo
def get_slots(smiles):
mol = Chem.MolFromSmiles(smiles)
return [(atom.GetSymbol(), atom.GetFormalCharge(), atom.GetTotalNumHs()) for atom in mol.GetAtoms()]
class Vocab(object):
def __init__(self, smiles_list):
self.vocab = smiles_list
self.vmap = {x:i for i,x in enumerate(self.vocab)}
self.slots = [get_slots(smiles) for smiles in self.vocab]
def get_index(self, smiles):
return self.vmap[smiles]
def get_smiles(self, idx):
return self.vocab[idx]
def get_slots(self, idx):
return copy.deepcopy(self.slots[idx])
def size(self):
return len(self.vocab)
class DGLMolTree(DGLGraph):
def __init__(self, smiles):
DGLGraph.__init__(self)
self.nodes_dict = {}
if smiles is None:
return
self.smiles = smiles
self.mol = get_mol(smiles)
# Stereo Generation
mol = Chem.MolFromSmiles(smiles)
self.smiles3D = Chem.MolToSmiles(mol, isomericSmiles=True)
self.smiles2D = Chem.MolToSmiles(mol)
self.stereo_cands = decode_stereo(self.smiles2D)
# cliques: a list of list of atom indices
cliques, edges = tree_decomp(self.mol)
root = 0
for i, c in enumerate(cliques):
cmol = get_clique_mol(self.mol, c)
csmiles = get_smiles(cmol)
self.nodes_dict[i] = dict(
smiles=csmiles,
mol=get_mol(csmiles),
clique=c,
)
if min(c) == 0:
root = i
self.add_nodes(len(cliques))
# The clique with atom ID 0 becomes root
if root > 0:
for attr in self.nodes_dict[0]:
self.nodes_dict[0][attr], self.nodes_dict[root][attr] = \
self.nodes_dict[root][attr], self.nodes_dict[0][attr]
src = np.zeros((len(edges) * 2,), dtype='int')
dst = np.zeros((len(edges) * 2,), dtype='int')
for i, (_x, _y) in enumerate(edges):
x = 0 if _x == root else root if _x == 0 else _x
y = 0 if _y == root else root if _y == 0 else _y
src[2 * i] = x
dst[2 * i] = y
src[2 * i + 1] = y
dst[2 * i + 1] = x
self.add_edges(src, dst)
for i in self.nodes_dict:
self.nodes_dict[i]['nid'] = i + 1
if self.out_degree(i) > 1: # Leaf node mol is not marked
set_atommap(self.nodes_dict[i]['mol'], self.nodes_dict[i]['nid'])
self.nodes_dict[i]['is_leaf'] = (self.out_degree(i) == 1)
def treesize(self):
return self.number_of_nodes()
def _recover_node(self, i, original_mol):
node = self.nodes_dict[i]
clique = []
clique.extend(node['clique'])
if not node['is_leaf']:
for cidx in node['clique']:
original_mol.GetAtomWithIdx(cidx).SetAtomMapNum(node['nid'])
for j in self.successors(i).numpy():
nei_node = self.nodes_dict[j]
clique.extend(nei_node['clique'])
if nei_node['is_leaf']: # Leaf node, no need to mark
continue
for cidx in nei_node['clique']:
# allow singleton node override the atom mapping
if cidx not in node['clique'] or len(nei_node['clique']) == 1:
atom = original_mol.GetAtomWithIdx(cidx)
atom.SetAtomMapNum(nei_node['nid'])
clique = list(set(clique))
label_mol = get_clique_mol(original_mol, clique)
node['label'] = Chem.MolToSmiles(Chem.MolFromSmiles(get_smiles(label_mol)))
node['label_mol'] = get_mol(node['label'])
for cidx in clique:
original_mol.GetAtomWithIdx(cidx).SetAtomMapNum(0)
return node['label']
def _assemble_node(self, i):
neighbors = [self.nodes_dict[j] for j in self.successors(i).numpy()
if self.nodes_dict[j]['mol'].GetNumAtoms() > 1]
neighbors = sorted(neighbors, key=lambda x: x['mol'].GetNumAtoms(), reverse=True)
singletons = [self.nodes_dict[j] for j in self.successors(i).numpy()
if self.nodes_dict[j]['mol'].GetNumAtoms() == 1]
neighbors = singletons + neighbors
cands = enum_assemble_nx(self.nodes_dict[i], neighbors)
if len(cands) > 0:
self.nodes_dict[i]['cands'], self.nodes_dict[i]['cand_mols'], _ = list(zip(*cands))
self.nodes_dict[i]['cands'] = list(self.nodes_dict[i]['cands'])
self.nodes_dict[i]['cand_mols'] = list(self.nodes_dict[i]['cand_mols'])
else:
self.nodes_dict[i]['cands'] = []
self.nodes_dict[i]['cand_mols'] = []
def recover(self):
for i in self.nodes_dict:
self._recover_node(i, self.mol)
def assemble(self):
for i in self.nodes_dict:
self._assemble_node(i)
\ No newline at end of file
import torch
import torch.nn as nn
import os
def cuda(tensor):
if torch.cuda.is_available() and not os.getenv('NOCUDA', None):
return tensor.cuda()
else:
return tensor
class GRUUpdate(nn.Module):
def __init__(self, hidden_size):
nn.Module.__init__(self)
self.hidden_size = hidden_size
self.W_z = nn.Linear(2 * hidden_size, hidden_size)
self.W_r = nn.Linear(hidden_size, hidden_size, bias=False)
self.U_r = nn.Linear(hidden_size, hidden_size)
self.W_h = nn.Linear(2 * hidden_size, hidden_size)
def update_zm(self, node):
src_x = node.data['src_x']
s = node.data['s']
rm = node.data['accum_rm']
z = torch.sigmoid(self.W_z(torch.cat([src_x, s], 1)))
m = torch.tanh(self.W_h(torch.cat([src_x, rm], 1)))
m = (1 - z) * s + z * m
return {'m': m, 'z': z}
def update_r(self, node, zm=None):
dst_x = node.data['dst_x']
m = node.data['m'] if zm is None else zm['m']
r_1 = self.W_r(dst_x)
r_2 = self.U_r(m)
r = torch.sigmoid(r_1 + r_2)
return {'r': r, 'rm': r * m}
def forward(self, node):
dic = self.update_zm(node)
dic.update(self.update_r(node, zm=dic))
return dic
def move_dgl_to_cuda(g):
g.ndata.update({k: cuda(g.ndata[k]) for k in g.ndata})
g.edata.update({k: cuda(g.edata[k]) for k in g.edata})
import torch
from torch.utils.data import DataLoader
import argparse
from dgl import model_zoo
import rdkit
from jtnn import *
def worker_init_fn(id_):
lg = rdkit.RDLogger.logger()
lg.setLevel(rdkit.RDLogger.CRITICAL)
worker_init_fn(None)
parser = argparse.ArgumentParser(description="Evaluation for JTNN",
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("-t", "--train", dest="train",
default='test', help='Training file name')
parser.add_argument("-v", "--vocab", dest="vocab",
default='vocab', help='Vocab file name')
parser.add_argument("-m", "--model", dest="model_path", default=None,
help="Pre-trained model to be loaded for evalutaion. If not specified,"
" would use pre-trained model from model zoo")
parser.add_argument("-w", "--hidden", dest="hidden_size", default=450,
help="Hidden size of representation vector, "
"should be consistent with pre-trained model")
parser.add_argument("-l", "--latent", dest="latent_size", default=56,
help="Latent Size of node(atom) features and edge(atom) features, "
"should be consistent with pre-trained model")
parser.add_argument("-d", "--depth", dest="depth", default=3,
help="Depth of message passing hops, "
"should be consistent with pre-trained model")
args = parser.parse_args()
dataset = JTNNDataset(data=args.train, vocab=args.vocab, training=False)
vocab_file = dataset.vocab_file
hidden_size = int(args.hidden_size)
latent_size = int(args.latent_size)
depth = int(args.depth)
model = model_zoo.chem.DGLJTNNVAE(vocab_file=vocab_file,
hidden_size=hidden_size,
latent_size=latent_size,
depth=depth)
if args.model_path is not None:
model.load_state_dict(torch.load(args.model_path))
else:
model = model_zoo.chem.load_pretrained("JTNN_ZINC")
model = cuda(model)
model.eval()
print("Model #Params: %dK" %
(sum([x.nelement() for x in model.parameters()]) / 1000,))
MAX_EPOCH = 100
PRINT_ITER = 20
def reconstruct():
dataset.training = False
dataloader = DataLoader(
dataset,
batch_size=1,
shuffle=False,
num_workers=0,
collate_fn=JTNNCollator(dataset.vocab, False),
drop_last=True,
worker_init_fn=worker_init_fn)
# Just an example of molecule decoding; in reality you may want to sample
# tree and molecule vectors.
acc = 0.0
tot = 0
with torch.no_grad():
for it, batch in enumerate(dataloader):
gt_smiles = batch['mol_trees'][0].smiles
# print(gt_smiles)
model.move_to_cuda(batch)
try:
_, tree_vec, mol_vec = model.encode(batch)
tree_mean = model.T_mean(tree_vec)
# Following Mueller et al.
tree_log_var = -torch.abs(model.T_var(tree_vec))
mol_mean = model.G_mean(mol_vec)
# Following Mueller et al.
mol_log_var = -torch.abs(model.G_var(mol_vec))
epsilon = torch.randn(1, model.latent_size // 2).cuda()
tree_vec = tree_mean + torch.exp(tree_log_var // 2) * epsilon
epsilon = torch.randn(1, model.latent_size // 2).cuda()
mol_vec = mol_mean + torch.exp(mol_log_var // 2) * epsilon
dec_smiles = model.decode(tree_vec, mol_vec)
if dec_smiles == gt_smiles:
acc += 1
tot += 1
except Exception as e:
print("Failed to encode: {}".format(gt_smiles))
print(e)
if it % 20 == 1:
print("Progress {}/{}; Current Reconstruction Accuracy: {:.4f}".format(it,
len(dataloader), acc / tot))
return acc / tot
if __name__ == '__main__':
reconstruct_acc = reconstruct()
print("Reconstruction Accuracy: {}".format(reconstruct_acc))
import torch
import torch.nn as nn
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
from dgl import model_zoo
from torch.utils.data import DataLoader
import sys
import argparse
import rdkit
from jtnn import *
torch.multiprocessing.set_sharing_strategy('file_system')
def worker_init_fn(id_):
lg = rdkit.RDLogger.logger()
lg.setLevel(rdkit.RDLogger.CRITICAL)
worker_init_fn(None)
parser = argparse.ArgumentParser(description="Training for JTNN",
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("-t", "--train", dest="train", default='train', help='Training file name')
parser.add_argument("-v", "--vocab", dest="vocab", default='vocab', help='Vocab file name')
parser.add_argument("-s", "--save_dir", dest="save_path", default='./',
help="Path to save checkpoint models, default to be current working directory")
parser.add_argument("-m", "--model", dest="model_path", default=None,
help="Path to load pre-trained model")
parser.add_argument("-b", "--batch", dest="batch_size", default=40,
help="Batch size")
parser.add_argument("-w", "--hidden", dest="hidden_size", default=200,
help="Size of representation vectors")
parser.add_argument("-l", "--latent", dest="latent_size", default=56,
help="Latent Size of node(atom) features and edge(atom) features")
parser.add_argument("-d", "--depth", dest="depth", default=3,
help="Depth of message passing hops")
parser.add_argument("-z", "--beta", dest="beta", default=1.0,
help="Coefficient of KL Divergence term")
parser.add_argument("-q", "--lr", dest="lr", default=1e-3,
help="Learning Rate")
args = parser.parse_args()
dataset = JTNNDataset(data=args.train, vocab=args.vocab, training=True)
vocab_file = dataset.vocab_file
batch_size = int(args.batch_size)
hidden_size = int(args.hidden_size)
latent_size = int(args.latent_size)
depth = int(args.depth)
beta = float(args.beta)
lr = float(args.lr)
model = model_zoo.chem.DGLJTNNVAE(vocab_file=vocab_file,
hidden_size=hidden_size,
latent_size=latent_size,
depth=depth)
if args.model_path is not None:
model.load_state_dict(torch.load(args.model_path))
else:
for param in model.parameters():
if param.dim() == 1:
nn.init.constant_(param, 0)
else:
nn.init.xavier_normal_(param)
model = cuda(model)
print("Model #Params: %dK" % (sum([x.nelement() for x in model.parameters()]) / 1000,))
optimizer = optim.Adam(model.parameters(), lr=lr)
scheduler = lr_scheduler.ExponentialLR(optimizer, 0.9)
scheduler.step()
MAX_EPOCH = 100
PRINT_ITER = 20
def train():
dataset.training = True
dataloader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=True,
num_workers=4,
collate_fn=JTNNCollator(dataset.vocab, True),
drop_last=True,
worker_init_fn=worker_init_fn)
for epoch in range(MAX_EPOCH):
word_acc, topo_acc, assm_acc, steo_acc = 0, 0, 0, 0
for it, batch in enumerate(dataloader):
model.zero_grad()
try:
loss, kl_div, wacc, tacc, sacc, dacc = model(batch, beta)
except:
print([t.smiles for t in batch['mol_trees']])
raise
loss.backward()
optimizer.step()
word_acc += wacc
topo_acc += tacc
assm_acc += sacc
steo_acc += dacc
if (it + 1) % PRINT_ITER == 0:
word_acc = word_acc / PRINT_ITER * 100
topo_acc = topo_acc / PRINT_ITER * 100
assm_acc = assm_acc / PRINT_ITER * 100
steo_acc = steo_acc / PRINT_ITER * 100
print("KL: %.1f, Word: %.2f, Topo: %.2f, Assm: %.2f, Steo: %.2f, Loss: %.6f" % (
kl_div, word_acc, topo_acc, assm_acc, steo_acc, loss.item()))
word_acc, topo_acc, assm_acc, steo_acc = 0, 0, 0, 0
sys.stdout.flush()
if (it + 1) % 1500 == 0: # Fast annealing
scheduler.step()
print("learning rate: %.6f" % scheduler.get_lr()[0])
torch.save(model.state_dict(),
args.save_path + "/model.iter-%d-%d" % (epoch, it + 1))
scheduler.step()
print("learning rate: %.6f" % scheduler.get_lr()[0])
torch.save(model.state_dict(), args.save_path + "/model.iter-" + str(epoch))
if __name__ == '__main__':
train()
print('# passes:', model.n_passes)
print('Total # nodes processed:', model.n_nodes_total)
print('Total # edges processed:', model.n_edges_total)
print('Total # tree nodes processed:', model.n_tree_nodes_total)
print('Graph decoder: # passes:', model.jtmpn.n_passes)
print('Graph decoder: Total # candidates processed:', model.jtmpn.n_samples_total)
print('Graph decoder: Total # nodes processed:', model.jtmpn.n_nodes_total)
print('Graph decoder: Total # edges processed:', model.jtmpn.n_edges_total)
print('Graph encoder: # passes:', model.mpn.n_passes)
print('Graph encoder: Total # candidates processed:', model.mpn.n_samples_total)
print('Graph encoder: Total # nodes processed:', model.mpn.n_nodes_total)
print('Graph encoder: Total # edges processed:', model.mpn.n_edges_total)
# Property Prediction
## Classification
Classification tasks require assigning discrete labels to a molecule, e.g. molecule toxicity.
### Datasets
- **Tox21**. The ["Toxicology in the 21st Century" (Tox21)](https://tripod.nih.gov/tox21/challenge/) initiative created
a public database measuring toxicity of compounds, which has been used in the 2014 Tox21 Data Challenge. The dataset
contains qualitative toxicity measurements for 8014 compounds on 12 different targets, including nuclear receptors and
stress response pathways. Each target yields a binary prediction problem. MoleculeNet [1] randomly splits the dataset
into training, validation and test set with a 80/10/10 ratio. By default we follow their split method.
### Models
- **Graph Convolutional Network** [2], [3]. Graph Convolutional Networks (GCN) have been one of the most popular graph neural
networks and they can be easily extended for graph level prediction. MoleculeNet [1] reports baseline results of graph
convolutions over multiple datasets.
- **Graph Attention Networks** [7]. Graph Attention Networks (GATs) incorporate multi-head attention into GCNs,
explicitly modeling the interactions between adjacent atoms.
### Usage
Use `classification.py` with arguments
```
-m {GCN, GAT}, MODEL, model to use
-d {Tox21}, DATASET, dataset to use
```
If you want to use the pre-trained model, simply add `-p`.
We use GPU whenever it is available.
### Performance
#### GCN on Tox21
| Source | Averaged Test ROC-AUC Score |
| ---------------- | --------------------------- |
| MoleculeNet [1] | 0.829 |
| [DeepChem example](https://github.com/deepchem/deepchem/blob/master/examples/tox21/tox21_tensorgraph_graph_conv.py) | 0.813 |
| Pretrained model | 0.826 |
Note that the dataset is randomly split so these numbers are only for reference and they do not necessarily suggest
a real difference.
#### GAT on Tox21
| Source | Averaged Test ROC-AUC Score |
| ---------------- | --------------------------- |
| Pretrained model | 0.827 |
## Regression
Regression tasks require assigning continuous labels to a molecule, e.g. molecular energy.
### Datasets
- **Alchemy**. The [Alchemy Dataset](https://alchemy.tencent.com/) is introduced by Tencent Quantum Lab to facilitate the development of new
machine learning models useful for chemistry and materials science. The dataset lists 12 quantum mechanical properties of 130,000+ organic
molecules comprising up to 12 heavy atoms (C, N, O, S, F and Cl), sampled from the [GDBMedChem](http://gdb.unibe.ch/downloads/) database.
These properties have been calculated using the open-source computational chemistry program Python-based Simulation of Chemistry Framework
([PySCF](https://github.com/pyscf/pyscf)). The Alchemy dataset expands on the volume and diversity of existing molecular datasets such as QM9.
- **PubChem BioAssay Aromaticity**. The dataset is introduced in
[Pushing the Boundaries of Molecular Representation for Drug Discovery with the Graph Attention Mechanism](https://www.ncbi.nlm.nih.gov/pubmed/31408336),
for the task of predicting the number of aromatic atoms in molecules. The dataset was constructed by sampling 3945 molecules with 0-40 aromatic atoms
from the PubChem BioAssay dataset.
### Models
- **Message Passing Neural Network** [6]. Message Passing Neural Networks (MPNNs) have reached the best performance on
the QM9 dataset for some time.
- **SchNet** [4]. SchNet employs continuous filter convolutional layers to model quantum interactions in molecules
without requiring them to lie on grids.
- **Multilevel Graph Convolutional Neural Network** [5]. Multilevel Graph Convolutional Neural Networks (MGCN) are
hierarchical graph neural networks that extract features from the conformation and spatial information followed by the
multilevel interactions.
- **AttentiveFP** [8]. AttentiveFP combines attention and GRU for better model capacity and shows competitive
performance across datasetts.
### Usage
Use `regression.py` with arguments
```
-m {MPNN, SCHNET, MGCN, AttentiveFP}, Model to use
-d {Alchemy, Aromaticity}, Dataset to use
```
If you want to use the pre-trained model, simply add `-p`. Currently we only support pre-trained models of AttentiveFP
on PubChem BioAssay Aromaticity dataset.
### Performance
#### Alchemy
The Alchemy contest is still ongoing. Before the test set is fully released, we only include the performance numbers
on the training and validation set for reference.
| Model | Training MAE | Validation MAE |
| ---------- | ------------ | -------------- |
| SchNet [4] | 0.2665 | 0.6139 |
| MGCN [5] | 0.2395 | 0.6463 |
| MPNN [6] | 0.2452 | 0.6259 |
#### PubChem BioAssay Aromaticity
| Model | Test RMSE |
| --------------- | --------- |
| AttentiveFP [8] | 0.6998 |
## Interpretation
[8] visualizes the weights of atoms in readout for possible interpretations like the figure below.
We provide a jupyter notebook for performing the visualization and you can download it with
`wget https://data.dgl.ai/model_zoo/drug_discovery/AttentiveFP/atom_weight_visualization.ipynb`.
![](https://data.dgl.ai/dgllife/attentive_fp_vis_example.png)
## Dataset Customization
Generally we follow the practice of PyTorch.
A Dataset class should implement `__getitem__(self, index)` and `__len__(self)` method
```python
class CustomDataset(object):
def __init__(self):
pass
def __getitem__(self, index):
"""
Parameters
----------
index : int
Index for the datapoint.
Returns
-------
str
SMILES for the molecule
DGLGraph
Constructed DGLGraph for the molecule
1D Tensor of dtype float32
Labels of the datapoint
"""
return self.smiles[index], self.graphs[index], self.labels[index]
def __len__(self):
return len(self.smiles)
```
We provide various methods for graph construction in `dgl.data.chem.utils.mol_to_graph`. If your dataset can
be converted to a pandas dataframe, e.g. a .csv file, you may use `MoleculeCSVDataset` in
`dgl.data.chem.datasets.csv_dataset`.
## References
[1] Wu et al. (2017) MoleculeNet: a benchmark for molecular machine learning. *Chemical Science* 9, 513-530.
[2] Duvenaud et al. (2015) Convolutional networks on graphs for learning molecular fingerprints. *Advances in neural
information processing systems (NeurIPS)*, 2224-2232.
[3] Kipf et al. (2017) Semi-Supervised Classification with Graph Convolutional Networks.
*The International Conference on Learning Representations (ICLR)*.
[4] Schütt et al. (2017) SchNet: A continuous-filter convolutional neural network for modeling quantum interactions.
*Advances in Neural Information Processing Systems (NeurIPS)*, 992-1002.
[5] Lu et al. (2019) Molecular Property Prediction: A Multilevel Quantum Interactions Modeling Perspective.
*The 33rd AAAI Conference on Artificial Intelligence*.
[6] Gilmer et al. (2017) Neural Message Passing for Quantum Chemistry. *Proceedings of the 34th International Conference on
Machine Learning*, JMLR. 1263-1272.
[7] Veličković et al. (2018) Graph Attention Networks.
*The International Conference on Learning Representations (ICLR)*.
[8] Xiong et al. (2019) Pushing the Boundaries of Molecular Representation for Drug Discovery with the Graph
Attention Mechanism. *Journal of Medicinal Chemistry*.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment