Unverified Commit 73b2668f authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[Model Zoo] DGMG for molecule generation (#783)

* DGMG for molecule generation

* Fix CI check

* Fix for CI

* Trial for CI due to shared memory

* Update

* Better interface for dataset configuration

* Update

* Handle corner cases

* Update README

* Fix

* Fix

* Fix

* Fix

* Fix

* Refactor

* Fix

* Fix

* Fix

* Fix

* Fix

* Fix

* Update

* Fix

* Fix

* Fix

* Fix

* Fix

* Finallly
parent 11fb217a
# Learning Deep Generative Models of Graphs (DGMG)
Yujia Li, Oriol Vinyals, Chris Dyer, Razvan Pascanu, and Peter Battaglia.
Learning Deep Generative Models of Graphs. *arXiv preprint arXiv:1803.03324*, 2018.
DGMG generates graphs by progressively adding nodes and edges as below:
![](https://user-images.githubusercontent.com/19576924/48605003-7f11e900-e9b6-11e8-8880-87362348e154.png)
For molecules, the nodes are atoms and the edges are bonds.
**Goal**: Given a set of real molecules, we want to learn the distribution of them and get new molecules
with similar properties. See the `Evaluation` section for more details.
## Dataset
### Preprocessing
With our implementation, this model has several limitations:
1. Information about protonation and chirality are ignored during generation
2. Molecules consisting of `[N+]`, `[O-]`, etc. cannot be generated.
For example, the model can only generate `O=C1NC(=S)NC(=O)C1=CNC1=CC=C(N(=O)O)C=C1O` from
`O=C1NC(=S)NC(=O)C1=CNC1=CC=C([N+](=O)[O-])C=C1O` even with the correct decisions.
To avoid issues about validity and novelty, we filter out these molecules from the dataset.
### ChEMBL
The authors use the [ChEMBL database](https://www.ebi.ac.uk/chembl/). Since they
did not release the code, we use a subset from [Olivecrona et al.](https://github.com/MarcusOlivecrona/REINVENT),
another work on generative modeling.
The authors restrict their dataset to molecules with at most 20 heavy atoms, and used a training/validation
split of 130, 830/26, 166 examples each. We use the same split but need to relax 20 to 23 as we are using
a different subset.
### ZINC
After the pre-processing, we are left with 232464 molecules for training and 5000 molecules for validation.
## Usage
### Training
Training auto-regressive generative models tends to be very slow. According to the authors, they use multiprocess to
speed up training and gpu does not give much speed advantage. We follow their approach and perform multiprocess cpu
training.
To start training, use `train.py` with required arguments
```
-d DATASET, dataset to use (default: None), built-in support exists for ChEMBL, ZINC
-o {random,canonical}, order to generate graphs (default: None)
```
and optional arguments
```
-s SEED, random seed (default: 0)
-np NUM_PROCESSES, number of processes to use (default: 32)
```
Even though multiprocess yields a significant speedup comparing to a single process, the training can still take a long
time (several days). An epoch of training and validation can take up to one hour and a half on our machine. If not
necessary, we recommend users use our pre-trained models.
Meanwhile, we make a checkpoint of our model whenever there is a performance improvement on the validation set so you
do not need to wait until the training terminates.
All training results can be found in `training_results`.
#### Dataset configuration
You can also use your own dataset with additional arguments
```
-tf TRAIN_FILE, Path to a file with one SMILES a line for training
data. This is only necessary if you want to use a new
dataset. (default: None)
-vf VAL_FILE, Path to a file with one SMILES a line for validation
data. This is only necessary if you want to use a new
dataset. (default: None)
```
#### Monitoring
We can monitor the training process with tensorboard as below:
![](https://s3.us-east-2.amazonaws.com/dgl.ai/model_zoo/drug_discovery/dgmg/tensorboard.png)
To use tensorboard, you need to install [tensorboardX](https://github.com/lanpa/tensorboardX) and
[TensorFlow](https://www.tensorflow.org/). You can lunch tensorboard with `tensorboard --logdir=.`
If you are training on a remote server, you can still use it with:
1. Launch it on the remote server with `tensorboard --logdir=. --port=A`
2. In the terminal of your local machine, type `ssh -NfL localhost:B:localhost:A username@your_remote_host_name`
3. Go to the address `localhost:B` in your browser
### Evaluation
To start evaluation, use `eval.py` with required arguments
```
-d DATASET, dataset to use (default: None), built-in support exists for ChEMBL, ZINC
-o {random,canonical}, order to generate graphs, used for naming evaluation directory (default: None)
-p MODEL_PATH, path to saved model (default: None). This is not needed if you want to use pretrained models.
-pr, Whether to use a pre-trained model (default: False)
```
and optional arguments
```
-s SEED, random seed (default: 0)
-ns NUM_SAMPLES, Number of molecules to generate (default: 100000)
-mn MAX_NUM_STEPS, Max number of steps allowed in generated molecules to
ensure termination (default: 400)
-np NUM_PROCESSES, number of processes to use (default: 32)
-gt GENERATION_TIME, max time (seconds) allowed for generation with
multiprocess (default: 600)
```
All evaluation results can be found in `eval_results`.
After the evaluation, 100000 molecules will be generated and stored in `generated_smiles.txt` under `eval_results`
directory, with three statistics logged in `generation_stats.txt` under `eval_results`:
1. `Validity among all` gives the percentage of molecules that are valid
2. `Uniqueness among valid ones` gives the percentage of valid molecules that are unique
3. `Novelty among unique ones` gives the percentage of unique valid molecules that are novel (not seen in training data)
We also provide a jupyter notebook where you can visualize the generated molecules
![](https://s3.us-east-2.amazonaws.com/dgl.ai/model_zoo/drug_discovery/dgmg/DGMG_ZINC_canonical_vis.png)
and compare their property distributions against the training molecule property distributions
![](https://s3.us-east-2.amazonaws.com/dgl.ai/model_zoo/drug_discovery/dgmg/DGMG_ZINC_canonical_dist.png)
Download it with `wget https://s3.us-east-2.amazonaws.com/dgl.ai/model_zoo/drug_discovery/dgmg/eval_jupyter.ipynb`.
### Pre-trained models
Below gives the statistics of pre-trained models. With random order, the training becomes significantly more difficult
as we now have `N^2` data points with `N` molecules.
| Pre-trained model | % valid | % unique among valid | % novel among unique |
| ------------------ | ------- | -------------------- | -------------------- |
| `ChEMBL_canonical` | 78.80 | 99.19 | 98.60 |
| `ChEMBL_random` | 29.09 | 99.87 | 100.00 |
| `ZINC_canonical` | 74.60 | 99.87 | 99.87 |
| `ZINC_random` | 12.37 | 99.38 | 100.00 |
import os
import pickle
import shutil
import torch
from dgl import model_zoo
from utils import MoleculeDataset, set_random_seed, download_data,\
mkdir_p, summarize_molecules, get_unique_smiles, get_novel_smiles
def generate_and_save(log_dir, num_samples, max_num_steps, model):
with open(os.path.join(log_dir, 'generated_smiles.txt'), 'w') as f:
for i in range(num_samples):
with torch.no_grad():
s = model(rdkit_mol=True, max_num_steps=max_num_steps)
f.write(s + '\n')
def prepare_for_evaluation(rank, args):
worker_seed = args['seed'] + rank * 10000
set_random_seed(worker_seed)
torch.set_num_threads(1)
# Setup dataset and data loader
dataset = MoleculeDataset(args['dataset'], subset_id=rank, n_subsets=args['num_processes'])
# Initialize model
if not args['pretrained']:
model = model_zoo.chem.DGMG(atom_types=dataset.atom_types,
bond_types=dataset.bond_types,
node_hidden_size=args['node_hidden_size'],
num_prop_rounds=args['num_propagation_rounds'], dropout=args['dropout'])
model.load_state_dict(torch.load(args['model_path'])['model_state_dict'])
else:
model = model_zoo.chem.load_pretrained('_'.join(['DGMG', args['dataset'], args['order']]), log=False)
model.eval()
worker_num_samples = args['num_samples'] // args['num_processes']
if rank == args['num_processes'] - 1:
worker_num_samples += args['num_samples'] % args['num_processes']
worker_log_dir = os.path.join(args['log_dir'], str(rank))
mkdir_p(worker_log_dir, log=False)
generate_and_save(worker_log_dir, worker_num_samples, args['max_num_steps'], model)
def remove_worker_tmp_dir(args):
for rank in range(args['num_processes']):
worker_path = os.path.join(args['log_dir'], str(rank))
try:
shutil.rmtree(worker_path)
except OSError:
print('Directory {} does not exist!'.format(worker_path))
def aggregate_and_evaluate(args):
print('Merging generated SMILES into a single file...')
smiles = []
for rank in range(args['num_processes']):
with open(os.path.join(args['log_dir'], str(rank), 'generated_smiles.txt'), 'r') as f:
rank_smiles = f.read().splitlines()
smiles.extend(rank_smiles)
with open(os.path.join(args['log_dir'], 'generated_smiles.txt'), 'w') as f:
for s in smiles:
f.write(s + '\n')
print('Removing temporary dirs...')
remove_worker_tmp_dir(args)
# Summarize training molecules
print('Summarizing training molecules...')
train_file = '_'.join([args['dataset'], 'DGMG_train.txt'])
if not os.path.exists(train_file):
download_data(args['dataset'], train_file)
with open(train_file, 'r') as f:
train_smiles = f.read().splitlines()
train_summary = summarize_molecules(train_smiles, args['num_processes'])
with open(os.path.join(args['log_dir'], 'train_summary.pickle'), 'wb') as f:
pickle.dump(train_summary, f)
# Summarize generated molecules
print('Summarizing generated molecules...')
generation_summary = summarize_molecules(smiles, args['num_processes'])
with open(os.path.join(args['log_dir'], 'generation_summary.pickle'), 'wb') as f:
pickle.dump(generation_summary, f)
# Stats computation
print('Preparing generation statistics...')
valid_generated_smiles = generation_summary['smile']
unique_generated_smiles = get_unique_smiles(valid_generated_smiles)
unique_train_smiles = get_unique_smiles(train_summary['smile'])
novel_generated_smiles = get_novel_smiles(unique_generated_smiles, unique_train_smiles)
with open(os.path.join(args['log_dir'], 'generation_stats.txt'), 'w') as f:
f.write('Total number of generated molecules: {:d}\n'.format(len(smiles)))
f.write('Validity among all: {:.4f}\n'.format(
len(valid_generated_smiles) / len(smiles)))
f.write('Uniqueness among valid ones: {:.4f}\n'.format(
len(unique_generated_smiles) / len(valid_generated_smiles)))
f.write('Novelty among unique ones: {:.4f}\n'.format(
len(novel_generated_smiles) / len(unique_generated_smiles)))
if __name__ == '__main__':
import argparse
import datetime
import time
from rdkit import rdBase
from utils import setup
parser = argparse.ArgumentParser(description='Evaluating DGMG for molecule generation',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
# configure
parser.add_argument('-s', '--seed', type=int, default=0, help='random seed')
# dataset and setting
parser.add_argument('-d', '--dataset',
help='dataset to use')
parser.add_argument('-o', '--order', choices=['random', 'canonical'],
help='order to generate graphs, used for naming evaluation directory')
# log
parser.add_argument('-l', '--log-dir', default='./eval_results',
help='folder to save evaluation results')
parser.add_argument('-p', '--model-path', type=str, default=None,
help='path to saved model')
parser.add_argument('-pr', '--pretrained', action='store_true',
help='Whether to use a pre-trained model')
parser.add_argument('-ns', '--num-samples', type=int, default=100000,
help='Number of molecules to generate')
parser.add_argument('-mn', '--max-num-steps', type=int, default=400,
help='Max number of steps allowed in generated molecules to ensure termination')
# multi-process
parser.add_argument('-np', '--num-processes', type=int, default=32,
help='number of processes to use')
parser.add_argument('-gt', '--generation-time', type=int, default=600,
help='max time (seconds) allowed for generation with multiprocess')
args = parser.parse_args()
args = setup(args, train=False)
rdBase.DisableLog('rdApp.error')
t1 = time.time()
if args['num_processes'] == 1:
prepare_for_evaluation(0, args)
else:
import multiprocessing as mp
procs = []
for rank in range(args['num_processes']):
p = mp.Process(target=prepare_for_evaluation, args=(rank, args,))
procs.append(p)
p.start()
while time.time() - t1 <= args['generation_time']:
if any(p.is_alive() for p in procs):
time.sleep(5)
else:
break
else:
print('Timeout, killing all processes.')
for p in procs:
p.terminate()
p.join()
t2 = time.time()
print('It took {} for generation.'.format(
datetime.timedelta(seconds=t2 - t1)))
aggregate_and_evaluate(args)
#
# calculation of synthetic accessibility score as described in:
#
# Estimation of Synthetic Accessibility Score of Drug-like Molecules
# based on Molecular Complexity and Fragment Contributions
# Peter Ertl and Ansgar Schuffenhauer
# Journal of Cheminformatics 1:8 (2009)
# http://www.jcheminf.com/content/1/1/8
#
# several small modifications to the original paper are included
# particularly slightly different formula for marocyclic penalty
# and taking into account also molecule symmetry (fingerprint density)
#
# for a set of 10k diverse molecules the agreement between the original method
# as implemented in PipelinePilot and this implementation is r2 = 0.97
#
# peter ertl & greg landrum, september 2013
#
# A small modification is performed
#
# DGL team, August 2019
#
from __future__ import print_function
import math
import os
from rdkit import Chem
from rdkit.Chem import rdMolDescriptors
from rdkit.six.moves import cPickle
from rdkit.six import iteritems
from dgl.data.utils import download, _get_dgl_url, get_download_dir
_fscores = None
def readFragmentScores(name='fpscores'):
import gzip
global _fscores
fname = '{}.pkl.gz'.format(name)
download(_get_dgl_url(os.path.join('dataset', fname)), path=fname)
_fscores = cPickle.load(gzip.open(fname))
outDict = {}
for i in _fscores:
for j in range(1, len(i)):
outDict[i[j]] = float(i[0])
_fscores = outDict
def numBridgeheadsAndSpiro(mol):
nSpiro = rdMolDescriptors.CalcNumSpiroAtoms(mol)
nBridgehead = rdMolDescriptors.CalcNumBridgeheadAtoms(mol)
return nBridgehead, nSpiro
def calculateScore(m):
if _fscores is None:
readFragmentScores()
# fragment score
# 2 is the *radius* of the circular fingerprint
fp = rdMolDescriptors.GetMorganFingerprint(m, 2)
fps = fp.GetNonzeroElements()
score1 = 0.
nf = 0
for bitId, v in iteritems(fps):
nf += v
sfp = bitId
score1 += _fscores.get(sfp, -4) * v
# We add L63 to avoid ZeroDivisionError.
if nf != 0:
score1 /= nf
# features score
nAtoms = m.GetNumAtoms()
nChiralCenters = len(Chem.FindMolChiralCenters(m, includeUnassigned=True))
ri = m.GetRingInfo()
nBridgeheads, nSpiro = numBridgeheadsAndSpiro(m)
nMacrocycles = 0
for x in ri.AtomRings():
if len(x) > 8:
nMacrocycles += 1
sizePenalty = nAtoms**1.005 - nAtoms
stereoPenalty = math.log10(nChiralCenters + 1)
spiroPenalty = math.log10(nSpiro + 1)
bridgePenalty = math.log10(nBridgeheads + 1)
macrocyclePenalty = 0.
# ---------------------------------------
# This differs from the paper, which defines:
# macrocyclePenalty = math.log10(nMacrocycles+1)
# This form generates better results when 2 or more macrocycles are present
if nMacrocycles > 0:
macrocyclePenalty = math.log10(2)
score2 = 0. - sizePenalty - stereoPenalty - \
spiroPenalty - bridgePenalty - macrocyclePenalty
# correction for the fingerprint density
# not in the original publication, added in version 1.1
# to make highly symmetrical molecules easier to synthetise
score3 = 0.
if nAtoms > len(fps):
score3 = math.log(float(nAtoms) / len(fps)) * .5
sascore = score1 + score2 + score3
# need to transform "raw" value into scale between 1 and 10
min = -4.0
max = 2.5
sascore = 11. - (sascore - min + 1) / (max - min) * 9.
# smooth the 10-end
if sascore > 8.:
sascore = 8. + math.log(sascore + 1. - 9.)
if sascore > 10.:
sascore = 10.0
elif sascore < 1.:
sascore = 1.0
return sascore
def processMols(mols):
print('smiles\tName\tsa_score')
for i, m in enumerate(mols):
if m is None:
continue
s = calculateScore(m)
smiles = Chem.MolToSmiles(m)
print(smiles + "\t" + m.GetProp('_Name') + "\t%3f" % s)
if __name__ == '__main__':
import sys, time
t1 = time.time()
readFragmentScores("fpscores")
t2 = time.time()
suppl = Chem.SmilesMolSupplier(sys.argv[1])
t3 = time.time()
processMols(suppl)
t4 = time.time()
print('Reading took %.2f seconds. Calculating took %.2f seconds' % ((t2 - t1), (t4 - t3)),
file=sys.stderr)
#
# Copyright (c) 2013, Novartis Institutes for BioMedical Research Inc.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following
# disclaimer in the documentation and/or other materials provided
# with the distribution.
# * Neither the name of Novartis Institutes for BioMedical Research Inc.
# nor the names of its contributors may be used to endorse or promote
# products derived from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
"""
Learning Deep Generative Models of Graphs
Paper: https://arxiv.org/pdf/1803.03324.pdf
"""
import datetime
import time
import torch
import torch.distributed as dist
from dgl import model_zoo
from torch.optim import Adam
from torch.utils.data import DataLoader
from utils import MoleculeDataset, Printer, set_random_seed, synchronize, launch_a_process
def evaluate(epoch, model, data_loader, printer):
model.eval()
batch_size = data_loader.batch_size
total_log_prob = 0
with torch.no_grad():
for i, data in enumerate(data_loader):
log_prob = model(actions=data, compute_log_prob=True).detach()
total_log_prob -= log_prob
if printer is not None:
prob = log_prob.detach().exp()
printer.update(epoch + 1, - log_prob / batch_size, prob / batch_size)
return total_log_prob / len(data_loader)
def main(rank, args):
"""
Parameters
----------
rank : int
Subprocess id
args : dict
Configuration
"""
if rank == 0:
t1 = time.time()
set_random_seed(args['seed'])
# Remove the line below will result in problems for multiprocess
torch.set_num_threads(1)
# Setup dataset and data loader
dataset = MoleculeDataset(args['dataset'], args['order'], ['train', 'val'],
subset_id=rank, n_subsets=args['num_processes'])
# Note that currently the batch size for the loaders should only be 1.
train_loader = DataLoader(dataset.train_set, batch_size=args['batch_size'],
shuffle=True, collate_fn=dataset.collate)
val_loader = DataLoader(dataset.val_set, batch_size=args['batch_size'],
shuffle=True, collate_fn=dataset.collate)
if rank == 0:
try:
from tensorboardX import SummaryWriter
writer = SummaryWriter(args['log_dir'])
except ImportError:
print('If you want to use tensorboard, install tensorboardX with pip.')
writer = None
train_printer = Printer(args['nepochs'], len(dataset.train_set), args['batch_size'], writer)
val_printer = Printer(args['nepochs'], len(dataset.val_set), args['batch_size'])
else:
val_printer = None
# Initialize model
model = model_zoo.chem.DGMG(atom_types=dataset.atom_types,
bond_types=dataset.bond_types,
node_hidden_size=args['node_hidden_size'],
num_prop_rounds=args['num_propagation_rounds'],
dropout=args['dropout'])
if args['num_processes'] == 1:
from utils import Optimizer
optimizer = Optimizer(args['lr'], Adam(model.parameters(), lr=args['lr']))
else:
from utils import MultiProcessOptimizer
optimizer = MultiProcessOptimizer(args['num_processes'], args['lr'],
Adam(model.parameters(), lr=args['lr']))
if rank == 0:
t2 = time.time()
best_val_prob = 0
# Training
for epoch in range(args['nepochs']):
model.train()
if rank == 0:
print('Training')
for i, data in enumerate(train_loader):
log_prob = model(actions=data, compute_log_prob=True)
prob = log_prob.detach().exp()
loss_averaged = - log_prob
prob_averaged = prob
optimizer.backward_and_step(loss_averaged)
if rank == 0:
train_printer.update(epoch + 1, loss_averaged.item(), prob_averaged.item())
synchronize(args['num_processes'])
# Validation
val_log_prob = evaluate(epoch, model, val_loader, val_printer)
if args['num_processes'] > 1:
dist.all_reduce(val_log_prob, op=dist.ReduceOp.SUM)
val_log_prob /= args['num_processes']
# Strictly speaking, the computation of probability here is different from what is
# performed on the training set as we first take an average of log likelihood and then
# take the exponentiation. By Jensen's inequality, the resulting value is then a
# lower bound of the real probabilities.
val_prob = (- val_log_prob).exp().item()
val_log_prob = val_log_prob.item()
if val_prob >= best_val_prob:
if rank == 0:
torch.save({'model_state_dict': model.state_dict()}, args['checkpoint_dir'])
print('Old val prob {:.10f} | new val prob {:.10f} | model saved'.format(best_val_prob, val_prob))
best_val_prob = val_prob
elif epoch >= args['warmup_epochs']:
optimizer.decay_lr()
if rank == 0:
print('Validation')
if writer is not None:
writer.add_scalar('validation_log_prob', val_log_prob, epoch)
writer.add_scalar('validation_prob', val_prob, epoch)
writer.add_scalar('lr', optimizer.lr, epoch)
print('Validation log prob {:.4f} | prob {:.10f}'.format(val_log_prob, val_prob))
synchronize(args['num_processes'])
if rank == 0:
t3 = time.time()
print('It took {} to setup.'.format(datetime.timedelta(seconds=t2 - t1)))
print('It took {} to finish training.'.format(datetime.timedelta(seconds=t3 - t2)))
print('--------------------------------------------------------------------------')
print('On average, an epoch takes {}.'.format(datetime.timedelta(
seconds=(t3 - t2) / args['nepochs'])))
if __name__ == '__main__':
import argparse
from utils import setup
parser = argparse.ArgumentParser(description='Training DGMG for molecule generation',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
# configure
parser.add_argument('-s', '--seed', type=int, default=0, help='random seed')
parser.add_argument('-w', '--warmup-epochs', type=int, default=10,
help='Number of epochs where no lr decay is performed.')
# dataset and setting
parser.add_argument('-d', '--dataset',
help='dataset to use')
parser.add_argument('-o', '--order', choices=['random', 'canonical'],
help='order to generate graphs')
parser.add_argument('-tf', '--train-file', type=str, default=None,
help='Path to a file with one SMILES a line for training data. '
'This is only necessary if you want to use a new dataset.')
parser.add_argument('-vf', '--val-file', type=str, default=None,
help='Path to a file with one SMILES a line for validation data. '
'This is only necessary if you want to use a new dataset.')
# log
parser.add_argument('-l', '--log-dir', default='./training_results',
help='folder to save info like experiment configuration')
# multi-process
parser.add_argument('-np', '--num-processes', type=int, default=32,
help='number of processes to use')
parser.add_argument('-mi', '--master-ip', type=str, default='127.0.0.1')
parser.add_argument('-mp', '--master-port', type=str, default='12345')
args = parser.parse_args()
args = setup(args, train=True)
if args['num_processes'] == 1:
main(0, args)
else:
mp = torch.multiprocessing.get_context('spawn')
procs = []
for rank in range(args['num_processes']):
procs.append(mp.Process(target=launch_a_process, args=(rank, args, main), daemon=True))
procs[-1].start()
for p in procs:
p.join()
import datetime
import dgl
import math
import numpy as np
import os
import pickle
import random
import torch
import torch.distributed as dist
import torch.nn as nn
from collections import defaultdict
from datetime import timedelta
from dgl import DGLGraph
from dgl.data.utils import get_download_dir, download, _get_dgl_url
from dgl.model_zoo.chem.dgmg import MoleculeEnv
from multiprocessing import Pool
from pprint import pprint
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem.Crippen import MolLogP
from rdkit.Chem.QED import qed
from torch.utils.data import Dataset
from sascorer import calculateScore
########################################################################################################################
# configuration #
########################################################################################################################
def mkdir_p(path, log=True):
"""Create a directory for the specified path.
Parameters
----------
path : str
Path name
log : bool
Whether to print result for directory creation
"""
import errno
try:
os.makedirs(path)
if log:
print('Created directory {}'.format(path))
except OSError as exc:
if exc.errno == errno.EEXIST and os.path.isdir(path) and log:
print('Directory {} already exists.'.format(path))
else:
raise
def get_date_postfix():
"""Get a date based postfix for directory name.
Returns
-------
post_fix : str
"""
dt = datetime.datetime.now()
post_fix = '{}_{:02d}-{:02d}-{:02d}'.format(
dt.date(), dt.hour, dt.minute, dt.second)
return post_fix
def setup_log_dir(args):
"""Name and create directory for logging.
Parameters
----------
args : dict
Configuration
Returns
-------
log_dir : str
Path for logging directory
"""
date_postfix = get_date_postfix()
log_dir = os.path.join(
args['log_dir'],
'{}_{}_{}'.format(args['dataset'], args['order'], date_postfix))
mkdir_p(log_dir)
return log_dir
def save_arg_dict(args, filename='settings.txt'):
"""Save all experiment settings in a file.
Parameters
----------
args : dict
Configuration
filename : str
Name for the file to save settings
"""
def _format_value(v):
if isinstance(v, float):
return '{:.4f}'.format(v)
elif isinstance(v, int):
return '{:d}'.format(v)
else:
return '{}'.format(v)
save_path = os.path.join(args['log_dir'], filename)
with open(save_path, 'w') as f:
for key, value in args.items():
f.write('{}\t{}\n'.format(key, _format_value(value)))
print('Saved settings to {}'.format(save_path))
def configure(args):
"""Use default hyperparameters.
Parameters
----------
args : dict
Old configuration
Returns
-------
args : dict
Updated configuration
"""
configure = {
'node_hidden_size': 128,
'num_propagation_rounds': 2,
'lr': 1e-4,
'dropout': 0.2,
'nepochs': 400,
'batch_size': 1,
}
args.update(configure)
return args
def set_random_seed(seed):
"""Fix random seed for reproducible results.
Parameters
----------
seed : int
Random seed to use.
"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
def setup_dataset(args):
"""Dataset setup
For unsupported dataset, we need to perform data preprocessing.
Parameters
----------
args : dict
Configuration
"""
if args['dataset'] in ['ChEMBL', 'ZINC']:
print('Built-in support for dataset {} exists.'.format(args['dataset']))
else:
print('Configure for new dataset {}...'.format(args['dataset']))
configure_new_dataset(args['dataset'], args['train_file'], args['val_file'])
def setup(args, train=True):
"""Setup
Parameters
----------
args : argparse.Namespace
Configuration
train : bool
Whether the setup is for training or evaluation
"""
# Convert argparse.Namespace into a dict
args = args.__dict__.copy()
# Dataset
args = configure(args)
# Log
print('Prepare logging directory...')
log_dir = setup_log_dir(args)
args['log_dir'] = log_dir
save_arg_dict(args)
if train:
setup_dataset(args)
args['checkpoint_dir'] = os.path.join(log_dir, 'checkpoint.pth')
pprint(args)
return args
########################################################################################################################
# multi-process #
########################################################################################################################
def synchronize(num_processes):
"""Synchronize all processes.
Parameters
----------
num_processes : int
Number of subprocesses used
"""
if num_processes > 1:
dist.barrier()
def launch_a_process(rank, args, target, minutes=720):
"""Launch a subprocess for training.
Parameters
----------
rank : int
Subprocess id
args : dict
Configuration
target : callable
Target function for the subprocess
minutes : int
Timeout minutes for operations executed against the process group
"""
dist_init_method = 'tcp://{master_ip}:{master_port}'.format(
master_ip=args['master_ip'], master_port=args['master_port'])
dist.init_process_group(backend='gloo',
init_method=dist_init_method,
# If you have a larger dataset, you will need to increase it.
timeout=timedelta(minutes=minutes),
world_size=args['num_processes'],
rank=rank)
assert torch.distributed.get_rank() == rank
target(rank, args)
########################################################################################################################
# optimization #
########################################################################################################################
class Optimizer(nn.Module):
"""Wrapper for optimization
Parameters
----------
lr : float
Initial learning rate
optimizer
model optimizer
"""
def __init__(self, lr, optimizer):
super(Optimizer, self).__init__()
self.lr = lr
self.optimizer = optimizer
self._reset()
def _reset(self):
self.optimizer.zero_grad()
def backward_and_step(self, loss):
"""Backward and update model.
Parameters
----------
loss : torch.tensor consisting of a float only
"""
loss.backward()
self.optimizer.step()
self._reset()
def decay_lr(self, decay_rate=0.99):
"""Decay learning rate.
Parameters
----------
decay_rate : float
Multiply the current learning rate by the decay_rate
"""
self.lr *= decay_rate
for param_group in self.optimizer.param_groups:
param_group['lr'] = self.lr
class MultiProcessOptimizer(Optimizer):
"""Wrapper for optimization with multiprocess
Parameters
----------
n_processes : int
Number of processes used
lr : float
Initial learning rate
optimizer
model optimizer
"""
def __init__(self, n_processes, lr, optimizer):
super(MultiProcessOptimizer, self).__init__(lr=lr, optimizer=optimizer)
self.n_processes = n_processes
def _sync_gradient(self):
"""Average gradients across all subprocesses."""
for param_group in self.optimizer.param_groups:
for p in param_group['params']:
if p.requires_grad and p.grad is not None:
dist.all_reduce(p.grad.data, op=dist.ReduceOp.SUM)
p.grad.data /= self.n_processes
def backward_and_step(self, loss):
"""Backward and update model.
Parameters
----------
loss : torch.tensor consisting of a float only
"""
loss.backward()
self._sync_gradient()
self.optimizer.step()
self._reset()
########################################################################################################################
# data #
########################################################################################################################
def initialize_neuralization_reactions():
"""Reference neuralization reactions
Code adapted from RDKit Cookbook, by Hans de Winter.
"""
patts = (
# Imidazoles
('[n+;H]', 'n'),
# Amines
('[N+;!H0]', 'N'),
# Carboxylic acids and alcohols
('[$([O-]);!$([O-][#7])]', 'O'),
# Thiols
('[S-;X1]', 'S'),
# Sulfonamides
('[$([N-;X2]S(=O)=O)]', 'N'),
# Enamines
('[$([N-;X2][C,N]=C)]', 'N'),
# Tetrazoles
('[n-]', '[n]'),
# Sulfoxides
('[$([S-]=O)]', 'S'),
# Amides
('[$([N-]C=O)]', 'N'),
)
return [(Chem.MolFromSmarts(x), Chem.MolFromSmiles(y, False)) for x, y in patts]
def neutralize_charges(mol, reactions=None):
"""Deprotonation for molecules.
Code adapted from RDKit Cookbook, by Hans de Winter.
DGMG currently cannot generate protonated molecules.
For example, it can only generate
CC(C)(C)CC1CCC[NH+]1Cc1nnc(-c2ccccc2F)o1
from
CC(C)(C)CC1CCCN1Cc1nnc(-c2ccccc2F)o1
even with correct decisions.
Deprotonation is therefore an important step to avoid
false novel molecules.
Parameters
----------
mol : Chem.rdchem.Mol
reactions : list of 2-tuples
Rules for deprotonation
Returns
-------
mol : Chem.rdchem.Mol
Deprotonated molecule
"""
if reactions is None:
reactions = initialize_neuralization_reactions()
for i, (reactant, product) in enumerate(reactions):
while mol.HasSubstructMatch(reactant):
rms = AllChem.ReplaceSubstructs(mol, reactant, product)
mol = rms[0]
return mol
def standardize_mol(mol):
"""Standardize molecule to avoid false novel molecule.
Kekulize and deprotonate molecules to avoid false novel molecules.
In addition to deprotonation, we also kekulize molecules to avoid
explicit Hs in the SMILES. Otherwise we will get false novel molecules
as well. For example, DGMG can only generate
O=S(=O)(NC1=CC=CC(C(F)(F)F)=C1)C1=CNC=N1
from
O=S(=O)(Nc1cccc(C(F)(F)F)c1)c1c[nH]cn1.
One downside is that we remove all explicit aromatic rings and to
explicitly predict aromatic bond might make the learning easier for
the model.
"""
reactions = initialize_neuralization_reactions()
Chem.Kekulize(mol, clearAromaticFlags=True)
mol = neutralize_charges(mol, reactions)
return mol
def smiles_to_standard_mol(s):
"""Convert SMILES to a standard molecule.
Parameters
----------
s : str
SMILES
Returns
-------
Chem.rdchem.Mol
Standardized molecule
"""
mol = Chem.MolFromSmiles(s)
return standardize_mol(mol)
def mol_to_standard_smile(mol):
"""Standardize a molecule and convert it to a SMILES.
Parameters
----------
mol : Chem.rdchem.Mol
Returns
-------
str
SMILES
"""
return Chem.MolToSmiles(standardize_mol(mol))
def get_atom_and_bond_types(smiles, log=True):
"""Identify the atom types and bond types
appearing in this dataset.
Parameters
----------
smiles : list
List of smiles
log : bool
Whether to print the process of pre-processing.
Returns
-------
atom_types : list
E.g. ['C', 'N']
bond_types : list
E.g. [Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE,
Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC]
"""
atom_types = set()
bond_types = set()
n_smiles = len(smiles)
for i, s in enumerate(smiles):
if log:
print('Processing smile {:d}/{:d}'.format(i + 1, n_smiles))
mol = smiles_to_standard_mol(s)
if mol is None:
continue
for atom in mol.GetAtoms():
a_symbol = atom.GetSymbol()
if a_symbol not in atom_types:
atom_types.add(a_symbol)
for bond in mol.GetBonds():
b_type = bond.GetBondType()
if b_type not in bond_types:
bond_types.add(b_type)
return list(atom_types), list(bond_types)
def eval_decisions(env, decisions):
"""This function mimics the way DGMG generates a molecule and is
helpful for debugging and verification in data preprocessing.
Parameters
----------
env : MoleculeEnv
MDP environment for generating molecules
decisions : list of 2-tuples of int
A decision sequence for generating a molecule
Returns
-------
str
SMILES for the molecule generated with decisions
"""
env.reset(rdkit_mol=True)
t = 0
def whether_to_add_atom(t):
assert decisions[t][0] == 0
atom_type = decisions[t][1]
t += 1
return t, atom_type
def whether_to_add_bond(t):
assert decisions[t][0] == 1
bond_type = decisions[t][1]
t += 1
return t, bond_type
def decide_atom2(t):
assert decisions[t][0] == 2
dst = decisions[t][1]
t += 1
return t, dst
t, atom_type = whether_to_add_atom(t)
while atom_type != len(env.atom_types):
env.add_atom(atom_type)
t, bond_type = whether_to_add_bond(t)
while bond_type != len(env.bond_types):
t, dst = decide_atom2(t)
env.add_bond((env.num_atoms() - 1), dst, bond_type)
t, bond_type = whether_to_add_bond(t)
t, atom_type = whether_to_add_atom(t)
assert t == len(decisions)
return env.get_current_smiles()
def get_DGMG_smile(env, mol):
"""Mimics the reproduced SMILE with DGMG for a molecule.
Given a molecule, we are interested in what SMILES we will
get if we want to generate it with DGMG. This is an important
step to check false novel molecules.
Parameters
----------
env : MoleculeEnv
MDP environment for generating molecules
mol : Chem.rdchem.Mol
A molecule
Returns
-------
canonical_smile : str
SMILES of the generated molecule with a canonical decision sequence
random_smile : str
SMILES of the generated molecule with a random decision sequence
"""
canonical_decisions = env.get_decision_sequence(mol, list(range(mol.GetNumAtoms())))
canonical_smile = eval_decisions(env, canonical_decisions)
order = list(range(mol.GetNumAtoms()))
random.shuffle(order)
random_decisions = env.get_decision_sequence(mol, order)
random_smile = eval_decisions(env, random_decisions)
return canonical_smile, random_smile
def preprocess_dataset(atom_types, bond_types, smiles, max_num_atoms=23):
"""Preprocess the dataset
1. Standardize the SMILES of the dataset
2. Only keep the SMILES that DGMG can reproduce
3. Drop repeated SMILES
Parameters
----------
atom_types : list
The types of atoms appearing in a dataset. E.g. ['C', 'N']
bond_types : list
The types of bonds appearing in a dataset.
E.g. [Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE,
Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC]
Returns
-------
valid_smiles : list of str
SMILES left after preprocessing
"""
valid_smiles = []
env = MoleculeEnv(atom_types, bond_types)
for id, s in enumerate(smiles):
print('Processing {:d}/{:d}'.format(id + 1, len(smiles)))
raw_s = s.strip()
mol = smiles_to_standard_mol(raw_s)
if mol is None:
continue
if (max_num_atoms is not None) and (mol.GetNumAtoms() > max_num_atoms):
continue
canonical_s, random_s = get_DGMG_smile(env, mol)
canonical_mol = Chem.MolFromSmiles(canonical_s)
random_mol = Chem.MolFromSmiles(random_s)
if (raw_s != canonical_s) or (canonical_s != random_s) or (canonical_mol is None) or (random_mol is None):
continue
valid_smiles.append(Chem.MolToSmiles(mol))
valid_smiles = list(set(valid_smiles))
return valid_smiles
def download_data(dataset, fname):
"""Download dataset if built-in support exists
Parameters
----------
dataset : str
Dataset name
fname : str
Name of dataset file
"""
if dataset not in ['ChEMBL', 'ZINC']:
# For dataset without built-in support, they should be locally processed.
return
data_path = fname
download(_get_dgl_url(os.path.join('dataset', fname)), path=data_path)
def load_smiles_from_file(f_name):
"""Load dataset into a list of SMILES
Parameters
----------
f_name : str
Path to a file of molecules, where each line of the file
is a molecule in SMILES format.
Returns
-------
smiles : list of str
List of molecules as SMILES
"""
with open(f_name, 'r') as f:
smiles = f.read().splitlines()
return smiles
def write_smiles_to_file(f_name, smiles):
"""Write dataset to a file.
Parameters
----------
f_name : str
Path to create a file of molecules, where each line of the file
is a molecule in SMILES format.
smiles : list of str
List of SMILES
"""
with open(f_name, 'w') as f:
for s in smiles:
f.write(s + '\n')
def configure_new_dataset(dataset, train_file, val_file):
"""Configure for a new dataset.
Parameters
----------
dataset : str
Dataset name
train_file : str
Path to a file with one SMILES a line for training data
val_file : str
Path to a file with one SMILES a line for validation data
"""
assert train_file is not None, 'Expect a file of SMILES for training, got None.'
assert val_file is not None, 'Expect a file of SMILES for validation, got None.'
train_smiles = load_smiles_from_file(train_file)
val_smiles = load_smiles_from_file(val_file)
all_smiles = train_smiles + val_smiles
# Get all atom and bond types in the dataset
path_to_atom_and_bond_types = '_'.join([dataset, 'atom_and_bond_types.pkl'])
if not os.path.exists(path_to_atom_and_bond_types):
atom_types, bond_types = get_atom_and_bond_types(all_smiles)
with open(path_to_atom_and_bond_types, 'wb') as f:
pickle.dump({'atom_types': atom_types, 'bond_types': bond_types}, f)
else:
with open(path_to_atom_and_bond_types, 'rb') as f:
type_info = pickle.load(f)
atom_types = type_info['atom_types']
bond_types = type_info['bond_types']
# Standardize training data
path_to_processed_train_data = '_'.join([dataset, 'DGMG', 'train.txt'])
if not os.path.exists(path_to_processed_train_data):
processed_train_smiles = preprocess_dataset(atom_types, bond_types, train_smiles, None)
write_smiles_to_file(path_to_processed_train_data, processed_train_smiles)
path_to_processed_val_data = '_'.join([dataset, 'DGMG', 'val.txt'])
if not os.path.exists(path_to_processed_val_data):
processed_val_smiles = preprocess_dataset(atom_types, bond_types, val_smiles, None)
write_smiles_to_file(path_to_processed_val_data, processed_val_smiles)
class MoleculeDataset(object):
"""Initialize and split the dataset.
Parameters
----------
dataset : str
Dataset name
order : None or str
Order to extract a decision sequence for generating a molecule. Default to be None.
modes : None or list
List of subsets to use, which can contain 'train', 'val', corresponding to
training and validation. Default to be None.
subset_id : int
With multiprocess training, we partition the training set into multiple subsets and
each process will use one subset only. This subset_id corresponds to subprocess id.
n_subsets : int
With multiprocess training, this corresponds to the number of total subprocesses.
"""
def __init__(self, dataset, order=None, modes=None, subset_id=0, n_subsets=1):
super(MoleculeDataset, self).__init__()
if modes is None:
modes = []
else:
assert order is not None, 'An order should be specified for extracting ' \
'decision sequences.'
assert order in ['random', 'canonical', None], \
"Unexpected order option to get sequences of graph generation decisions"
assert len(set(modes) - {'train', 'val'}) == 0, \
"modes should be a list, representing a subset of ['train', 'val']"
self.dataset = dataset
self.order = order
self.modes = modes
self.subset_id = subset_id
self.n_subsets = n_subsets
self._setup()
def collate(self, samples):
"""PyTorch's approach to batch multiple samples.
For auto-regressive generative models, we process one sample at a time.
Parameters
----------
samples : list
A list of length 1 that consists of decision sequence to generate a molecule.
Returns
-------
list
List of 2-tuples, a decision sequence to generate a molecule
"""
assert len(samples) == 1
return samples[0]
def _create_a_subset(self, smiles):
"""Create a dataset from a subset of smiles.
Parameters
----------
smiles : list of str
List of molecules in SMILES format
"""
# We evenly divide the smiles into multiple susbets with multiprocess
subset_size = len(smiles) // self.n_subsets
return Subset(smiles[self.subset_id * subset_size: (self.subset_id + 1) * subset_size],
self.order, self.env)
def _setup(self):
"""
1. Instantiate an MDP environment for molecule generation
2. Download the dataset, which is a file of SMILES
3. Create subsets for training and validation
"""
if self.dataset == 'ChEMBL':
# For new datasets, get_atom_and_bond_types can be used to
# identify the atom and bond types in them.
self.atom_types = ['O', 'Cl', 'C', 'S', 'F', 'Br', 'N']
self.bond_types = [Chem.rdchem.BondType.SINGLE,
Chem.rdchem.BondType.DOUBLE,
Chem.rdchem.BondType.TRIPLE]
elif self.dataset == 'ZINC':
self.atom_types = ['Br', 'S', 'C', 'P', 'N', 'O', 'F', 'Cl', 'I']
self.bond_types = [Chem.rdchem.BondType.SINGLE,
Chem.rdchem.BondType.DOUBLE,
Chem.rdchem.BondType.TRIPLE]
else:
path_to_atom_and_bond_types = '_'.join([self.dataset, 'atom_and_bond_types.pkl'])
with open(path_to_atom_and_bond_types, 'rb') as f:
type_info = pickle.load(f)
self.atom_types = type_info['atom_types']
self.bond_types = type_info['bond_types']
self.env = MoleculeEnv(self.atom_types, self.bond_types)
dataset_prefix = self._dataset_prefix()
if 'train' in self.modes:
fname = '_'.join([dataset_prefix, 'train.txt'])
download_data(self.dataset, fname)
smiles = load_smiles_from_file(fname)
self.train_set = self._create_a_subset(smiles)
if 'val' in self.modes:
fname = '_'.join([dataset_prefix, 'val.txt'])
download_data(self.dataset, fname)
smiles = load_smiles_from_file(fname)
# We evenly divide the smiles into multiple susbets with multiprocess
self.val_set = self._create_a_subset(smiles)
def _dataset_prefix(self):
"""Get the prefix for the data files of supported datasets.
Returns
-------
str
Prefix for dataset file name
"""
return '_'.join([self.dataset, 'DGMG'])
class Subset(Dataset):
"""A set of molecules which can be used for training, validation, test.
Parameters
----------
smiles : list
List of SMILES for the dataset
order : str
Specifies how decision sequences for molecule generation
are obtained, can be either "random" or "canonical"
env : MoleculeEnv object
MDP environment for generating molecules
"""
def __init__(self, smiles, order, env):
super(Subset, self).__init__()
self.smiles = smiles
self.order = order
self.env = env
self._setup()
def _setup(self):
"""Convert SMILES into rdkit molecule objects.
Decision sequences are extracted if we use a fixed order.
"""
smiles_ = []
mols = []
for s in self.smiles:
m = smiles_to_standard_mol(s)
if m is None:
continue
smiles_.append(s)
mols.append(m)
self.smiles = smiles_
self.mols = mols
if self.order is 'random':
return
self.decisions = []
for m in self.mols:
self.decisions.append(
self.env.get_decision_sequence(m, list(range(m.GetNumAtoms())))
)
def __len__(self):
"""Get number of molecules in the dataset."""
return len(self.mols)
def __getitem__(self, item):
"""Get the decision sequence for generating the molecule indexed by item."""
if self.order == 'canonical':
return self.decisions[item]
else:
m = self.mols[item]
nodes = list(range(m.GetNumAtoms()))
random.shuffle(nodes)
return self.env.get_decision_sequence(m, nodes)
########################################################################################################################
# progress tracking #
########################################################################################################################
class Printer(object):
def __init__(self, num_epochs, dataset_size, batch_size, writer=None):
"""Wrapper to track the learning progress.
Parameters
----------
num_epochs : int
Number of epochs for training
dataset_size : int
batch_size : int
writer : None or SummaryWriter
If not None, tensorboard will be used to visualize learning curves.
"""
super(Printer, self).__init__()
self.num_epochs = num_epochs
self.batch_size = batch_size
self.num_batches = math.ceil(dataset_size / batch_size)
self.count = 0
self.batch_count = 0
self.writer = writer
self._reset()
def _reset(self):
"""Reset when an epoch is completed."""
self.batch_loss = 0
self.batch_prob = 0
def _get_current_batch(self):
"""Get current batch index."""
remainer = self.batch_count % self.num_batches
if (remainer == 0):
return self.num_batches
else:
return remainer
def update(self, epoch, loss, prob):
"""Update learning progress.
Parameters
----------
epoch : int
loss : float
prob : float
"""
self.count += 1
self.batch_loss += loss
self.batch_prob += prob
if self.count % self.batch_size == 0:
self.batch_count += 1
if self.writer is not None:
self.writer.add_scalar('train_log_prob', self.batch_loss, self.batch_count)
self.writer.add_scalar('train_prob', self.batch_prob, self.batch_count)
print('epoch {:d}/{:d}, batch {:d}/{:d}, loss {:.4f}, prob {:.4f}'.format(
epoch, self.num_epochs, self._get_current_batch(),
self.num_batches, self.batch_loss, self.batch_prob))
self._reset()
########################################################################################################################
# eval #
########################################################################################################################
def summarize_a_molecule(smile, checklist=None):
"""Get information about a molecule.
Parameters
----------
smile : str
Molecule in SMILES format
checklist : dict
Things to learn about the molecule
"""
if checklist is None:
checklist = {
'HBA': Chem.rdMolDescriptors.CalcNumHBA,
'HBD': Chem.rdMolDescriptors.CalcNumHBD,
'logP': MolLogP,
'SA': calculateScore,
'TPSA': Chem.rdMolDescriptors.CalcTPSA,
'QED': qed,
'NumAtoms': lambda mol: mol.GetNumAtoms(),
'NumBonds': lambda mol: mol.GetNumBonds()
}
summary = dict()
mol = Chem.MolFromSmiles(smile)
if mol is None:
summary.update({
'smile': smile,
'valid': False
})
for k in checklist.keys():
summary[k] = None
else:
mol = standardize_mol(mol)
summary.update({
'smile': Chem.MolToSmiles(mol),
'valid': True
})
Chem.SanitizeMol(mol)
for k, f in checklist.items():
summary[k] = f(mol)
return summary
def summarize_molecules(smiles, num_processes):
"""Summarize molecules with multiprocess.
Parameters
----------
smiles : list of str
List of molecules in SMILES for summarization
num_processes : int
Number of processes to use for summarization
Returns
-------
summary_for_valid : dict
Summary of all valid molecules, where
summary_for_valid[k] gives the values of all
valid molecules on item k.
"""
with Pool(processes=num_processes) as pool:
result = pool.map(summarize_a_molecule, smiles)
items = list(result[0].keys())
items.remove('valid')
summary_for_valid = defaultdict(list)
for summary in result:
if summary['valid']:
for k in items:
summary_for_valid[k].append(summary[k])
return summary_for_valid
def get_unique_smiles(smiles):
"""Given a list of smiles, return a list consisting of unique elements in it.
Parameters
----------
smiles : list of str
Molecules in SMILES
Returns
-------
list of str
Sublist where each SMIES occurs exactly once
"""
unique_set = set()
for mol_s in smiles:
if mol_s not in unique_set:
unique_set.add(mol_s)
return list(unique_set)
def get_novel_smiles(new_unique_smiles, reference_unique_smiles):
"""Get novel smiles which do not appear in the reference set.
Parameters
----------
new_unique_smiles : list of str
List of SMILES from which we want to identify novel ones
reference_unique_smiles : list of str
List of reference SMILES that we already have
"""
return set(new_unique_smiles).difference(set(reference_unique_smiles))
{
"cells": [
{
"cell_type": "code",
"execution_count": 57,
"metadata": {},
"outputs": [],
"source": [
"from dgl import model_zoo\n",
"import torch\n",
"import rdkit\n",
"from rdkit import Chem\n",
"from rdkit.Chem.Draw import IPythonConsole\n",
"from dgl.data.chem.utils import smile2graph\n",
"import dgl"
]
},
{
"cell_type": "code",
"execution_count": 58,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loading pretrained model...\n"
]
},
{
"data": {
"text/plain": [
"GCNClassifier(\n",
" (gcn_layers): ModuleList(\n",
" (0): GCNLayer(\n",
" (graph_conv): GraphConv(in=74, out=64, normalization=False, activation=<function relu at 0x7efd7f46e158>)\n",
" (dropout): Dropout(p=0.0)\n",
" (res_connection): Linear(in_features=74, out_features=64, bias=True)\n",
" (bn_layer): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" (1): GCNLayer(\n",
" (graph_conv): GraphConv(in=64, out=64, normalization=False, activation=<function relu at 0x7efd7f46e158>)\n",
" (dropout): Dropout(p=0.0)\n",
" (res_connection): Linear(in_features=64, out_features=64, bias=True)\n",
" (bn_layer): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (atom_weighting): Sequential(\n",
" (0): Linear(in_features=64, out_features=1, bias=True)\n",
" (1): Sigmoid()\n",
" )\n",
" (soft_classifier): MLPBinaryClassifier(\n",
" (predict): Sequential(\n",
" (0): Dropout(p=0.0)\n",
" (1): Linear(in_features=128, out_features=64, bias=True)\n",
" (2): ReLU()\n",
" (3): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (4): Linear(in_features=64, out_features=12, bias=True)\n",
" )\n",
" )\n",
")"
]
},
"execution_count": 58,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model = model_zoo.chem.load_pretrained(\"GCN_Tox21\")\n",
"model.eval()\n",
"model"
]
},
{
"cell_type": "code",
"execution_count": 59,
"metadata": {},
"outputs": [],
"source": [
"tasks = ['NR-AR', 'NR-AR-LBD', 'NR-AhR', 'NR-Aromatase',\n",
" 'NR-ER', 'NR-ER-LBD', 'NR-PPAR-gamma', 'SR-ARE',\n",
" 'SR-ATAD5', 'SR-HSE', 'SR-MMP', 'SR-p53']"
]
},
{
"cell_type": "code",
"execution_count": 60,
"metadata": {},
"outputs": [],
"source": [
"smiles = \"CC[NH+](CC)c1ccc(/C=C2\\Oc3c(ccc(OCC(N)=O)c3C)C2=O)cc1\""
]
},
{
"cell_type": "code",
"execution_count": 61,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAcIAAACWCAIAAADCEh9HAAAABmJLR0QA/wD/AP+gvaeTAAAgAElEQVR4nO3deViU19k/8O8sjBBWlaCM4IJgBFxxF3wxkQQX1JiKOy4xmqa2Y2xqqfHNi7bmDVlsp6mvqIkL7hoXhLhChAR+0YgI4pAYIQJVdlHZt5k5vz8OTgm4IDzMM8D9uXL1Yh6Gc+6xcnv2I2GMgRBCSEtJxQ6AEELaN0qjhBDSKpRGCSGkVSiNEkJIq1AaJYSQVqE0+gy3b9+eNWtWcXGx2IEQ8pzOn8fLL8PfH8uWoaJC7Gg6MgkteHq61157LTo62s3N7fTp025ubmKHQ0jz5Odj6lTExsLWFrt2ITUVarXYMXVYlEafIS8vb8aMGVevXu3WrduJEyd8fX3FjoiQZtixAyUlWLsWAPR6eHjg5k2xY+qwqFP/DI6OjnFxcTNnzrx//76/v//+/fvFjoiQZsjNhVJZ/7VUCqkUWq2oAXVklEafzdLS8vjx48HBwTU1NYsXL96wYQM14YmpUyqRm1v/tV4PnQ5yuagBdWSURptFJpOFhoZu375dJpNt3Lhx2bJltbW1YgdFyJNNn47Dh1FaCgB792LyZLED6shobPT5nD9/fs6cOaWlpd7e3hEREfb29mJHRMgTnDmDTz6BQgGlEp9/jvBwfP01oqKgUIgdWUdDafS5paamBgQE3Llzx9XV9cyZMzR9T0xXSQnMzPDCCwAwbBiuX8epU5gxQ+ywOhrq1D+3IUOGXL582cvLKyMjY9y4cQkJCWJHRMjj/O1v6NkTBw/Wv5w/HwAOHRIxoo6K0mhLKJXKb7/9NiAgoLi42M/P7xD91SQmyMkJ1dU4fLj+5fz5kEgQGYnyclHD6oAojbaQlZVVRETEqlWrampqFi5cuGHDBrEjIuTXfvMbWFggNhY5OQDQuze8vVFZiVOnxI6so6E02nIymWzLli1qtVoikWzcuHH58uV1dXViB0XIIzY2mDwZej2++qr+CfXr2wZNMQng5MmTixYtqqys9PPzO3bsmK2trdgREQIA+OorzJmD0aPxww8AUFSEXr0AIDcXtMhEONQaFcCsWbNiY2N79OgRExPj4+OTnZ0tdkSEAACmT4etLa5cQXo6ALz4IiZNQl0djh0TO7IOhdKoMEaPHn3p0iUPDw+NRjNu3LirV6+KHREhgLl5/fKmI0fqn1C/vg1QGhVMv379Ll++PGXKlLy8vIkTJ56igXxiCnjePHCg/uWsWbCwYMnJ5XzeiQiB0qiQrK2tIyIili5dWlFR8e6771ZXV4sdEen0Xn0VDg64eRPXrwOAtfXRFSvstdpt1CAVDqVRgSkUit27d48dO7awsPD48eNih0M6PbmcBQbeHz784rlz/IHZxIn3q6posbOAKI22lcrKSkdHR7GjIAT/b+7c7snJb4aF8WU5U6dOtbOzu3bt2o8//ih2aB0EpVHh6fV6jUYDYMiQIWLHQgi8fXz69u2bnZ39/fffA+jSpcsbb7wB4OjRo2KH1kFQGhXerVu3ysvLe/fuTec/EVMgkUjmzJkDwNCRnz9/PoADhnkn0jqURoWXnJwMYPjw4WIHQkg9njePHj2q1WoBvPLKK0qlMiMjIykpSezQOgJKo8JLSUkBMGzYMLEDIaTesGHDPD09i4qKYmJiAEil0tmzZ6NB+5S0BqVR4VEaJSZo7ty5aNKvP3TokE6nEzOsDoH21AuvZ8+eBQUFWVlZffr0ETsWQur98ssvbm5uVlZWBQUFFhYWANzc3DIyMmJjYydOnCh2dO0btUYFlpOTU1BQYGdn17t3b7FjIeQ/+vfvP3LkyLKystOnT+t0uoSEhPHjx3t7e0ullARaiy4LFFhammzChP1OTpkSiUTsWAj5lfnz5ycmJu7ateuDDz4wNzfno0/z5s0bMWLEiBEjPD09PTw8PDw86K/u86JOvcA2bcIHH+CPf8TmzWKHQsiv5eXlff7558ePH09PT3dxcbG3t79x40ZVVVXD90ya9E9ANXw4hg3DsGF46SW6mPnZ6E9IYCkpAEDTS8QE2dvbX7t2LT093d3d/fvvv7ezs9PpdNnZ2WlpaUlJSUlJSYmJiTrdmLg4fPNN/Y9YWGDwYAwbhhEjMHo0RozAzz/D1RXbtsHcHMOGQa3Gnj0AsG0btFr8/vfifTzxUBoVWHIyANCaUWKCVq9efeHCBXt7+6ioKDs7OwAymczFxcXFxWX69On8PXfu1KWkIDkZKSlISUFmJq5cwZUrcHPD0aMYOBCbNyMsTNSPYXoojQqptBSZmejSBS+9JHYohPxaaGhoWFiYhYVFVFRU//79n/Q2Z2czZ2c8Sqp4+BA8q/LL7ceNQ3o6CguNEnH7QWlUSCkpYAyDB8PMTOxQyK9VVVV9+eWXV69eDQ8PFzsWERw7dmz9+vVSqXT//v1jx45t/g/a2WHiRPAFUXzAatUqbNkCpbL+DefO1X83NxcqlaBBtx8dba1Dfn6+iLVTj95kyeXyjRs37t27l58a06kkJiYuWbJEr9d/+umn/FCS1pg5E9HRqKiofzl5MuLiEBeHP/6xtXG2Xx0qjX788ceenp7ffvutWAHQ/JLJMjMz65zbHzMzM6dPn15ZWfnWW2/9UYhUJ5Vi6VLs3t36kjqOjpNGGWOXLl26f/++v7//wYMHRYmBWqOmzHCsUedZ5FdSUjJjxoyCgoLJkyeHCTcxtGQJioqEKqxDYB2IVqsNDg4GIJFIQkJC9Hq9MWuvqWEKBZNKWVmZMaslzaXT6ZycnABcunRJ7FiMoba2dtKkSQA8PT0fPnxonEq/+oq99RabMYPdv2+cCk2CbMOGDeLmcQFJpVI/Pz+lUnn27NnY2NisrKypU6fKZDLj1C6RYP58vPoqBg0yToXk+Ugkktzc3EuXLllaWk6ZMkXscNoWY+zNN9+MiIhQKpVxcXEODg7GqdfDAzNmoKICNTXo1884dZoAsfN4mzh37pyNjQ0Ab2/voqIiscMhpoJffO3g4FBXVyd2LG2LN49eeOGFK1euGLnqqiq2eDGrqTFytWLqmGmUMXb9+nVnZ2cArq6ut27dEiWG774TpVryNO7u7gAuXLggchznzrGJE9lrr7GlS1l5ubBlHz58WCKRyGSyiIgIYUt+pro69s47LCvLyNWKrONMMTUyZMiQy5cve3l5ZWRkjBs3Lj4+3vgxnD9v/DrJMzS6TkMc+flYtw4RETh/HhMmYP16AcuOj49fsmQJY+wf//jHzJkzBSy5OT77DDdv4qOP6qdbOwux83jbKisrCwgIANClS5eDBw8KUmZyMpNKWXo6Y4yFhbHdu1lyMluypP67YWHsX/9iiYns7beZlxd7+2127Jgg1RJhpKenA7CxsamsrBQtiO3b2Sef1H+t07GXXmKMsZ072aFD7KefmFbb4oIzMjJefPFFAGvWrBEiUNIsHbY1yllZWUVERKxataqmpmbhwoVCzafxncVPMXIktm3DlCnYtg2/+Y0gdRJhuLq6enl5lZaWnnt0b7sIcnP/sw1IKoVUCq0W69dj/ny4u+OFF+DpicWL8c9/IiEBlZXNLLW4uHjKlClFRUXTpk379NNP2yp40kQHT6MAZDLZli1b1Gq1RCLZuHHj8uXL6+rqWlkm7Sxu1wz3Z4gWgVKJ3Nz6r/V68Gs83n4bM2agd2/U1uLHH7FvH959FxMmwMYGnp5p77772WefxcTEFBcXP7bI6urqmTNnpqene3l5HTlyxGgLVAg61XmjJ0+eXLRoUWVlpZ+f37Fjx2xtbZ+3hJoaXLiAXr2wdSumTUNyMpTK+uPCJk/GwIHAo53FnfO4sPYiNzfX2dlZoVDk5+e34K+BAPLyEBCA2FjY2GDPHiQn45///M93S0pw4waSkvDjj0hLw9WrqKn5u6/ve4+25zk6OvIjlvlxyx4eHgCCgoIOHDjQq1evy5cv8+WxxHjEHlUwqh9++KFHjx4ABg0alNXs2USdjsXHM5WKvfgiA1hYGFu+nOl0bOxY9tlnjx8bJSbO19cXQHh4uLErLipiv/0tKytjp08zX1/26qtsyZJnbNioqmKJibH79v3ud78bN26cpaVlo19hOzs7fmKTjY3NjRs3jPVJyH90rjTKGMvMzOT/ejs6OiYmJj79zT/8wN59lymVDKj/b+hQtmULW76cMca2bWOenpRG26Vt27YBmDx5slFrrapi3t4MYEuXtqaYnJycyMjIkJCQgICAnj17ApDL5ZMmTTp//rxQkZLn0unSKGOstLSUb2KxtLR87MK6H3/88YMPPpg1a50he/bvz/77v1laGmOMJSfXp9GqKubg8LQ0WlnJJk1iRl+6R56tuLhYoVDI5fKCggIjVanXs4ULGcB69WJ37ghYcE5OTlxcnIAFkufVGdMoY6yuru63v/0tAJlMFhoayh/euXNHrVZ7e3vzvpKVlbWLS8Xq1ezy5RbWsnUrA5hMxtRqwSInQpk2bRqA//u//zNSfe+/zwBmbc2uXzdSjcRYOmka5TZt2sQvQfTz8/Px8TFciNi1a9fly5d/88032las4OPUaiaVMoCtXMnabv9hVVVVZGTkgwcPWh9w57F//34APj4+xqhs1y4GMLmcUb+7I+rUaZQx9tVXX1lYWAwcOBCAubl5QEBAeHh4RUWFoFUwCwsGMH9/VlIiYMFMq9XGx8evXLmSHyAwZMgQf3//EmHr6LjKy8stLS0lEklmZmbb1hQXxxQKBrCtW9u2IiKSzp5GGWOpqamXLl3av39/WZudcHfpEnNwYAAbMoT9+9+tLU2v1yckJKxatarhsT2DBg3iydTLyysnJ0eIqDu+efPmAfj444/bsI4ff2RduzKABQe3YS1EVJRGjSQ9nQ0YwAA2cmTm9ZaOjmk0mpCQEFdXV0P27Nu3b3Bw8M2bNxljv/zyC29WK5XKpKQkQcPvmE6dOgVg2LBhbVR+fn5+9ZgxDGCBgUyna6NaiOgojRrPvXts8uSHrq6e1tbWZ8+ebf4PZmdnq9VqLy8vQ/bs1auXSqWKj49v9M7i4mK+ItLKyioqKkrQ8Dugmpqa7t27A9BoNIIXXlFRMXr06GE9ehTPns1E3L9P2h6lUaOqqakJCgrCoy2qT3/zvXv3tm/f7u3t3XDuKygoKDIy8ilTSTU1NYsWLWpmFeStt94C8MEHHwhbrE6n45fH9evXLz8/X9jCiamhNGpser0+JCSEZ0aVSqVr0td7+PBheHh4QECAXF5//bWFhUVgYGBkZGRNk7Nwy8rK9u3bV/7rAyufWQUxOHbsmEQisbCwGDFiRFBQkFqtjo+Pb/0c43vvvUfbijoPSqPiOHz4sLm5OYBZs2bxX1q+aCkoKOiFF17g2bNLly585UDTua/q6mr+ZisrKwCPPQOwaRWkkYqKilGjRjXdIS2XywcPHhwUFLR58+aLFy/ef857hb744gsAZmZmMTExbRQ5MSmURkVz8eLFrl278kn2wMBAa2tr/jssk8leeeWVL7/8sulvr1arjYmJWb58Of9BAFKpdMKECWfOnHlsFQkJCfb29gBGjx5NXctGGva7f/755/j4eLVaHRQUNGLECIVC0SixOjo6BgQEBAcHh4eHazSap9yWePbsWblcLpFI9uzZY8yPQ0REaVRM6enpffv27datG/9d9fDwCA0NfexyJY1GExwc7OjoaPjF9vDwCAkJycjIeGYVAwYMAODk5JSSktI2n6Ndekq/u7a2VqPRhIeHq1Qqb2/vpqeB2Nraent7q1Sq8PDwq1evVldX8x/UaDT8yKj/+Z//MfoHIqKhNCqyjRs38tZiOj9P/9fS0tJCQkJ4HuT69OmjUqmuXbvW/Cru3bs3YcIEANbW1k9qt3Y2z9Xv1mq1aWlpBw4cWLt2rZ+fH2/gN2Rubj5q1KiFCxfyk+cXLlxo5Mu9ibgojYqMT9xv27at4cNGu/sBKJVKvsKpZb+f1dXV/KxiuVweFhYmUOzt1ZP63cXFxUePHk1PT3/mH3JOTk50dDQfBPDw8JBK648/d3Z2dnd3NzROSSdBaVRkgwcPBnC5wfEnFy5cMKxw6tat21tvvXXx4sXWz7bz6XtebGeevn9KvzsqKor/+VhbWzecuH/mrU0PHz6Mi4vj/yKOGTOmzWInJorSqJiqq6vNzMxkMlnDmfSqqioHB4eAgICjR482XeHUSjt37jQzMwMwe/ZsMe90E0lubm7v3r0BzJ07t2mTMzY2dtq0aUrDLUmPmJmZDR06dMmSJWq1Oi4u7uHDh48tvKKiwsrKSiKR3L59u+0/CjEhlEbFlJiYyCeLGj2va7vDoBiLjo7mzbGxY8ca77RNE8C3FQGYMGHC0/vdBQUF58+fDw0NnTdv3sCBAw19dgMXF5c33njjr3/969dff93wB/nIyUcffdTGH4WYFkqjYuITHQsWLDByvRqNpk+fPjwd/PTTT0auXRRarZZf2t6/f//CwsLn+tmampqGE/eGhb0ABgwY0PCdkZGRAIYMGSJo7MTUyUHEk5KSAmDYsGFGrtfT0/PSpUvTp09PSkry8fE5efIkn8rvwNasWXPq1Knu3bufOXOGz6c3n0Kh8PT09PT0XLx4MQCtVnvz5s2UlJSUlBQ7O7uG75w8eXL37t1TU1M1Gs2gQYOE/ADElImdxzu18ePHA4iOjhal9vLy8pkzZ8pksq1bt3bsBTr/+Mc/ACgUitjY2Laua8WKFQDWr1/f1hUR09GJLlg2NXq93tbWtry8vKioqOlSROPIyMhwc3NzcHAoKCgQJQAjOH369MyZM/V6/d69e/mhLW0qNjb2lVde6dOnT2ZmpmHFBenYGo+dE6NJT08vLy93cnISK4cCSE1NBTB8+HCxAmhr165dmzt3rk6n27RpkxFyKABfX18nJ6fs7OwffvjBCNURU0BpVDR8YNTIKWzr1q1//vOf09LSRIzBaHJycmbOnFlRUbFs2bL333/fOJVKpdI5c+YAOHTokHFqJKKjNCoaUeaXDh8+/Omnn969e1fEGIyjrKxs6tSpd+/e9fX15bfSGw1f9nT48GGtVmvMeolYKI2KJjk5GcZNYYwx3os3VGr8GIxDq9XOnj07NTXV3d395MmTTU9salMjR44cMGBAYWFhbGysMeslYqE0Khrjd6h/+eWXkpISpVLZo0cPAMXFxXfv3rW0tHRzczNaDMahUqkuXLhgb28fGRlpOFTQmHiDlPr1nQSlUXHk5eUVFBTY2dn17dvXaJU26sLzpujQoUOb7tIxqvPn8fLL8PfHsmWoqHiuHy0pKXnw4MGDBw+ysrJu3759+/bta9eurV27NiwszMLCIioqquH1f8bE0+jx48erqqpECYAYEy2/F4chhRlzTUyj9q9J9Ojz87FuHWJjYWuLXbuwfj3UanzyCaKj8eBB/XsePgRflvfoC5u6urLy8icV6ejoKJFI9u/fP3bsWGN8hMd56aWXvLy8rl27du7cuVmzZokVBjEOSqPiEGWK3JC7G8YgchqNjMT8+bC1BYClS+HhAQC3biEm5ik/ZPaon25jYyOTyQDY2tryNrWVlZVGo5HL5S+//HIbh/4M8+fPv3bt2qFDhyiNdniURsUhSgprlLtNYrVTbi4MI7NSKaRSaLV47z3Mmwc7O/Cmuq0t+LDDoy/u2dpKnjwQ4e/vf+HChWPHjvENRWJZsGBBcHBwVFRUSUkJPwuGdFQ0NioO43eoCwsLc3Nzra2tXVxcAFRVVd26dUsul4u89VupRG5u/dd6PXQ6yOVwd4efH0aOxIgRGDECrq5wcYGLC7p3R9eu6Nr1KTkUJjO9o1QqfXx8qqurT506JW4kpK1RGhVBWVnZ7du3FQqFu7u70SptNKGUmpqq1WoHDhzIbw8VzfTpOHwYpaUAsHcvJk9ufZFvvPGGhYXFt99+m5OT0/rSWsNEEjppa5RGRXD9+nW9Xj9o0CBjrmc0xR49AEdHTJ2KoUPh44O4OHz4YYtLys7Ozs/PB2BjYzNlyhS9Xn/06FHhAm2JOXPmKBSKmJiYDnxkAQGlUVGIMkXeaDTWJOaXuKtXkZWFFSuwZw+srFpWxieffNKvX7/PP/+cvzSRZmC3bt1effVVrVZ7/PhxcSMhbYrSqAge2xJcs2ZNWFhYW1faaNGo+K3RigrExUEqbWV3fuzYsYyxgwcP8hPLAgICbG1tExMTb926JVCgLWQiCZ20LXHP6euc+FUWx48fNzyJj4/n/3esW7euLY7+LC8vl8lkZmZmVVVVjDGtVsuPcC8uLha8rudz4gQD2PjxrSxGr9fz8/y///57/mTJkiUANm7c2OoQW6W8vNzS0lIikWRmZoobCWk71BoVQWhoqJmZmUqlun79On/i4+Oza9cuMzOzjz76KDAwUPCtL6mpqTqdzsPDg08o3bp1q7Kysk+fPt26dRO2oufGL+OcPr2VxUgkksDAQDRo94nSDKyurvb39//6668NTywtLceMGSOVSr///ntjRkKMSuw83hndv39/4sSJAKysrKKiogzPY2Ji+KUUgl82t3XrVgBLlizhLw8cOADg9ddfF7CKltDpWM+eDGAaTesLS0pKAvDiiy/yCwHr6ur40QHXrl1rfeHNodfr+RF5/fv3N1zpmpaWZmNj4+DgkJWVZZwwiPFRa1QEXbt2PX/+fFBQUHl5+euvv75lyxb+fNKkSQkJCX379r18+fK4ceNu3rwpVI0mOr+UmIj8fPTpA0/P1hfm5eXl4eFRVFR08eJFAHK5fPbs2TBig/T9998/evSojY3NiRMn+BqMe/fuzZgxo7S09L/+67+cnZ2NEwYRgdh5vPPS6/UhISF8T71KpdLpdPx5Xl7eqFGjAHTt2lWou4OOHz++YsWKxMRE/vLVV18FEBERIUjhLbd+PQOYSiVUeRs2bACwdOlS/jIhIQGAs7Oz4c+27ezcuROAmZnZhQsX+JPKykq+qX/UqFEVFRVtHQAREaVRkR0+fJiPV86aNcvwy1ZVVTV37lwACoUiPDxc8Er51ZjZ2dmCl/x8hgxhABPuRr+MjAwANjY2lZWVjDG9Xs8P0Pruu++EquKxYmNjefMzLCyMP9HpdG+88QaAfv365efnt2ntRHSURsWXkJDAr2MaPXp0Xl4ef8jbqgAkEklwcLCA0/d37tzhTV2RbwPNzmYSCbOyYtXVApY6cuRINFgF8Ze//AXAO++8I2AVjaSlpfER7XXr1hkevvfeewBsbW1v3LjRdlUTE0Fp1CSkp6cPGDAAQK9evZKTkw3Pv/jiC7lcDmDOnDl8rVLrRUVFAXjllVcEKa3F9mzbtmn06J9Wrxa22M2bNwOYPXs2f8nXQtjb29fW1gpbEZeXl8cXWgUGBhqGDr744gvewY+JiWmLSompoTRqKu7duzdhwgQA1tbWZ86cMTw/f/68jY0NgPHjxxcWFramCp1OFxsbywdeDYlGLP7+/gAEH7LIycmRyWTm5uYPHz7kT/jZK6dPnxa2IvaE0c+zZ8/K5XKJRLJnzx7BaySmidKoCamurl6wYAEAuVxuGGVjjKWmpvbu3RuAq6vrzz//3IKSNRpNSEhIv379+LyinZ2dXC7fuXOncLE/n7KyMnNzc5lMVlRUJHjhfDGZIYtt2rQJwKJFi4StRafT8YNE+/XrZ1idptFo+Jl4ISEhwlZHTBmlUdNiGBJtNH1/584dvj7Jx8en+aXdvHkzJCSEDxdwLi4u69evX7VqVdMqjInvMff29m6Lwrdv3w7A39+fv/zll18kEomlpWV5ebmAtaxZswZAt27dbt68yZ/k5ubyVU1z584VedyZGBelUVO0c+dOMzMz3vXmk86MsdLS0qCgoIyMjGf++N27d9Vqtbe3t+GGku7du69cuTI+Pt7w6/3YKoxm6dKlAEJDQ9ui8Pv37ysUCrlcbpgiHzNmDIAjR44IVcWOHTsajX5WVFTwPb4TJkyoFnTSjJg+SqMmKjo6mncPm7+j6f79++Hh4QEBAfxeDT5THBQUFBkZ+dgJlhZU0QIVFRWGYUpOp9Px/UVpaWltVOn06dMBbNmyhb9Uq9UQbteWYfTTMLCr1WpnzpwJoH///q0cvybtEaVR06XRaPgssIuLy08//fSkt1VWVkZGRgYGBhpOLzU3Nw8ICAgPD3/mqu9mVtECWq02Ojo6KCjI2tq60fkgfHe5i4uLgNU1cvDgwYaDBgUFBXK5XKFQtP4olhs3bjQd/fzDH/7Am/y3bt1qZfmkPaI0atLy8vL4Qshu3brFxcU1/JYhT1k9OqNTJpN5e3tv3769pKTkuaoQcNOUTqeLi4t7++23u3fvzqOSSCSNpnfWrVsHYLXQS50aKi8vX7NmzeXLlw1P/Pz8AHz55ZetLPnIkSMKhWLRokWG4ZG///3v/J+uhISEVhZO2ilKo6auvLycdxgVCsXevXt1Ol18fLxKpeI7kbgRI0ao1eoW75ZpVEXLCuGLAfhFT5yHh0dISEjTBhpfgWTkNZX8cBYnJ6dPP/00Ojr63r17LS7qypUrhtHPr7/+WiaTSSSSffv2CRQpaX8ojbYDWq129erVvGXX8Gi7oUOHfvzxx4Ls6Wy4aSokJKT5E81ZWVmhoaEDBw40ROXs7KxSqeLj45u+OTMzc926dRKJxMLCQth586f76aef3N3dG13Z4ujo6Ofnp1KpwsPDNRpNC1YsJCUlWVpaAvjwww/bImzSXlAabTe2b98+ZswYS0vL3r17q1Sqtjj/bceOHXzT1Ny5c5++aarpYoBu3boFBQVFR0c3TcHFxcXh4eF+fn78zfxfgsGDB6ekpAj+EZoKDw/nyc7d3f1vf/vb7373u/Hjx1s1ua3Ezs5u4sSJ77777p49e65fv/7MXU937951cnICsGzZMiN8CmLKKI22J1qttjkLnlrDsGnK29v7sWvjHzx4MHHiROmjK45tbW2XLFly7tw5rVbb6J0PHz7cvXv3a6+9Zlg5YGlpuWDBgs8++8zNzY2PJ4aGhrbdwtWqqiqVSj6OlTMAAAW9SURBVMWr5scSNvxuTk5OZGRkSEhIQECAo6Njo6xqZmbm4eERFBSkVqubDgKUlpYOGTIEgK+vr+FoUdJpURoljaWmpvJl5E/aNOXh4dGlSxe+GKBp37y6ujoyMjIoKIi3Afncl5+fX3h4eFlZGX9PRUWFSqXijdNJkyb9+9//FvxTZGVl8YWc5ubmarX6me/Pyck5ffr0hx9+GBgY6Orqamhlc1Kp1M3NLTAw8H//93+joqJefvll3ry9f/++4JGTdofSKHmMnJwcLy8vvoin6SlzKSkpjZaCMsa0Wm18fPzKlSt5Y5anHm9vb7Va/aSllGfPnlUqlbxJu337dgHjj4iI4Kcuubm5tWzooLS09OrVq+Hh4SqVytvb28LComFWtbGx6dGjB12vRDhKo+TxysrKAgICAHTp0uXAgQNPeefVq1dVKlXPnj0bzdHfvn37mbUUFhbynekAZs+e3ZoJdK6uri44OJi3JV9//fUHDx60skCutrY2JSVlz549q1ev9vX1PXHiRFJSkiAlkw6A0ih5Iq1W+/vf/94wfd/ou3yFk6urqyF79u3bNzg42LDHvPmOHj3atWtXAD169Gh4OdXzunPnzvjx4wHI5fLQ0FDa2E6Mg9IoeQa1Ws0nlN58883a2trs7Gy1Ws27/JyTk9OTVjg1X1ZWFj+ZSSKRrFy50jCK2nwxMTF8j6mzs7PhmmVCjIDSKHm2EydO8HvteZ7i7O3t33nnne+++06oqXa9Xq9Wq7t06QKgX79+zc/LWq02JCSE5/pp06a1fscnIc+F0ihplitXrri6uvr6+lpYWAQGBkZGRrbRQh+NRjN8+HDeMQ8ODn5mLYWFhfyGPplMFhISIsq5f6STkzDGQEgzaLXavLw8e3v7RtPWgqurq/vwww83bdqk0+lGjRq1d+/ehrukGvruu+/mzZuXl5fn4OBw4MABvnGeECOjNEpM1KVLlxYvXpyRkWFubr5hw4a1a9ca1vwDYIx9/vnna9euraur8/X1PXToUNMl9IQYB6VRYrrKysr+9Kc/8TOS/fz8du/ezfdfFhcXL168+MyZMxKJ5A9/+MPmzZv5HlZCREFplJi6s2fPLl++PC8vz9bW9l//+pe7u/ucOXMyMzO7d+++b9++KVOmiB0g6ewojZJ2oKCgYMWKFfxqaJlMptPpxo0bd+TIEb5plRBxSZ/9FkLElpub6+jouGDBAmdn50mTJgUEBFhYWOzbt0/suAgBABpRIu3AnTt3duzYMWPGjIyMDKlUevLkyTlz5jQ8epUQEVEaJe0AP3G5traWf8H/t66uTuSwCAFAnXrSLvC7oGtra/lLQ1YVMyZCHqE0StqBRs3PRlmVEHFRGiXtQKPmJ3XqiUmhNEraAerUE1NGaZS0A4/t1FNrlJgISqOkHXhsp55ao8REUBol7UCjvElTTMSk0LpR0g5YKRR/9vJS2tvzl/ZduuybMEHRvbu4URHC0Z560h4UFcHBAS++iMJCALh7F87OcHLCnTtiR0YIdepJu6BQAIChF29m9quXhIiK0ihpDxrlTZ5VaaaemAZKo6Q9aJQ3GzVOCREVpVHSHsjlkEqh1UKvByiNEtNCaZS0E7xfzxukZmaQSFBXB5ogJSaA0ihpJxoNjzbMqoSIitIoaScaDY/SZD0xGZRGSTvRaDyUJuuJyaBdTKSdaNSLd3D41UtCxEO7mEg7UVICS0vQffTE9FCnnrQTtrb45hu8/DL8/bFsGSoqxA6IkHrUGiXtRH4+pk5FbCxsbbFrF1JToVaLHRMhAKVR0m7s2IGSEqxdCwB6PTw8cPOm2DERAlCnnrQbublQKuu/lkrrNzURYgIojZJ2QqlEbm7913o9dDqabiImgtIoaSemT8fhwygtBYC9ezF5stgBEVKPxkZJ+3HmDD75BAoFlEps2QIrK7EDIgSgNEoIIa1EnXpCCGkVSqOEENIq/x9vVZpe6phePAAAAABJRU5ErkJggg==\n",
"text/plain": [
"<rdkit.Chem.rdchem.Mol at 0x7efd736958f0>"
]
},
"execution_count": 61,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"m = Chem.MolFromSmiles(smiles)\n",
"m"
]
},
{
"cell_type": "code",
"execution_count": 62,
"metadata": {},
"outputs": [],
"source": [
"g = smile2graph(smiles)"
]
},
{
"cell_type": "code",
"execution_count": 63,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"DGLGraph(num_nodes=28, num_edges=60,\n",
" ndata_schemes={'h': Scheme(shape=(74,), dtype=torch.float32)}\n",
" edata_schemes={})"
]
},
"execution_count": 63,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"g"
]
},
{
"cell_type": "code",
"execution_count": 64,
"metadata": {},
"outputs": [],
"source": [
"bg = dgl.batch([g])"
]
},
{
"cell_type": "code",
"execution_count": 65,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([28, 74])"
]
},
"execution_count": 65,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"bg.ndata['h'].shape"
]
},
{
"cell_type": "code",
"execution_count": 66,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/ubuntu/playground/mz_dgl/python/dgl/base.py:18: UserWarning: Initializer is not set. Use zero initializer instead. To suppress this warning, use `set_initializer` to explicitly specify which initializer to use.\n",
" warnings.warn(msg, warn_type)\n"
]
}
],
"source": [
"logits = model(bg.ndata['h'], bg)"
]
},
{
"cell_type": "code",
"execution_count": 67,
"metadata": {},
"outputs": [],
"source": [
"preds = logits.data.numpy() > 0.5"
]
},
{
"cell_type": "code",
"execution_count": 68,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>NR-AR</th>\n",
" <th>NR-AR-LBD</th>\n",
" <th>NR-AhR</th>\n",
" <th>NR-Aromatase</th>\n",
" <th>NR-ER</th>\n",
" <th>NR-ER-LBD</th>\n",
" <th>NR-PPAR-gamma</th>\n",
" <th>SR-ARE</th>\n",
" <th>SR-ATAD5</th>\n",
" <th>SR-HSE</th>\n",
" <th>SR-MMP</th>\n",
" <th>SR-p53</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" NR-AR NR-AR-LBD NR-AhR NR-Aromatase NR-ER NR-ER-LBD NR-PPAR-gamma \\\n",
"0 False False True False True False True \n",
"\n",
" SR-ARE SR-ATAD5 SR-HSE SR-MMP SR-p53 \n",
"0 False True False False True "
]
},
"execution_count": 68,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import pandas as pd\n",
"pd.DataFrame(preds, columns=tasks)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Environment (conda_miniconda3-latest)",
"language": "python",
"name": "conda_miniconda3-latest"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.1"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
......@@ -47,7 +47,7 @@ def split_dataset(dataset, frac_list=None, shuffle=False, random_state=None):
return [Subset(dataset, indices[offset - length:offset]) for offset, length in zip(accumulate(lengths), lengths)]
def download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_ssl=True):
def download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_ssl=True, log=True):
"""Download a given URL.
Codes borrowed from mxnet/gluon/utils.py
......@@ -68,6 +68,8 @@ def download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_
The number of times to attempt downloading in case of failure or non 200 return codes.
verify_ssl : bool, default True
Verify SSL certificates.
log : bool, default True
Whether to print the progress for download
Returns
-------
......@@ -100,7 +102,8 @@ def download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_
# Disable pyling too broad Exception
# pylint: disable=W0703
try:
print('Downloading %s from %s...' % (fname, url))
if log:
print('Downloading %s from %s...' % (fname, url))
r = requests.get(url, stream=True, verify=verify_ssl)
if r.status_code != 200:
raise RuntimeError("Failed downloading url %s" % url)
......@@ -119,8 +122,9 @@ def download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_
if retries <= 0:
raise e
else:
print("download failed, retrying, {} attempt{} left"
.format(retries, 's' if retries > 1 else ''))
if log:
print("download failed, retrying, {} attempt{} left"
.format(retries, 's' if retries > 1 else ''))
return fname
......
......@@ -44,6 +44,24 @@ molecular graph topology, which may be viewed as a learned fingerprint [3].
- **Graph Convolutional Network**: Graph Convolutional Networks (GCN) have been one of the most popular graph neural
networks and they can be easily extended for graph level prediction.
## Generative Models
We use generative models for two different purposes when it comes to molecules:
- **Distribution Learning**: Given a collection of molecules, we want to model their distribution and generate new
molecules with similar properties.
- **Goal-directed Optimization**: Find molecules with desired properties.
For this model zoo, we will only focused on generative models for molecular graphs. There are other generative models
working with alternative representations like SMILES.
Generative models are known to be difficult for evaluation. [GuacaMol](https://github.com/BenevolentAI/guacamol) and
[MOSES](https://github.com/molecularsets/moses) have been two recent efforts to benchmark generative models. There
are also two accompanying review papers that are well written [4], [5].
### Models
- **Deep Generative Models of Graphs (DGMG)**: A very general framework for graph distribution learning by progressively
adding atoms and bonds.
## References
[1] Chen et al. (2018) The rise of deep learning in drug discovery. *Drug Discov Today* 6, 1241-1250.
......@@ -53,3 +71,8 @@ networks and they can be easily extended for graph level prediction.
[3] Duvenaud et al. (2015) Convolutional networks on graphs for learning molecular fingerprints. *Advances in neural
information processing systems (NeurIPS)*, 2224-2232.
[4] Brown et al. (2019) GuacaMol: Benchmarking Models for de Novo Molecular Design. *J. Chem. Inf. Model*, 2019, 59, 3,
1096-1108.
[5] Polykovskiy et al. (2019) Molecular Sets (MOSES): A Benchmarking Platform for Molecular Generation Models. *arXiv*.
......@@ -2,4 +2,5 @@
"""Model Zoo Package"""
from .gcn import GCNClassifier
from .dgmg import DGMG
from .pretrain import load_pretrained
# pylint: disable=C0103, W0622, R1710, W0104
"""
Learning Deep Generative Models of Graphs
https://arxiv.org/pdf/1803.03324.pdf
"""
from functools import partial
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from torch.distributions import Categorical
import dgl
from dgl import DGLGraph
try:
from rdkit import Chem
except ImportError:
pass
class MoleculeEnv(object):
"""MDP environment for generating molecules.
Parameters
----------
atom_types : list
E.g. ['C', 'N']
bond_types : list
E.g. [Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE,
Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC]
"""
def __init__(self, atom_types, bond_types):
super(MoleculeEnv, self).__init__()
self.atom_types = atom_types
self.bond_types = bond_types
self.atom_type_to_id = dict()
self.bond_type_to_id = dict()
for id, a_type in enumerate(atom_types):
self.atom_type_to_id[a_type] = id
for id, b_type in enumerate(bond_types):
self.bond_type_to_id[b_type] = id
def get_decision_sequence(self, mol, atom_order):
"""Extract a decision sequence with which DGMG can generate the
molecule with a specified atom order.
Parameters
----------
mol : Chem.rdchem.Mol
atom_order : list
Specifies a mapping between the original atom
indices and the new atom indices. In particular,
atom_order[i] is re-labeled as i.
Returns
-------
decisions : list
decisions[i] is a 2-tuple (i, j)
- If i = 0, j specifies either the type of the atom to add
self.atom_types[j] or termination with j = len(self.atom_types)
- If i = 1, j specifies either the type of the bond to add
self.bond_types[j] or termination with j = len(self.bond_types)
- If i = 2, j specifies the destination atom id for the bond to add.
With the formulation of DGMG, j must be created before the decision.
"""
decisions = []
old2new = dict()
for new_id, old_id in enumerate(atom_order):
atom = mol.GetAtomWithIdx(old_id)
a_type = atom.GetSymbol()
decisions.append((0, self.atom_type_to_id[a_type]))
for bond in atom.GetBonds():
u = bond.GetBeginAtomIdx()
v = bond.GetEndAtomIdx()
if v == old_id:
u, v = v, u
if v in old2new:
decisions.append((1, self.bond_type_to_id[bond.GetBondType()]))
decisions.append((2, old2new[v]))
decisions.append((1, len(self.bond_types)))
old2new[old_id] = new_id
decisions.append((0, len(self.atom_types)))
return decisions
def reset(self, rdkit_mol=False):
"""Setup for generating a new molecule
Parameters
----------
rdkit_mol : bool
Whether to keep a Chem.rdchem.Mol object so
that we know what molecule is being generated
"""
self.dgl_graph = DGLGraph()
# If there are some features for nodes and edges,
# zero tensors will be set for those of new nodes and edges.
self.dgl_graph.set_n_initializer(dgl.frame.zero_initializer)
self.dgl_graph.set_e_initializer(dgl.frame.zero_initializer)
self.mol = None
if rdkit_mol:
# RWMol is a molecule class that is intended to be edited.
self.mol = Chem.RWMol(Chem.MolFromSmiles(''))
def num_atoms(self):
"""Get the number of atoms for the current molecule.
Returns
-------
int
"""
return self.dgl_graph.number_of_nodes()
def add_atom(self, type):
"""Add an atom of the specified type.
Parameters
----------
type : int
Should be in the range of [0, len(self.atom_types) - 1]
"""
self.dgl_graph.add_nodes(1)
if self.mol is not None:
self.mol.AddAtom(Chem.Atom(self.atom_types[type]))
def add_bond(self, u, v, type, bi_direction=True):
"""Add a bond of the specified type between atom u and v.
Parameters
----------
u : int
Index for the first atom
v : int
Index for the second atom
type : int
Index for the bond type
bi_direction : bool
Whether to add edges for both directions in the DGLGraph.
If not, we will only add the edge (u, v).
"""
if bi_direction:
self.dgl_graph.add_edges([u, v], [v, u])
else:
self.dgl_graph.add_edge(u, v)
if self.mol is not None:
self.mol.AddBond(u, v, self.bond_types[type])
def get_current_smiles(self):
"""Get the generated molecule in SMILES
Returns
-------
s : str
SMILES
"""
assert self.mol is not None, 'Expect a Chem.rdchem.Mol object initialized.'
s = Chem.MolToSmiles(self.mol)
return s
class GraphEmbed(nn.Module):
"""Compute a molecule representations out of atom representations.
Parameters
----------
node_hidden_size : int
Size of atom representation
"""
def __init__(self, node_hidden_size):
super(GraphEmbed, self).__init__()
# Setting from the paper
self.graph_hidden_size = 2 * node_hidden_size
# Embed graphs
self.node_gating = nn.Sequential(
nn.Linear(node_hidden_size, 1),
nn.Sigmoid()
)
self.node_to_graph = nn.Linear(node_hidden_size,
self.graph_hidden_size)
def forward(self, g):
"""
Parameters
----------
g : DGLGraph
Current molecule graph
Returns
-------
tensor of dtype float32 and shape (1, self.graph_hidden_size)
Computed representation for the current molecule graph
"""
if g.number_of_nodes() == 0:
# Use a zero tensor for an empty molecule.
return torch.zeros(1, self.graph_hidden_size)
else:
# Node features are stored as hv in ndata.
hvs = g.ndata['hv']
return (self.node_gating(hvs) *
self.node_to_graph(hvs)).sum(0, keepdim=True)
class GraphProp(nn.Module):
"""Perform message passing over a molecule graph and update its atom representations.
Parameters
----------
num_prop_rounds : int
Number of message passing rounds for each time
node_hidden_size : int
Size of atom representation
edge_hidden_size : int
Size of bond representation
"""
def __init__(self, num_prop_rounds, node_hidden_size, edge_hidden_size):
super(GraphProp, self).__init__()
self.num_prop_rounds = num_prop_rounds
# Setting from the paper
self.node_activation_hidden_size = 2 * node_hidden_size
message_funcs = []
self.reduce_funcs = []
node_update_funcs = []
for t in range(num_prop_rounds):
# input being [hv, hu, xuv]
message_funcs.append(nn.Linear(2 * node_hidden_size + edge_hidden_size,
self.node_activation_hidden_size))
self.reduce_funcs.append(partial(self.dgmg_reduce, round=t))
node_update_funcs.append(
nn.GRUCell(self.node_activation_hidden_size,
node_hidden_size))
self.message_funcs = nn.ModuleList(message_funcs)
self.node_update_funcs = nn.ModuleList(node_update_funcs)
def dgmg_msg(self, edges):
"""For an edge u->v, send a message concat([h_u, x_uv])
Parameters
----------
edges : batch of edges
Returns
-------
dict
Dictionary containing messages for the edge batch,
with the messages being tensors of shape (B, F1),
B for the number of edges and F1 for the message size.
"""
return {'m': torch.cat([edges.src['hv'],
edges.data['he']],
dim=1)}
def dgmg_reduce(self, nodes, round):
"""Aggregate messages.
Parameters
----------
nodes : batch of nodes
round : int
Update round
Returns
-------
dict
Dictionary containing aggregated messages for each node
in the batch, with the messages being tensors of shape
(B, F2), B for the number of nodes and F2 for the aggregated
message size
"""
hv_old = nodes.data['hv']
m = nodes.mailbox['m']
# Make copies of original atom representations to match the
# number of messages.
message = torch.cat([
hv_old.unsqueeze(1).expand(-1, m.size(1), -1), m], dim=2)
node_activation = (self.message_funcs[round](message)).sum(1)
return {'a': node_activation}
def forward(self, g):
"""
Parameters
----------
g : DGLGraph
"""
if g.number_of_edges() == 0:
return
else:
for t in range(self.num_prop_rounds):
g.update_all(message_func=self.dgmg_msg,
reduce_func=self.reduce_funcs[t])
g.ndata['hv'] = self.node_update_funcs[t](
g.ndata['a'], g.ndata['hv'])
class AddNode(nn.Module):
"""Stop or add an atom of a particular type.
Parameters
----------
env : MoleculeEnv
Environment for generating molecules
graph_embed_func : callable taking g as input
Function for computing molecule representation
node_hidden_size : int
Size of atom representation
dropout : float
Probability for dropout
"""
def __init__(self, env, graph_embed_func, node_hidden_size, dropout):
super(AddNode, self).__init__()
self.env = env
n_node_types = len(env.atom_types)
self.graph_op = {'embed': graph_embed_func}
self.stop = n_node_types
self.add_node = nn.Sequential(
nn.Linear(graph_embed_func.graph_hidden_size, graph_embed_func.graph_hidden_size),
nn.Dropout(p=dropout),
nn.Linear(graph_embed_func.graph_hidden_size, n_node_types + 1)
)
# If to add a node, initialize its hv
self.node_type_embed = nn.Embedding(n_node_types, node_hidden_size)
self.initialize_hv = nn.Linear(node_hidden_size + \
graph_embed_func.graph_hidden_size,
node_hidden_size)
self.init_node_activation = torch.zeros(1, 2 * node_hidden_size)
self.dropout = nn.Dropout(p=dropout)
def _initialize_node_repr(self, g, node_type, graph_embed):
"""Initialize atom representation
Parameters
----------
g : DGLGraph
node_type : int
Index for the type of the new atom
graph_embed : tensor of dtype float32
Molecule representation
"""
num_nodes = g.number_of_nodes()
hv_init = torch.cat([
self.node_type_embed(torch.LongTensor([node_type])),
graph_embed], dim=1)
hv_init = self.dropout(hv_init)
hv_init = self.initialize_hv(hv_init)
g.nodes[num_nodes - 1].data['hv'] = hv_init
g.nodes[num_nodes - 1].data['a'] = self.init_node_activation
def prepare_log_prob(self, compute_log_prob):
"""Setup for returning log likelihood
Parameters
----------
compute_log_prob : bool
Whether to compute log likelihood
"""
if compute_log_prob:
self.log_prob = []
self.compute_log_prob = compute_log_prob
def forward(self, action=None):
"""
Parameters
----------
action : None or int
If None, a new action will be sampled. If not None,
teacher forcing will be used to enforce the decision of the
corresponding action.
Returns
-------
stop : bool
Whether we stop adding new atoms
"""
g = self.env.dgl_graph
graph_embed = self.graph_op['embed'](g)
logits = self.add_node(graph_embed).view(1, -1)
probs = F.softmax(logits, dim=1)
if action is None:
action = Categorical(probs).sample().item()
stop = bool(action == self.stop)
if not stop:
self.env.add_atom(action)
self._initialize_node_repr(g, action, graph_embed)
if self.compute_log_prob:
sample_log_prob = F.log_softmax(logits, dim=1)[:, action: action + 1]
self.log_prob.append(sample_log_prob)
return stop
class AddEdge(nn.Module):
"""Stop or add a bond of a particular type.
Parameters
----------
env : MoleculeEnv
Environment for generating molecules
graph_embed_func : callable taking g as input
Function for computing molecule representation
node_hidden_size : int
Size of atom representation
dropout : float
Probability for dropout
"""
def __init__(self, env, graph_embed_func, node_hidden_size, dropout):
super(AddEdge, self).__init__()
self.env = env
n_bond_types = len(env.bond_types)
self.stop = n_bond_types
self.graph_op = {'embed': graph_embed_func}
self.add_edge = nn.Sequential(
nn.Linear(graph_embed_func.graph_hidden_size + node_hidden_size,
graph_embed_func.graph_hidden_size + node_hidden_size),
nn.Dropout(p=dropout),
nn.Linear(graph_embed_func.graph_hidden_size + node_hidden_size, n_bond_types + 1)
)
def prepare_log_prob(self, compute_log_prob):
"""Setup for returning log likelihood
Parameters
----------
compute_log_prob : bool
Whether to compute log likelihood
"""
if compute_log_prob:
self.log_prob = []
self.compute_log_prob = compute_log_prob
def forward(self, action=None):
"""
Parameters
----------
action : None or int
If None, a new action will be sampled. If not None,
teacher forcing will be used to enforce the decision of the
corresponding action.
Returns
-------
stop : bool
Whether we stop adding new bonds
action : int
The type for the new bond
"""
g = self.env.dgl_graph
graph_embed = self.graph_op['embed'](g)
src_embed = g.nodes[g.number_of_nodes() - 1].data['hv']
logits = self.add_edge(
torch.cat([graph_embed, src_embed], dim=1))
probs = F.softmax(logits, dim=1)
if action is None:
action = Categorical(probs).sample().item()
stop = bool(action == self.stop)
if self.compute_log_prob:
sample_log_prob = F.log_softmax(logits, dim=1)[:, action: action + 1]
self.log_prob.append(sample_log_prob)
return stop, action
class ChooseDestAndUpdate(nn.Module):
"""Choose the atom to connect for the new bond.
Parameters
----------
env : MoleculeEnv
Environment for generating molecules
graph_prop_func : callable taking g as input
Function for performing message passing
and updating atom representations
node_hidden_size : int
Size of atom representation
dropout : float
Probability for dropout
"""
def __init__(self, env, graph_prop_func, node_hidden_size, dropout):
super(ChooseDestAndUpdate, self).__init__()
self.env = env
n_bond_types = len(self.env.bond_types)
# To be used for one-hot encoding of bond type
self.bond_embedding = torch.eye(n_bond_types)
self.graph_op = {'prop': graph_prop_func}
self.choose_dest = nn.Sequential(
nn.Linear(2 * node_hidden_size + n_bond_types, 2 * node_hidden_size + n_bond_types),
nn.Dropout(p=dropout),
nn.Linear(2 * node_hidden_size + n_bond_types, 1)
)
def _initialize_edge_repr(self, g, src_list, dest_list, edge_embed):
"""Initialize bond representation
Parameters
----------
g : DGLGraph
src_list : list of int
source atoms for new bonds
dest_list : list of int
destination atoms for new bonds
edge_embed : 2D tensor of dtype float32
Embeddings for the new bonds
"""
g.edges[src_list, dest_list].data['he'] = edge_embed.expand(len(src_list), -1)
def prepare_log_prob(self, compute_log_prob):
"""Setup for returning log likelihood
Parameters
----------
compute_log_prob : bool
Whether to compute log likelihood
"""
if compute_log_prob:
self.log_prob = []
self.compute_log_prob = compute_log_prob
def forward(self, bond_type, dest):
"""
Parameters
----------
bond_type : int
The type for the new bond
dest : int or None
If None, a new action will be sampled. If not None,
teacher forcing will be used to enforce the decision of the
corresponding action.
"""
g = self.env.dgl_graph
src = g.number_of_nodes() - 1
possible_dests = range(src)
src_embed_expand = g.nodes[src].data['hv'].expand(src, -1)
possible_dests_embed = g.nodes[possible_dests].data['hv']
edge_embed = self.bond_embedding[bond_type: bond_type + 1]
dests_scores = self.choose_dest(
torch.cat([possible_dests_embed,
src_embed_expand,
edge_embed.expand(src, -1)], dim=1)).view(1, -1)
dests_probs = F.softmax(dests_scores, dim=1)
if dest is None:
dest = Categorical(dests_probs).sample().item()
if not g.has_edge_between(src, dest):
# For undirected graphs, we add edges for both directions
# so that we can perform graph propagation.
src_list = [src, dest]
dest_list = [dest, src]
self.env.add_bond(src, dest, bond_type)
self._initialize_edge_repr(g, src_list, dest_list, edge_embed)
# Perform message passing when new bonds are added.
self.graph_op['prop'](g)
if self.compute_log_prob:
if dests_probs.nelement() > 1:
self.log_prob.append(
F.log_softmax(dests_scores, dim=1)[:, dest: dest + 1])
def weights_init(m):
'''Function to initialize weights for models
Code from https://gist.github.com/jeasinema/ed9236ce743c8efaf30fa2ff732749f5
Usage:
model = Model()
model.apply(weight_init)
'''
if isinstance(m, nn.Linear):
init.xavier_normal_(m.weight.data)
init.normal_(m.bias.data)
elif isinstance(m, nn.GRUCell):
for param in m.parameters():
if len(param.shape) >= 2:
init.orthogonal_(param.data)
else:
init.normal_(param.data)
def dgmg_message_weight_init(m):
"""Weight initialization for graph propagation module
These are suggested by the author. This should only be used for
the message passing functions, i.e. fe's in the paper.
"""
def _weight_init(m):
if isinstance(m, nn.Linear):
init.normal_(m.weight.data, std=1./10)
init.normal_(m.bias.data, std=1./10)
else:
raise ValueError('Expected the input to be of type nn.Linear!')
if isinstance(m, nn.ModuleList):
for layer in m:
layer.apply(_weight_init)
else:
m.apply(_weight_init)
class DGMG(nn.Module):
"""DGMG model
Users only need to initialize an instance of this class.
Parameters
----------
atom_types : list
E.g. ['C', 'N']
bond_types : list
E.g. [Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE,
Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC]
node_hidden_size : int
Size of atom representation
num_prop_rounds : int
Number of message passing rounds for each time
dropout : float
Probability for dropout
"""
def __init__(self, atom_types, bond_types, node_hidden_size, num_prop_rounds, dropout):
super(DGMG, self).__init__()
self.env = MoleculeEnv(atom_types, bond_types)
# Graph embedding module
self.graph_embed = GraphEmbed(node_hidden_size)
# Graph propagation module
# For one-hot encoding, edge_hidden_size is just the number of bond types
self.graph_prop = GraphProp(num_prop_rounds, node_hidden_size, len(self.env.bond_types))
# Actions
self.add_node_agent = AddNode(
self.env, self.graph_embed, node_hidden_size, dropout)
self.add_edge_agent = AddEdge(
self.env, self.graph_embed, node_hidden_size, dropout)
self.choose_dest_agent = ChooseDestAndUpdate(
self.env, self.graph_prop, node_hidden_size, dropout)
# Weight initialization
self.init_weights()
def init_weights(self):
"""Initialize model weights"""
self.graph_embed.apply(weights_init)
self.graph_prop.apply(weights_init)
self.add_node_agent.apply(weights_init)
self.add_edge_agent.apply(weights_init)
self.choose_dest_agent.apply(weights_init)
self.graph_prop.message_funcs.apply(dgmg_message_weight_init)
def count_step(self):
"""Increment the step by 1."""
self.step_count += 1
def prepare_log_prob(self, compute_log_prob):
"""Setup for returning log likelihood
Parameters
----------
compute_log_prob : bool
Whether to compute log likelihood
"""
self.compute_log_prob = compute_log_prob
self.add_node_agent.prepare_log_prob(compute_log_prob)
self.add_edge_agent.prepare_log_prob(compute_log_prob)
self.choose_dest_agent.prepare_log_prob(compute_log_prob)
def add_node_and_update(self, a=None):
"""Decide if to add a new atom.
If a new atom should be added, update the graph.
Parameters
----------
a : None or int
If None, a new action will be sampled. If not None,
teacher forcing will be used to enforce the decision of the
corresponding action.
"""
self.count_step()
return self.add_node_agent(a)
def add_edge_or_not(self, a=None):
"""Decide if to add a new bond.
Parameters
----------
a : None or int
If None, a new action will be sampled. If not None,
teacher forcing will be used to enforce the decision of the
corresponding action.
"""
self.count_step()
return self.add_edge_agent(a)
def choose_dest_and_update(self, bond_type, a=None):
"""Choose destination and connect it to the latest atom.
Add edges for both directions and update the graph.
Parameters
----------
bond_type : int
The type of the new bond to add
a : None or int
If None, a new action will be sampled. If not None,
teacher forcing will be used to enforce the decision of the
corresponding action.
"""
self.count_step()
self.choose_dest_agent(bond_type, a)
def get_log_prob(self):
"""Compute the log likelihood for the decision sequence,
typically corresponding to the generation of a molecule.
Returns
-------
torch.tensor consisting of a float only
"""
return torch.cat(self.add_node_agent.log_prob).sum()\
+ torch.cat(self.add_edge_agent.log_prob).sum()\
+ torch.cat(self.choose_dest_agent.log_prob).sum()
def teacher_forcing(self, actions):
"""Generate a molecule according to a sequence of actions.
Parameters
----------
actions : list of 2-tuples of int
actions[t] gives (i, j), the action to execute by DGMG at timestep t.
- If i = 0, j specifies either the type of the atom to add or termination
- If i = 1, j specifies either the type of the bond to add or termination
- If i = 2, j specifies the destination atom id for the bond to add.
With the formulation of DGMG, j must be created before the decision.
"""
stop_node = self.add_node_and_update(a=actions[self.step_count][1])
while not stop_node:
# A new atom was just added.
stop_edge, bond_type = self.add_edge_or_not(a=actions[self.step_count][1])
while not stop_edge:
# A new bond is to be added.
self.choose_dest_and_update(bond_type, a=actions[self.step_count][1])
stop_edge, bond_type = self.add_edge_or_not(a=actions[self.step_count][1])
stop_node = self.add_node_and_update(a=actions[self.step_count][1])
def rollout(self, max_num_steps):
"""Sample a molecule from the distribution learned by DGMG."""
stop_node = self.add_node_and_update()
while (not stop_node) and (self.step_count <= max_num_steps):
stop_edge, bond_type = self.add_edge_or_not()
if self.env.num_atoms() == 1:
stop_edge = True
while (not stop_edge) and (self.step_count <= max_num_steps):
self.choose_dest_and_update(bond_type)
stop_edge, bond_type = self.add_edge_or_not()
stop_node = self.add_node_and_update()
def forward(self, actions=None, rdkit_mol=False, compute_log_prob=False, max_num_steps=400):
"""
Parameters
----------
actions : list of 2-tuples or None.
If actions are not None, generate a molecule according to actions.
Otherwise, a molecule will be generated based on sampled actions.
rdkit_mol : bool
Whether to maintain a Chem.rdchem.Mol object. This brings extra
computational cost, but is necessary if we are interested in
learning the generated molecule.
compute_log_prob : bool
Whether to compute log likelihood
max_num_steps : int
Maximum number of steps allowed. This only comes into effect
during inference and prevents the model from not stopping.
Returns
-------
torch.tensor consisting of a float only, optional
The log likelihood for the actions taken
str, optional
The generated molecule in the form of SMILES
"""
# Initialize an empty molecule
self.step_count = 0
self.env.reset(rdkit_mol=rdkit_mol)
self.prepare_log_prob(compute_log_prob)
if actions is not None:
# A sequence of decisions is given, use teacher forcing
self.teacher_forcing(actions)
else:
# Sample a molecule from the distribution learned by DGMG
self.rollout(max_num_steps)
if compute_log_prob and rdkit_mol:
return self.get_log_prob(), self.env.get_current_smiles()
if compute_log_prob:
return self.get_log_prob()
if rdkit_mol:
return self.env.get_current_smiles()
"""Utilities for using pretrained models."""
import torch
from .dgmg import DGMG
from .gcn import GCNClassifier
from ...data.utils import _get_dgl_url, download
def load_pretrained(model_name):
URL = {
'GCN_Tox21' : 'pre_trained/gcn_tox21.pth',
'DGMG_ChEMBL_canonical' : 'pre_trained/dgmg_ChEMBL_canonical.pth',
'DGMG_ChEMBL_random' : 'pre_trained/dgmg_ChEMBL_random.pth',
'DGMG_ZINC_canonical' : 'pre_trained/dgmg_ZINC_canonical.pth',
'DGMG_ZINC_random' : 'pre_trained/dgmg_ZINC_random.pth'
}
try:
from rdkit import Chem
except ImportError:
pass
def download_and_load_checkpoint(model_name, model, model_postfix,
local_pretrained_path='pre_trained.pth', log=True):
"""Download pretrained model checkpoint
Parameters
----------
model_name : str
Name of the model
model : nn.Module
Instantiated model instance
model_postfix : str
Postfix for pretrained model checkpoint
local_pretrained_path : str
Local name for the downloaded model checkpoint
log : bool
Whether to print progress for model loading
Returns
-------
model : nn.Module
Pretrained model
"""
url_to_pretrained = _get_dgl_url(model_postfix)
local_pretrained_path = '_'.join([model_name, local_pretrained_path])
download(url_to_pretrained, path=local_pretrained_path, log=log)
checkpoint = torch.load(local_pretrained_path)
model.load_state_dict(checkpoint['model_state_dict'])
return model
def load_pretrained(model_name, log=True):
"""Load a pretrained model
Parameters
----------
model_name : str
log : bool
Whether to print progress for model loading
Returns
-------
model
"""
if model_name == "GCN_Tox21":
print('Loading pretrained model...')
url_to_pretrained = _get_dgl_url('pre_trained/gcn_tox21.pth')
local_pretrained_path = 'pre_trained.pth'
download(url_to_pretrained, path=local_pretrained_path)
if model_name not in URL:
return RuntimeError("Cannot find a pretrained model with name {}".format(model_name))
if model_name == 'GCN_Tox21':
model = GCNClassifier(in_feats=74,
gcn_hidden_feats=[64, 64],
n_tasks=12,
classifier_hidden_feats=64)
checkpoint = torch.load(local_pretrained_path)
model.load_state_dict(checkpoint['model_state_dict'])
return model
else:
raise RuntimeError("Cannot find a pretrained model with name {}".format(model_name))
elif model_name.startswith('DGMG'):
if model_name.startswith('DGMG_ChEMBL'):
atom_types = ['O', 'Cl', 'C', 'S', 'F', 'Br', 'N']
elif model_name.startswith('DGMG_ZINC'):
atom_types = ['Br', 'S', 'C', 'P', 'N', 'O', 'F', 'Cl', 'I']
bond_types = [Chem.rdchem.BondType.SINGLE,
Chem.rdchem.BondType.DOUBLE,
Chem.rdchem.BondType.TRIPLE]
model = DGMG(atom_types=atom_types,
bond_types=bond_types,
node_hidden_size=128,
num_prop_rounds=2,
dropout=0.2)
if log:
print('Pretrained model loaded')
return download_and_load_checkpoint(model_name, model, URL[model_name], log=log)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment