Unverified Commit 73b2668f authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[Model Zoo] DGMG for molecule generation (#783)

* DGMG for molecule generation

* Fix CI check

* Fix for CI

* Trial for CI due to shared memory

* Update

* Better interface for dataset configuration

* Update

* Handle corner cases

* Update README

* Fix

* Fix

* Fix

* Fix

* Fix

* Refactor

* Fix

* Fix

* Fix

* Fix

* Fix

* Fix

* Update

* Fix

* Fix

* Fix

* Fix

* Fix

* Finallly
parent 11fb217a
# Learning Deep Generative Models of Graphs (DGMG)
Yujia Li, Oriol Vinyals, Chris Dyer, Razvan Pascanu, and Peter Battaglia.
Learning Deep Generative Models of Graphs. *arXiv preprint arXiv:1803.03324*, 2018.
DGMG generates graphs by progressively adding nodes and edges as below:
![](https://user-images.githubusercontent.com/19576924/48605003-7f11e900-e9b6-11e8-8880-87362348e154.png)
For molecules, the nodes are atoms and the edges are bonds.
**Goal**: Given a set of real molecules, we want to learn the distribution of them and get new molecules
with similar properties. See the `Evaluation` section for more details.
## Dataset
### Preprocessing
With our implementation, this model has several limitations:
1. Information about protonation and chirality are ignored during generation
2. Molecules consisting of `[N+]`, `[O-]`, etc. cannot be generated.
For example, the model can only generate `O=C1NC(=S)NC(=O)C1=CNC1=CC=C(N(=O)O)C=C1O` from
`O=C1NC(=S)NC(=O)C1=CNC1=CC=C([N+](=O)[O-])C=C1O` even with the correct decisions.
To avoid issues about validity and novelty, we filter out these molecules from the dataset.
### ChEMBL
The authors use the [ChEMBL database](https://www.ebi.ac.uk/chembl/). Since they
did not release the code, we use a subset from [Olivecrona et al.](https://github.com/MarcusOlivecrona/REINVENT),
another work on generative modeling.
The authors restrict their dataset to molecules with at most 20 heavy atoms, and used a training/validation
split of 130, 830/26, 166 examples each. We use the same split but need to relax 20 to 23 as we are using
a different subset.
### ZINC
After the pre-processing, we are left with 232464 molecules for training and 5000 molecules for validation.
## Usage
### Training
Training auto-regressive generative models tends to be very slow. According to the authors, they use multiprocess to
speed up training and gpu does not give much speed advantage. We follow their approach and perform multiprocess cpu
training.
To start training, use `train.py` with required arguments
```
-d DATASET, dataset to use (default: None), built-in support exists for ChEMBL, ZINC
-o {random,canonical}, order to generate graphs (default: None)
```
and optional arguments
```
-s SEED, random seed (default: 0)
-np NUM_PROCESSES, number of processes to use (default: 32)
```
Even though multiprocess yields a significant speedup comparing to a single process, the training can still take a long
time (several days). An epoch of training and validation can take up to one hour and a half on our machine. If not
necessary, we recommend users use our pre-trained models.
Meanwhile, we make a checkpoint of our model whenever there is a performance improvement on the validation set so you
do not need to wait until the training terminates.
All training results can be found in `training_results`.
#### Dataset configuration
You can also use your own dataset with additional arguments
```
-tf TRAIN_FILE, Path to a file with one SMILES a line for training
data. This is only necessary if you want to use a new
dataset. (default: None)
-vf VAL_FILE, Path to a file with one SMILES a line for validation
data. This is only necessary if you want to use a new
dataset. (default: None)
```
#### Monitoring
We can monitor the training process with tensorboard as below:
![](https://s3.us-east-2.amazonaws.com/dgl.ai/model_zoo/drug_discovery/dgmg/tensorboard.png)
To use tensorboard, you need to install [tensorboardX](https://github.com/lanpa/tensorboardX) and
[TensorFlow](https://www.tensorflow.org/). You can lunch tensorboard with `tensorboard --logdir=.`
If you are training on a remote server, you can still use it with:
1. Launch it on the remote server with `tensorboard --logdir=. --port=A`
2. In the terminal of your local machine, type `ssh -NfL localhost:B:localhost:A username@your_remote_host_name`
3. Go to the address `localhost:B` in your browser
### Evaluation
To start evaluation, use `eval.py` with required arguments
```
-d DATASET, dataset to use (default: None), built-in support exists for ChEMBL, ZINC
-o {random,canonical}, order to generate graphs, used for naming evaluation directory (default: None)
-p MODEL_PATH, path to saved model (default: None). This is not needed if you want to use pretrained models.
-pr, Whether to use a pre-trained model (default: False)
```
and optional arguments
```
-s SEED, random seed (default: 0)
-ns NUM_SAMPLES, Number of molecules to generate (default: 100000)
-mn MAX_NUM_STEPS, Max number of steps allowed in generated molecules to
ensure termination (default: 400)
-np NUM_PROCESSES, number of processes to use (default: 32)
-gt GENERATION_TIME, max time (seconds) allowed for generation with
multiprocess (default: 600)
```
All evaluation results can be found in `eval_results`.
After the evaluation, 100000 molecules will be generated and stored in `generated_smiles.txt` under `eval_results`
directory, with three statistics logged in `generation_stats.txt` under `eval_results`:
1. `Validity among all` gives the percentage of molecules that are valid
2. `Uniqueness among valid ones` gives the percentage of valid molecules that are unique
3. `Novelty among unique ones` gives the percentage of unique valid molecules that are novel (not seen in training data)
We also provide a jupyter notebook where you can visualize the generated molecules
![](https://s3.us-east-2.amazonaws.com/dgl.ai/model_zoo/drug_discovery/dgmg/DGMG_ZINC_canonical_vis.png)
and compare their property distributions against the training molecule property distributions
![](https://s3.us-east-2.amazonaws.com/dgl.ai/model_zoo/drug_discovery/dgmg/DGMG_ZINC_canonical_dist.png)
Download it with `wget https://s3.us-east-2.amazonaws.com/dgl.ai/model_zoo/drug_discovery/dgmg/eval_jupyter.ipynb`.
### Pre-trained models
Below gives the statistics of pre-trained models. With random order, the training becomes significantly more difficult
as we now have `N^2` data points with `N` molecules.
| Pre-trained model | % valid | % unique among valid | % novel among unique |
| ------------------ | ------- | -------------------- | -------------------- |
| `ChEMBL_canonical` | 78.80 | 99.19 | 98.60 |
| `ChEMBL_random` | 29.09 | 99.87 | 100.00 |
| `ZINC_canonical` | 74.60 | 99.87 | 99.87 |
| `ZINC_random` | 12.37 | 99.38 | 100.00 |
import os
import pickle
import shutil
import torch
from dgl import model_zoo
from utils import MoleculeDataset, set_random_seed, download_data,\
mkdir_p, summarize_molecules, get_unique_smiles, get_novel_smiles
def generate_and_save(log_dir, num_samples, max_num_steps, model):
with open(os.path.join(log_dir, 'generated_smiles.txt'), 'w') as f:
for i in range(num_samples):
with torch.no_grad():
s = model(rdkit_mol=True, max_num_steps=max_num_steps)
f.write(s + '\n')
def prepare_for_evaluation(rank, args):
worker_seed = args['seed'] + rank * 10000
set_random_seed(worker_seed)
torch.set_num_threads(1)
# Setup dataset and data loader
dataset = MoleculeDataset(args['dataset'], subset_id=rank, n_subsets=args['num_processes'])
# Initialize model
if not args['pretrained']:
model = model_zoo.chem.DGMG(atom_types=dataset.atom_types,
bond_types=dataset.bond_types,
node_hidden_size=args['node_hidden_size'],
num_prop_rounds=args['num_propagation_rounds'], dropout=args['dropout'])
model.load_state_dict(torch.load(args['model_path'])['model_state_dict'])
else:
model = model_zoo.chem.load_pretrained('_'.join(['DGMG', args['dataset'], args['order']]), log=False)
model.eval()
worker_num_samples = args['num_samples'] // args['num_processes']
if rank == args['num_processes'] - 1:
worker_num_samples += args['num_samples'] % args['num_processes']
worker_log_dir = os.path.join(args['log_dir'], str(rank))
mkdir_p(worker_log_dir, log=False)
generate_and_save(worker_log_dir, worker_num_samples, args['max_num_steps'], model)
def remove_worker_tmp_dir(args):
for rank in range(args['num_processes']):
worker_path = os.path.join(args['log_dir'], str(rank))
try:
shutil.rmtree(worker_path)
except OSError:
print('Directory {} does not exist!'.format(worker_path))
def aggregate_and_evaluate(args):
print('Merging generated SMILES into a single file...')
smiles = []
for rank in range(args['num_processes']):
with open(os.path.join(args['log_dir'], str(rank), 'generated_smiles.txt'), 'r') as f:
rank_smiles = f.read().splitlines()
smiles.extend(rank_smiles)
with open(os.path.join(args['log_dir'], 'generated_smiles.txt'), 'w') as f:
for s in smiles:
f.write(s + '\n')
print('Removing temporary dirs...')
remove_worker_tmp_dir(args)
# Summarize training molecules
print('Summarizing training molecules...')
train_file = '_'.join([args['dataset'], 'DGMG_train.txt'])
if not os.path.exists(train_file):
download_data(args['dataset'], train_file)
with open(train_file, 'r') as f:
train_smiles = f.read().splitlines()
train_summary = summarize_molecules(train_smiles, args['num_processes'])
with open(os.path.join(args['log_dir'], 'train_summary.pickle'), 'wb') as f:
pickle.dump(train_summary, f)
# Summarize generated molecules
print('Summarizing generated molecules...')
generation_summary = summarize_molecules(smiles, args['num_processes'])
with open(os.path.join(args['log_dir'], 'generation_summary.pickle'), 'wb') as f:
pickle.dump(generation_summary, f)
# Stats computation
print('Preparing generation statistics...')
valid_generated_smiles = generation_summary['smile']
unique_generated_smiles = get_unique_smiles(valid_generated_smiles)
unique_train_smiles = get_unique_smiles(train_summary['smile'])
novel_generated_smiles = get_novel_smiles(unique_generated_smiles, unique_train_smiles)
with open(os.path.join(args['log_dir'], 'generation_stats.txt'), 'w') as f:
f.write('Total number of generated molecules: {:d}\n'.format(len(smiles)))
f.write('Validity among all: {:.4f}\n'.format(
len(valid_generated_smiles) / len(smiles)))
f.write('Uniqueness among valid ones: {:.4f}\n'.format(
len(unique_generated_smiles) / len(valid_generated_smiles)))
f.write('Novelty among unique ones: {:.4f}\n'.format(
len(novel_generated_smiles) / len(unique_generated_smiles)))
if __name__ == '__main__':
import argparse
import datetime
import time
from rdkit import rdBase
from utils import setup
parser = argparse.ArgumentParser(description='Evaluating DGMG for molecule generation',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
# configure
parser.add_argument('-s', '--seed', type=int, default=0, help='random seed')
# dataset and setting
parser.add_argument('-d', '--dataset',
help='dataset to use')
parser.add_argument('-o', '--order', choices=['random', 'canonical'],
help='order to generate graphs, used for naming evaluation directory')
# log
parser.add_argument('-l', '--log-dir', default='./eval_results',
help='folder to save evaluation results')
parser.add_argument('-p', '--model-path', type=str, default=None,
help='path to saved model')
parser.add_argument('-pr', '--pretrained', action='store_true',
help='Whether to use a pre-trained model')
parser.add_argument('-ns', '--num-samples', type=int, default=100000,
help='Number of molecules to generate')
parser.add_argument('-mn', '--max-num-steps', type=int, default=400,
help='Max number of steps allowed in generated molecules to ensure termination')
# multi-process
parser.add_argument('-np', '--num-processes', type=int, default=32,
help='number of processes to use')
parser.add_argument('-gt', '--generation-time', type=int, default=600,
help='max time (seconds) allowed for generation with multiprocess')
args = parser.parse_args()
args = setup(args, train=False)
rdBase.DisableLog('rdApp.error')
t1 = time.time()
if args['num_processes'] == 1:
prepare_for_evaluation(0, args)
else:
import multiprocessing as mp
procs = []
for rank in range(args['num_processes']):
p = mp.Process(target=prepare_for_evaluation, args=(rank, args,))
procs.append(p)
p.start()
while time.time() - t1 <= args['generation_time']:
if any(p.is_alive() for p in procs):
time.sleep(5)
else:
break
else:
print('Timeout, killing all processes.')
for p in procs:
p.terminate()
p.join()
t2 = time.time()
print('It took {} for generation.'.format(
datetime.timedelta(seconds=t2 - t1)))
aggregate_and_evaluate(args)
#
# calculation of synthetic accessibility score as described in:
#
# Estimation of Synthetic Accessibility Score of Drug-like Molecules
# based on Molecular Complexity and Fragment Contributions
# Peter Ertl and Ansgar Schuffenhauer
# Journal of Cheminformatics 1:8 (2009)
# http://www.jcheminf.com/content/1/1/8
#
# several small modifications to the original paper are included
# particularly slightly different formula for marocyclic penalty
# and taking into account also molecule symmetry (fingerprint density)
#
# for a set of 10k diverse molecules the agreement between the original method
# as implemented in PipelinePilot and this implementation is r2 = 0.97
#
# peter ertl & greg landrum, september 2013
#
# A small modification is performed
#
# DGL team, August 2019
#
from __future__ import print_function
import math
import os
from rdkit import Chem
from rdkit.Chem import rdMolDescriptors
from rdkit.six.moves import cPickle
from rdkit.six import iteritems
from dgl.data.utils import download, _get_dgl_url, get_download_dir
_fscores = None
def readFragmentScores(name='fpscores'):
import gzip
global _fscores
fname = '{}.pkl.gz'.format(name)
download(_get_dgl_url(os.path.join('dataset', fname)), path=fname)
_fscores = cPickle.load(gzip.open(fname))
outDict = {}
for i in _fscores:
for j in range(1, len(i)):
outDict[i[j]] = float(i[0])
_fscores = outDict
def numBridgeheadsAndSpiro(mol):
nSpiro = rdMolDescriptors.CalcNumSpiroAtoms(mol)
nBridgehead = rdMolDescriptors.CalcNumBridgeheadAtoms(mol)
return nBridgehead, nSpiro
def calculateScore(m):
if _fscores is None:
readFragmentScores()
# fragment score
# 2 is the *radius* of the circular fingerprint
fp = rdMolDescriptors.GetMorganFingerprint(m, 2)
fps = fp.GetNonzeroElements()
score1 = 0.
nf = 0
for bitId, v in iteritems(fps):
nf += v
sfp = bitId
score1 += _fscores.get(sfp, -4) * v
# We add L63 to avoid ZeroDivisionError.
if nf != 0:
score1 /= nf
# features score
nAtoms = m.GetNumAtoms()
nChiralCenters = len(Chem.FindMolChiralCenters(m, includeUnassigned=True))
ri = m.GetRingInfo()
nBridgeheads, nSpiro = numBridgeheadsAndSpiro(m)
nMacrocycles = 0
for x in ri.AtomRings():
if len(x) > 8:
nMacrocycles += 1
sizePenalty = nAtoms**1.005 - nAtoms
stereoPenalty = math.log10(nChiralCenters + 1)
spiroPenalty = math.log10(nSpiro + 1)
bridgePenalty = math.log10(nBridgeheads + 1)
macrocyclePenalty = 0.
# ---------------------------------------
# This differs from the paper, which defines:
# macrocyclePenalty = math.log10(nMacrocycles+1)
# This form generates better results when 2 or more macrocycles are present
if nMacrocycles > 0:
macrocyclePenalty = math.log10(2)
score2 = 0. - sizePenalty - stereoPenalty - \
spiroPenalty - bridgePenalty - macrocyclePenalty
# correction for the fingerprint density
# not in the original publication, added in version 1.1
# to make highly symmetrical molecules easier to synthetise
score3 = 0.
if nAtoms > len(fps):
score3 = math.log(float(nAtoms) / len(fps)) * .5
sascore = score1 + score2 + score3
# need to transform "raw" value into scale between 1 and 10
min = -4.0
max = 2.5
sascore = 11. - (sascore - min + 1) / (max - min) * 9.
# smooth the 10-end
if sascore > 8.:
sascore = 8. + math.log(sascore + 1. - 9.)
if sascore > 10.:
sascore = 10.0
elif sascore < 1.:
sascore = 1.0
return sascore
def processMols(mols):
print('smiles\tName\tsa_score')
for i, m in enumerate(mols):
if m is None:
continue
s = calculateScore(m)
smiles = Chem.MolToSmiles(m)
print(smiles + "\t" + m.GetProp('_Name') + "\t%3f" % s)
if __name__ == '__main__':
import sys, time
t1 = time.time()
readFragmentScores("fpscores")
t2 = time.time()
suppl = Chem.SmilesMolSupplier(sys.argv[1])
t3 = time.time()
processMols(suppl)
t4 = time.time()
print('Reading took %.2f seconds. Calculating took %.2f seconds' % ((t2 - t1), (t4 - t3)),
file=sys.stderr)
#
# Copyright (c) 2013, Novartis Institutes for BioMedical Research Inc.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following
# disclaimer in the documentation and/or other materials provided
# with the distribution.
# * Neither the name of Novartis Institutes for BioMedical Research Inc.
# nor the names of its contributors may be used to endorse or promote
# products derived from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
"""
Learning Deep Generative Models of Graphs
Paper: https://arxiv.org/pdf/1803.03324.pdf
"""
import datetime
import time
import torch
import torch.distributed as dist
from dgl import model_zoo
from torch.optim import Adam
from torch.utils.data import DataLoader
from utils import MoleculeDataset, Printer, set_random_seed, synchronize, launch_a_process
def evaluate(epoch, model, data_loader, printer):
model.eval()
batch_size = data_loader.batch_size
total_log_prob = 0
with torch.no_grad():
for i, data in enumerate(data_loader):
log_prob = model(actions=data, compute_log_prob=True).detach()
total_log_prob -= log_prob
if printer is not None:
prob = log_prob.detach().exp()
printer.update(epoch + 1, - log_prob / batch_size, prob / batch_size)
return total_log_prob / len(data_loader)
def main(rank, args):
"""
Parameters
----------
rank : int
Subprocess id
args : dict
Configuration
"""
if rank == 0:
t1 = time.time()
set_random_seed(args['seed'])
# Remove the line below will result in problems for multiprocess
torch.set_num_threads(1)
# Setup dataset and data loader
dataset = MoleculeDataset(args['dataset'], args['order'], ['train', 'val'],
subset_id=rank, n_subsets=args['num_processes'])
# Note that currently the batch size for the loaders should only be 1.
train_loader = DataLoader(dataset.train_set, batch_size=args['batch_size'],
shuffle=True, collate_fn=dataset.collate)
val_loader = DataLoader(dataset.val_set, batch_size=args['batch_size'],
shuffle=True, collate_fn=dataset.collate)
if rank == 0:
try:
from tensorboardX import SummaryWriter
writer = SummaryWriter(args['log_dir'])
except ImportError:
print('If you want to use tensorboard, install tensorboardX with pip.')
writer = None
train_printer = Printer(args['nepochs'], len(dataset.train_set), args['batch_size'], writer)
val_printer = Printer(args['nepochs'], len(dataset.val_set), args['batch_size'])
else:
val_printer = None
# Initialize model
model = model_zoo.chem.DGMG(atom_types=dataset.atom_types,
bond_types=dataset.bond_types,
node_hidden_size=args['node_hidden_size'],
num_prop_rounds=args['num_propagation_rounds'],
dropout=args['dropout'])
if args['num_processes'] == 1:
from utils import Optimizer
optimizer = Optimizer(args['lr'], Adam(model.parameters(), lr=args['lr']))
else:
from utils import MultiProcessOptimizer
optimizer = MultiProcessOptimizer(args['num_processes'], args['lr'],
Adam(model.parameters(), lr=args['lr']))
if rank == 0:
t2 = time.time()
best_val_prob = 0
# Training
for epoch in range(args['nepochs']):
model.train()
if rank == 0:
print('Training')
for i, data in enumerate(train_loader):
log_prob = model(actions=data, compute_log_prob=True)
prob = log_prob.detach().exp()
loss_averaged = - log_prob
prob_averaged = prob
optimizer.backward_and_step(loss_averaged)
if rank == 0:
train_printer.update(epoch + 1, loss_averaged.item(), prob_averaged.item())
synchronize(args['num_processes'])
# Validation
val_log_prob = evaluate(epoch, model, val_loader, val_printer)
if args['num_processes'] > 1:
dist.all_reduce(val_log_prob, op=dist.ReduceOp.SUM)
val_log_prob /= args['num_processes']
# Strictly speaking, the computation of probability here is different from what is
# performed on the training set as we first take an average of log likelihood and then
# take the exponentiation. By Jensen's inequality, the resulting value is then a
# lower bound of the real probabilities.
val_prob = (- val_log_prob).exp().item()
val_log_prob = val_log_prob.item()
if val_prob >= best_val_prob:
if rank == 0:
torch.save({'model_state_dict': model.state_dict()}, args['checkpoint_dir'])
print('Old val prob {:.10f} | new val prob {:.10f} | model saved'.format(best_val_prob, val_prob))
best_val_prob = val_prob
elif epoch >= args['warmup_epochs']:
optimizer.decay_lr()
if rank == 0:
print('Validation')
if writer is not None:
writer.add_scalar('validation_log_prob', val_log_prob, epoch)
writer.add_scalar('validation_prob', val_prob, epoch)
writer.add_scalar('lr', optimizer.lr, epoch)
print('Validation log prob {:.4f} | prob {:.10f}'.format(val_log_prob, val_prob))
synchronize(args['num_processes'])
if rank == 0:
t3 = time.time()
print('It took {} to setup.'.format(datetime.timedelta(seconds=t2 - t1)))
print('It took {} to finish training.'.format(datetime.timedelta(seconds=t3 - t2)))
print('--------------------------------------------------------------------------')
print('On average, an epoch takes {}.'.format(datetime.timedelta(
seconds=(t3 - t2) / args['nepochs'])))
if __name__ == '__main__':
import argparse
from utils import setup
parser = argparse.ArgumentParser(description='Training DGMG for molecule generation',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
# configure
parser.add_argument('-s', '--seed', type=int, default=0, help='random seed')
parser.add_argument('-w', '--warmup-epochs', type=int, default=10,
help='Number of epochs where no lr decay is performed.')
# dataset and setting
parser.add_argument('-d', '--dataset',
help='dataset to use')
parser.add_argument('-o', '--order', choices=['random', 'canonical'],
help='order to generate graphs')
parser.add_argument('-tf', '--train-file', type=str, default=None,
help='Path to a file with one SMILES a line for training data. '
'This is only necessary if you want to use a new dataset.')
parser.add_argument('-vf', '--val-file', type=str, default=None,
help='Path to a file with one SMILES a line for validation data. '
'This is only necessary if you want to use a new dataset.')
# log
parser.add_argument('-l', '--log-dir', default='./training_results',
help='folder to save info like experiment configuration')
# multi-process
parser.add_argument('-np', '--num-processes', type=int, default=32,
help='number of processes to use')
parser.add_argument('-mi', '--master-ip', type=str, default='127.0.0.1')
parser.add_argument('-mp', '--master-port', type=str, default='12345')
args = parser.parse_args()
args = setup(args, train=True)
if args['num_processes'] == 1:
main(0, args)
else:
mp = torch.multiprocessing.get_context('spawn')
procs = []
for rank in range(args['num_processes']):
procs.append(mp.Process(target=launch_a_process, args=(rank, args, main), daemon=True))
procs[-1].start()
for p in procs:
p.join()
...@@ -47,7 +47,7 @@ def split_dataset(dataset, frac_list=None, shuffle=False, random_state=None): ...@@ -47,7 +47,7 @@ def split_dataset(dataset, frac_list=None, shuffle=False, random_state=None):
return [Subset(dataset, indices[offset - length:offset]) for offset, length in zip(accumulate(lengths), lengths)] return [Subset(dataset, indices[offset - length:offset]) for offset, length in zip(accumulate(lengths), lengths)]
def download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_ssl=True): def download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_ssl=True, log=True):
"""Download a given URL. """Download a given URL.
Codes borrowed from mxnet/gluon/utils.py Codes borrowed from mxnet/gluon/utils.py
...@@ -68,6 +68,8 @@ def download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_ ...@@ -68,6 +68,8 @@ def download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_
The number of times to attempt downloading in case of failure or non 200 return codes. The number of times to attempt downloading in case of failure or non 200 return codes.
verify_ssl : bool, default True verify_ssl : bool, default True
Verify SSL certificates. Verify SSL certificates.
log : bool, default True
Whether to print the progress for download
Returns Returns
------- -------
...@@ -100,7 +102,8 @@ def download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_ ...@@ -100,7 +102,8 @@ def download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_
# Disable pyling too broad Exception # Disable pyling too broad Exception
# pylint: disable=W0703 # pylint: disable=W0703
try: try:
print('Downloading %s from %s...' % (fname, url)) if log:
print('Downloading %s from %s...' % (fname, url))
r = requests.get(url, stream=True, verify=verify_ssl) r = requests.get(url, stream=True, verify=verify_ssl)
if r.status_code != 200: if r.status_code != 200:
raise RuntimeError("Failed downloading url %s" % url) raise RuntimeError("Failed downloading url %s" % url)
...@@ -119,8 +122,9 @@ def download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_ ...@@ -119,8 +122,9 @@ def download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_
if retries <= 0: if retries <= 0:
raise e raise e
else: else:
print("download failed, retrying, {} attempt{} left" if log:
.format(retries, 's' if retries > 1 else '')) print("download failed, retrying, {} attempt{} left"
.format(retries, 's' if retries > 1 else ''))
return fname return fname
......
...@@ -44,6 +44,24 @@ molecular graph topology, which may be viewed as a learned fingerprint [3]. ...@@ -44,6 +44,24 @@ molecular graph topology, which may be viewed as a learned fingerprint [3].
- **Graph Convolutional Network**: Graph Convolutional Networks (GCN) have been one of the most popular graph neural - **Graph Convolutional Network**: Graph Convolutional Networks (GCN) have been one of the most popular graph neural
networks and they can be easily extended for graph level prediction. networks and they can be easily extended for graph level prediction.
## Generative Models
We use generative models for two different purposes when it comes to molecules:
- **Distribution Learning**: Given a collection of molecules, we want to model their distribution and generate new
molecules with similar properties.
- **Goal-directed Optimization**: Find molecules with desired properties.
For this model zoo, we will only focused on generative models for molecular graphs. There are other generative models
working with alternative representations like SMILES.
Generative models are known to be difficult for evaluation. [GuacaMol](https://github.com/BenevolentAI/guacamol) and
[MOSES](https://github.com/molecularsets/moses) have been two recent efforts to benchmark generative models. There
are also two accompanying review papers that are well written [4], [5].
### Models
- **Deep Generative Models of Graphs (DGMG)**: A very general framework for graph distribution learning by progressively
adding atoms and bonds.
## References ## References
[1] Chen et al. (2018) The rise of deep learning in drug discovery. *Drug Discov Today* 6, 1241-1250. [1] Chen et al. (2018) The rise of deep learning in drug discovery. *Drug Discov Today* 6, 1241-1250.
...@@ -53,3 +71,8 @@ networks and they can be easily extended for graph level prediction. ...@@ -53,3 +71,8 @@ networks and they can be easily extended for graph level prediction.
[3] Duvenaud et al. (2015) Convolutional networks on graphs for learning molecular fingerprints. *Advances in neural [3] Duvenaud et al. (2015) Convolutional networks on graphs for learning molecular fingerprints. *Advances in neural
information processing systems (NeurIPS)*, 2224-2232. information processing systems (NeurIPS)*, 2224-2232.
[4] Brown et al. (2019) GuacaMol: Benchmarking Models for de Novo Molecular Design. *J. Chem. Inf. Model*, 2019, 59, 3,
1096-1108.
[5] Polykovskiy et al. (2019) Molecular Sets (MOSES): A Benchmarking Platform for Molecular Generation Models. *arXiv*.
...@@ -2,4 +2,5 @@ ...@@ -2,4 +2,5 @@
"""Model Zoo Package""" """Model Zoo Package"""
from .gcn import GCNClassifier from .gcn import GCNClassifier
from .dgmg import DGMG
from .pretrain import load_pretrained from .pretrain import load_pretrained
This diff is collapsed.
"""Utilities for using pretrained models.""" """Utilities for using pretrained models."""
import torch import torch
from .dgmg import DGMG
from .gcn import GCNClassifier from .gcn import GCNClassifier
from ...data.utils import _get_dgl_url, download from ...data.utils import _get_dgl_url, download
def load_pretrained(model_name): URL = {
'GCN_Tox21' : 'pre_trained/gcn_tox21.pth',
'DGMG_ChEMBL_canonical' : 'pre_trained/dgmg_ChEMBL_canonical.pth',
'DGMG_ChEMBL_random' : 'pre_trained/dgmg_ChEMBL_random.pth',
'DGMG_ZINC_canonical' : 'pre_trained/dgmg_ZINC_canonical.pth',
'DGMG_ZINC_random' : 'pre_trained/dgmg_ZINC_random.pth'
}
try:
from rdkit import Chem
except ImportError:
pass
def download_and_load_checkpoint(model_name, model, model_postfix,
local_pretrained_path='pre_trained.pth', log=True):
"""Download pretrained model checkpoint
Parameters
----------
model_name : str
Name of the model
model : nn.Module
Instantiated model instance
model_postfix : str
Postfix for pretrained model checkpoint
local_pretrained_path : str
Local name for the downloaded model checkpoint
log : bool
Whether to print progress for model loading
Returns
-------
model : nn.Module
Pretrained model
"""
url_to_pretrained = _get_dgl_url(model_postfix)
local_pretrained_path = '_'.join([model_name, local_pretrained_path])
download(url_to_pretrained, path=local_pretrained_path, log=log)
checkpoint = torch.load(local_pretrained_path)
model.load_state_dict(checkpoint['model_state_dict'])
return model
def load_pretrained(model_name, log=True):
"""Load a pretrained model """Load a pretrained model
Parameters Parameters
---------- ----------
model_name : str model_name : str
log : bool
Whether to print progress for model loading
Returns Returns
------- -------
model model
""" """
if model_name == "GCN_Tox21": if model_name not in URL:
print('Loading pretrained model...') return RuntimeError("Cannot find a pretrained model with name {}".format(model_name))
url_to_pretrained = _get_dgl_url('pre_trained/gcn_tox21.pth')
local_pretrained_path = 'pre_trained.pth' if model_name == 'GCN_Tox21':
download(url_to_pretrained, path=local_pretrained_path)
model = GCNClassifier(in_feats=74, model = GCNClassifier(in_feats=74,
gcn_hidden_feats=[64, 64], gcn_hidden_feats=[64, 64],
n_tasks=12, n_tasks=12,
classifier_hidden_feats=64) classifier_hidden_feats=64)
checkpoint = torch.load(local_pretrained_path) elif model_name.startswith('DGMG'):
model.load_state_dict(checkpoint['model_state_dict']) if model_name.startswith('DGMG_ChEMBL'):
return model atom_types = ['O', 'Cl', 'C', 'S', 'F', 'Br', 'N']
else: elif model_name.startswith('DGMG_ZINC'):
raise RuntimeError("Cannot find a pretrained model with name {}".format(model_name)) atom_types = ['Br', 'S', 'C', 'P', 'N', 'O', 'F', 'Cl', 'I']
bond_types = [Chem.rdchem.BondType.SINGLE,
Chem.rdchem.BondType.DOUBLE,
Chem.rdchem.BondType.TRIPLE]
model = DGMG(atom_types=atom_types,
bond_types=bond_types,
node_hidden_size=128,
num_prop_rounds=2,
dropout=0.2)
if log:
print('Pretrained model loaded')
return download_and_load_checkpoint(model_name, model, URL[model_name], log=log)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment