Unverified Commit 36c7b771 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[LifeSci] Move to Independent Repo (#1592)

* Move LifeSci

* Remove doc
parent 94c67203
.. _apimodelzoo:
Model Zoo
=========
This section introduces complete models for various downstream tasks.
.. contents:: Contents
:local:
Building Blocks
---------------
MLP Predictor
`````````````
.. automodule:: dgllife.model.model_zoo.mlp_predictor
:members:
Molecular Property Prediction
-----------------------------
AttentiveFP Predictor
`````````````````````
.. automodule:: dgllife.model.model_zoo.attentivefp_predictor
:members:
GAT Predictor
`````````````
.. automodule:: dgllife.model.model_zoo.gat_predictor
:members:
GCN Predictor
`````````````
.. automodule:: dgllife.model.model_zoo.gcn_predictor
:members:
MGCN Predictor
``````````````
.. automodule:: dgllife.model.model_zoo.mgcn_predictor
:members:
MPNN Predictor
``````````````
.. automodule:: dgllife.model.model_zoo.mpnn_predictor
:members:
SchNet Predictor
````````````````
.. automodule:: dgllife.model.model_zoo.schnet_predictor
:members:
Weave Predictor
```````````````
.. automodule:: dgllife.model.model_zoo.weave_predictor
:members:
GIN Predictor
`````````````
.. automodule:: dgllife.model.model_zoo.gin_predictor
:members:
Generative Models
-----------------
DGMG
````
.. automodule:: dgllife.model.model_zoo.dgmg
:members:
JTNN
````
.. autoclass:: dgllife.model.model_zoo.jtnn.DGLJTNNVAE
:members:
Reaction Prediction
WLN for Reaction Center Prediction
``````````````````````````````````
.. automodule:: dgllife.model.model_zoo.wln_reaction_center
:members:
WLN for Ranking Candidate Products
``````````````````````````````````
.. automodule:: dgllife.model.model_zoo.wln_reaction_ranking
:members:
Protein-Ligand Binding Affinity Prediction
ACNN
````
.. automodule:: dgllife.model.model_zoo.acnn
:members:
\ No newline at end of file
.. _apiutilscomplexes:
Utils for protein-ligand complexes
==================================
Utilities in DGL-LifeSci for working with protein-ligand complexes.
.. autosummary::
:toctree: ../generated/
dgllife.utils.ACNN_graph_construction_and_featurization
.. _apiutilsmols:
Utils for Molecules
===================
Utilities in DGL-LifeSci for working with molecules.
RDKit Utils
-----------
RDKit utils for loading molecules and accessing their information.
.. autosummary::
:toctree: ../generated/
dgllife.utils.get_mol_3d_coordinates
dgllife.utils.load_molecule
dgllife.utils.multiprocess_load_molecules
Graph Construction
------------------
The modeling of graph neural networks starts with constructing appropriate graph topologies. We provide
three common graph constructions:
* ``bigraph``: Bi-directed graphs corresponding exactly to molecular graphs
* ``complete_graph``: Graphs with all pairs of atoms connected
* ``nearest_neighbor_graph``: Graphs where each atom is connected to its closest (k) atoms based on molecule coordinates
.. autosummary::
:toctree: ../generated/
dgllife.utils.mol_to_graph
dgllife.utils.smiles_to_bigraph
dgllife.utils.mol_to_bigraph
dgllife.utils.smiles_to_complete_graph
dgllife.utils.mol_to_complete_graph
dgllife.utils.k_nearest_neighbors
dgllife.utils.mol_to_nearest_neighbor_graph
dgllife.utils.smiles_to_nearest_neighbor_graph
Featurization for Molecules
---------------------------
To apply graph neural networks, we need to prepare node and edge features for molecules. Intuitively,
they can be developed based on various descriptors (features) of atoms/bonds/molecules. Particularly, we can
work with numerical descriptors directly or use ``one_hot_encoding`` for categorical descriptors. When using
multiple descriptors together, we can simply concatenate them with ``ConcatFeaturizer``.
General Utils
```````````
.. autosummary::
:toctree: ../generated/
dgllife.utils.one_hot_encoding
dgllife.utils.ConcatFeaturizer
Featurization for Nodes
```````````````````````
We consider the following atom descriptors:
* type/atomic number
* degree (excluding neighboring hydrogen atoms)
* total degree (including neighboring hydrogen atoms)
* explicit valence
* implicit valence
* hybridization
* total number of neighboring hydrogen atoms
* formal charge
* number of radical electrons
* aromatic atom
* ring membership
* chirality
* mass
We can employ their numerical values directly or with one-hot encoding.
.. autosummary::
:toctree: ../generated/
dgllife.utils.atom_type_one_hot
dgllife.utils.atomic_number_one_hot
dgllife.utils.atomic_number
dgllife.utils.atom_degree_one_hot
dgllife.utils.atom_degree
dgllife.utils.atom_total_degree_one_hot
dgllife.utils.atom_total_degree
dgllife.utils.atom_explicit_valence_one_hot
dgllife.utils.atom_explicit_valence
dgllife.utils.atom_implicit_valence_one_hot
dgllife.utils.atom_implicit_valence
dgllife.utils.atom_hybridization_one_hot
dgllife.utils.atom_total_num_H_one_hot
dgllife.utils.atom_total_num_H
dgllife.utils.atom_formal_charge_one_hot
dgllife.utils.atom_formal_charge
dgllife.utils.atom_num_radical_electrons_one_hot
dgllife.utils.atom_num_radical_electrons
dgllife.utils.atom_is_aromatic_one_hot
dgllife.utils.atom_is_aromatic
dgllife.utils.atom_is_in_ring_one_hot
dgllife.utils.atom_is_in_ring
dgllife.utils.atom_chiral_tag_one_hot
dgllife.utils.atom_mass
For using featurization methods like above in creating node features:
.. autosummary::
:toctree: ../generated/
dgllife.utils.BaseAtomFeaturizer
dgllife.utils.BaseAtomFeaturizer.feat_size
dgllife.utils.CanonicalAtomFeaturizer
dgllife.utils.CanonicalAtomFeaturizer.feat_size
dgllife.utils.PretrainAtomFeaturizer
Featurization for Edges
```````````````````````
We consider the following bond descriptors:
* type
* conjugated bond
* ring membership
* stereo configuration
.. autosummary::
:toctree: ../generated/
dgllife.utils.bond_type_one_hot
dgllife.utils.bond_is_conjugated_one_hot
dgllife.utils.bond_is_conjugated
dgllife.utils.bond_is_in_ring_one_hot
dgllife.utils.bond_is_in_ring
dgllife.utils.bond_stereo_one_hot
dgllife.utils.bond_direction_one_hot
For using featurization methods like above in creating edge features:
.. autosummary::
:toctree: ../generated/
dgllife.utils.BaseBondFeaturizer
dgllife.utils.BaseBondFeaturizer.feat_size
dgllife.utils.CanonicalBondFeaturizer
dgllife.utils.CanonicalBondFeaturizer.feat_size
dgllife.utils.PretrainBondFeaturizer
.. _apiutilspipeline:
Model Development Pipeline
==========================
.. contents:: Contents
:local:
Model Evaluation
----------------
A utility class for evaluating model performance on (multi-label) supervised learning.
.. autoclass:: dgllife.utils.Meter
:members: update, compute_metric
Early Stopping
--------------
Early stopping is a standard practice for preventing models from overfitting and we provide a utility
class for handling it.
.. autoclass:: dgllife.utils.EarlyStopping
:members:
.. _apiutilssplitters:
Splitting Datasets
==================
We provide multiple splitting methods for datasets.
.. contents:: Contents
:local:
ConsecutiveSplitter
-------------------
.. autoclass:: dgllife.utils.ConsecutiveSplitter
:members: train_val_test_split, k_fold_split
RandomSplitter
--------------
.. autoclass:: dgllife.utils.RandomSplitter
:members: train_val_test_split, k_fold_split
MolecularWeightSplitter
-----------------------
.. autoclass:: dgllife.utils.MolecularWeightSplitter
:members: train_val_test_split, k_fold_split
ScaffoldSplitter
----------------
.. autoclass:: dgllife.utils.ScaffoldSplitter
:members: train_val_test_split, k_fold_split
SingleTaskStratifiedSplitter
----------------------------
.. autoclass:: dgllife.utils.SingleTaskStratifiedSplitter
:members: train_val_test_split, k_fold_split
# -*- coding: utf-8 -*-
#
# Configuration file for the Sphinx documentation builder.
#
# This file does only contain a selection of the most common options. For a
# full list see the documentation:
# http://www.sphinx-doc.org/en/master/config
# -- Path setup --------------------------------------------------------------
# If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.
#
import os
import sys
sys.path.insert(0, os.path.abspath('../../python'))
# -- Project information -----------------------------------------------------
project = 'DGL-LifeSci'
copyright = '2020, DGL Team'
author = 'DGL Team'
import dgllife
version = dgllife.__version__
release = dgllife.__version__
# -- General configuration ---------------------------------------------------
# If your documentation needs a minimal Sphinx version, state it here.
#
# needs_sphinx = '1.0'
# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = [
'sphinx.ext.autodoc',
'sphinx.ext.autosummary',
'sphinx.ext.coverage',
'sphinx.ext.mathjax',
'sphinx.ext.napoleon',
'sphinx.ext.viewcode',
'sphinx.ext.intersphinx',
'sphinx.ext.graphviz',
'sphinx_gallery.gen_gallery',
]
# Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates']
# The suffix(es) of source filenames.
# You can specify multiple suffix as a list of string:
#
source_suffix = ['.rst', '.md']
# The master toctree document.
master_doc = 'index'
# The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages.
#
# This is also used if you do content translation via gettext catalogs.
# Usually you set "language" from the command line for these cases.
language = None
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This pattern also affects html_static_path and html_extra_path.
exclude_patterns = []
# The name of the Pygments (syntax highlighting) style to use.
pygments_style = None
# -- Options for HTML output -------------------------------------------------
# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
#
html_theme = 'sphinx_rtd_theme'
# Theme options are theme-specific and customize the look and feel of a theme
# further. For a list of options available for each theme, see the
# documentation.
#
# html_theme_options = {}
# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ['_static']
# Custom sidebar templates, must be a dictionary that maps document names
# to template names.
#
# The default sidebars (for documents that don't match any pattern) are
# defined by theme itself. Builtin themes are using these templates by
# default: ``['localtoc.html', 'relations.html', 'sourcelink.html',
# 'searchbox.html']``.
#
# html_sidebars = {}
# -- Options for HTMLHelp output ---------------------------------------------
# Output file base name for HTML help builder.
htmlhelp_basename = 'dgllifedoc'
# -- Options for LaTeX output ------------------------------------------------
latex_elements = {
# The paper size ('letterpaper' or 'a4paper').
#
# 'papersize': 'letterpaper',
# The font size ('10pt', '11pt' or '12pt').
#
# 'pointsize': '10pt',
# Additional stuff for the LaTeX preamble.
#
# 'preamble': '',
# Latex figure (float) alignment
#
# 'figure_align': 'htbp',
}
# Grouping the document tree into LaTeX files. List of tuples
# (source start file, target name, title,
# author, documentclass [howto, manual, or own class]).
latex_documents = [
(master_doc, 'dgllife.tex', 'DGL-LifeSci Documentation',
'DGL Team', 'manual'),
]
# -- Options for manual page output ------------------------------------------
# One entry per manual page. List of tuples
# (source start file, name, description, authors, manual section).
man_pages = [
(master_doc, 'dgllife', 'DGL-LifeSci Documentation',
[author], 1)
]
# -- Options for Texinfo output ----------------------------------------------
# Grouping the document tree into Texinfo files. List of tuples
# (source start file, target name, title, author,
# dir menu entry, description, category)
texinfo_documents = [
(master_doc, 'dgllife', 'DGL-LifeSci Documentation',
author, 'dgllife', 'Application library for life science.',
'Miscellaneous'),
]
# -- Options for Epub output -------------------------------------------------
# Bibliographic Dublin Core info.
epub_title = project
# The unique identifier of the text. This can be a ISBN number
# or the project homepage.
#
# epub_identifier = ''
# A unique identification for the text.
#
# epub_uid = ''
# A list of files that should not be packed into the epub file.
epub_exclude_files = ['search.html']
# -- Extension configuration -------------------------------------------------
autosummary_generate = True
intersphinx_mapping = {
'python': ('https://docs.python.org/{.major}'.format(sys.version_info), None),
'numpy': ('http://docs.scipy.org/doc/numpy/', None),
'scipy': ('http://docs.scipy.org/doc/scipy/reference', None),
'matplotlib': ('http://matplotlib.org/', None),
'networkx' : ('https://networkx.github.io/documentation/stable', None),
}
# sphinx gallery configurations
from sphinx_gallery.sorting import FileNameSortKey
examples_dirs = [] # path to find sources
gallery_dirs = [] # path to generate docs
reference_url = {
'dgllife' : None,
'numpy': 'http://docs.scipy.org/doc/numpy/',
'scipy': 'http://docs.scipy.org/doc/scipy/reference',
'matplotlib': 'http://matplotlib.org/',
'networkx' : 'https://networkx.github.io/documentation/stable',
}
sphinx_gallery_conf = {
'backreferences_dir' : 'generated/backreferences',
'doc_module' : ('dgllife', 'numpy'),
'examples_dirs' : examples_dirs,
'gallery_dirs' : gallery_dirs,
'within_subsection_order' : FileNameSortKey,
'filename_pattern' : '.py',
'download_all_examples' : False,
}
DGL-LifeSci: Bringing Graph Neural Networks to Chemistry and Biology
===========================================================================================
DGL-LifeSci is a python package for applying graph neural networks to various tasks in chemistry
and biology, on top of PyTorch and DGL. It provides:
* Various utilities for data processing, training and evaluation.
* Efficient and flexible model implementations.
* Pre-trained models for use without training from scratch.
We cover various applications in our
`examples <https://github.com/dmlc/dgl/tree/master/apps/life_sci/examples>`_, including:
* `Molecular property prediction <https://github.com/dmlc/dgl/tree/master/apps/life_sci/examples/property_prediction>`_
* `Generative models <https://github.com/dmlc/dgl/tree/master/apps/life_sci/examples/generative_models>`_
* `Protein-ligand binding affinity prediction <https://github.com/dmlc/dgl/tree/master/apps/life_sci/examples/binding_affinity_prediction>`_
* `Reaction prediction <https://github.com/dmlc/dgl/tree/master/apps/life_sci/examples/reaction_prediction>`_
Get Started
------------
Follow the :doc:`instructions<install/index>` to install DGL.
.. toctree::
:maxdepth: 1
:caption: Installation
:hidden:
:glob:
install/index
.. toctree::
:maxdepth: 2
:caption: API Reference
:hidden:
:glob:
api/utils.mols
api/utils.splitters
api/utils.pipeline
api/utils.complexes
api/data
api/model.pretrain
api/model.gnn
api/model.readout
api/model.zoo
Free software
-------------
DGL-LifeSci is free software; you can redistribute it and/or modify it under the terms
of the Apache License 2.0. We welcome contributions. Join us on `GitHub <https://github.com/dmlc/dgl/tree/master/apps/life_sci>`_.
Index
-----
* :ref:`genindex`
Install DGL-LifeSci
===================
This topic explains how to install DGL-LifeSci. We recommend installing DGL-LifeSci by using ``conda`` or ``pip``.
System requirements
-------------------
DGL-LifeSci works with the following operating systems:
* Ubuntu 16.04
* macOS X
* Windows 10
DGL-LifeSci requires:
* Python 3.6 or later
* `DGL 0.4.3 or later <https://www.dgl.ai/pages/start.html>`_
* `PyTorch 1.2.0 or later <https://pytorch.org/>`_
If you have just installed DGL, the first time you use it, a message will pop up as follows:
.. code:: bash
DGL does not detect a valid backend option. Which backend would you like to work with?
Backend choice (pytorch, mxnet or tensorflow):
and you need to enter ``pytorch``.
Additionally, we require **RDKit 2018.09.3** for cheminformatics. We recommend installing it with
.. code:: bash
conda install -c conda-forge rdkit==2018.09.3
Other verions of RDKit are not tested.
Install from conda
----------------------
If ``conda`` is not yet installed, get either `miniconda <https://conda.io/miniconda.html>`_ or
the full `anaconda <https://www.anaconda.com/download/>`_.
.. code:: bash
conda install -c dglteam dgllife
Install from pip
----------------
.. code:: bash
pip install dgllife
.. _install-from-source:
Install from source
-------------------
To use the latest experimental features,
.. code:: bash
git clone https://github.com/dmlc/dgl.git
cd apps/life_sci/python
python setup.py install
# Work Implemented in DGL-LifeSci
We provide various examples across 3 applications -- property prediction, generative models and protein-ligand binding affinity prediction.
## Datasets/Benchmarks
- MoleculeNet: A Benchmark for Molecular Machine Learning [[paper]](https://arxiv.org/abs/1703.00564), [[website]](http://moleculenet.ai/)
- [Tox21 with DGL](../python/dgllife/data/tox21.py)
- [PDBBind with DGL](../python/dgllife/data/pdbbind.py)
- Alchemy: A Quantum Chemistry Dataset for Benchmarking AI Models [[paper]](https://arxiv.org/abs/1906.09427), [[github]](https://github.com/tencent-alchemy/Alchemy)
- [Alchemy with DGL](../python/dgllife/data/alchemy.py)
## Property Prediction
- Molecular graph convolutions: moving beyond fingerprints (Weave) [[paper]](https://arxiv.org/abs/1603.00856), [[github]](https://github.com/deepchem/deepchem)
- [Weave Predictor with DGL](../python/dgllife/model/model_zoo/weave_predictor.py)
- [Example for Molecule Classification](property_prediction/classification.py)
- Semi-Supervised Classification with Graph Convolutional Networks (GCN) [[paper]](https://arxiv.org/abs/1609.02907), [[github]](https://github.com/tkipf/gcn)
- [GCN-Based Predictor with DGL](../python/dgllife/model/model_zoo/gcn_predictor.py)
- [Example for Molecule Classification](property_prediction/classification.py)
- Graph Attention Networks (GAT) [[paper]](https://arxiv.org/abs/1710.10903), [[github]](https://github.com/PetarV-/GAT)
- [GAT-Based Predictor with DGL](../python/dgllife/model/model_zoo/gat_predictor.py)
- [Example for Molecule Classification](property_prediction/classification.py)
- SchNet: A continuous-filter convolutional neural network for modeling quantum interactions [[paper]](https://arxiv.org/abs/1706.08566), [[github]](https://github.com/atomistic-machine-learning/SchNet)
- [SchNet with DGL](../python/dgllife/model/model_zoo/schnet_predictor.py)
- [Example for Molecule Regression](property_prediction/regression.py)
- Molecular Property Prediction: A Multilevel Quantum Interactions Modeling Perspective (MGCN) [[paper]](https://arxiv.org/abs/1906.11081)
- [MGCN with DGL](../python/dgllife/model/model_zoo/mgcn_predictor.py)
- [Example for Molecule Regression](property_prediction/regression.py)
- Neural Message Passing for Quantum Chemistry (MPNN) [[paper]](https://arxiv.org/abs/1704.01212), [[github]](https://github.com/brain-research/mpnn)
- [MPNN with DGL](../python/dgllife/model/model_zoo/mpnn_predictor.py)
- [Example for Molecule Regression](property_prediction/regression.py)
- Pushing the Boundaries of Molecular Representation for Drug Discovery with the Graph Attention Mechanism (AttentiveFP) [[paper]](https://pubs.acs.org/doi/abs/10.1021/acs.jmedchem.9b00959)
- [AttentiveFP with DGL](../python/dgllife/model/model_zoo/attentivefp_predictor.py)
- [Example for Molecule Regression](property_prediction/regression.py)
## Generative Models
- Learning Deep Generative Models of Graphs (DGMG) [[paper]](https://arxiv.org/abs/1803.03324)
- [DGMG with DGL](../python/dgllife/model/model_zoo/dgmg.py)
- [Example Training Script](generative_models/dgmg)
- Junction Tree Variational Autoencoder for Molecular Graph Generation (JTNN) [[paper]](https://arxiv.org/abs/1802.04364)
- [JTNN with DGL](../python/dgllife/model/model_zoo/jtnn)
- [Example Training Script](generative_models/jtnn)
## Binding Affinity Prediction
- Atomic Convolutional Networks for Predicting Protein-Ligand Binding Affinity (ACNN) [[paper]](https://arxiv.org/abs/1703.10603), [[github]](https://github.com/deepchem/deepchem/tree/master/contrib/atomicconv)
- [ACNN with DGL](../python/dgllife/model/model_zoo/acnn.py)
- [Example Training Script](binding_affinity_prediction)
## Reaction Prediction
- A graph-convolutional neural network model for the prediction of chemical reactivity [[paper]](https://pubs.rsc.org/en/content/articlelanding/2019/sc/c8sc04228d#!divAbstract), [[github]](https://github.com/connorcoley/rexgen_direct)
- An earlier version was published in NeurIPS 2017 as "Predicting Organic Reaction Outcomes with Weisfeiler-Lehman Network" [[paper]](https://arxiv.org/abs/1709.04555)
- [WLN with DGL for Reaction Center Prediction](../python/dgllife/model/model_zoo/wln_reaction_center.py)
- [Example Script](reaction_prediction/rexgen_direct)
# Binding Affinity Prediction
## Datasets
- **PDBBind**: The PDBBind dataset in MoleculeNet [1] processed from the PDBBind database. The PDBBind
database consists of experimentally measured binding affinities for bio-molecular complexes [2], [3].
It provides detailed 3D Cartesian coordinates of both ligands and their target proteins derived from
experimental(e.g., X-ray crystallography) measurements. The availability of coordinates of the
protein-ligand complexes permits structure-based featurization that is aware of the protein-ligand
binding geometry. The authors of [1] use the "refined" and "core" subsets of the database [4], more carefully
processed for data artifacts, as additional benchmarking targets.
## Models
- **Atomic Convolutional Networks (ACNN)** [5]: Constructs nearest neighbor graphs separately for the ligand, protein and complex
based on the 3D coordinates of the atoms and predicts the binding free energy.
## Usage
Use `main.py` with arguments
```
-m {ACNN}, Model to use
-d {PDBBind_core_pocket_random, PDBBind_core_pocket_scaffold, PDBBind_core_pocket_stratified,
PDBBind_core_pocket_temporal, PDBBind_refined_pocket_random, PDBBind_refined_pocket_scaffold,
PDBBind_refined_pocket_stratified, PDBBind_refined_pocket_temporal}, dataset and splitting method to use
```
## Performance
### PDBBind
#### ACNN
| Subset | Splitting Method | Test MAE | Test R2 |
| ------- | ---------------- | -------- | ------- |
| Core | Random | 1.7688 | 0.1511 |
| Core | Scaffold | 2.5420 | 0.1471 |
| Core | Stratified | 1.7419 | 0.1520 |
| Core | Temporal | 1.9543 | 0.1640 |
| Refined | Random | 1.1948 | 0.4373 |
| Refined | Scaffold | 1.4021 | 0.2086 |
| Refined | Stratified | 1.6376 | 0.3050 |
| Refined | Temporal | 1.2457 | 0.3438 |
## Speed
### ACNN
Comparing to the [DeepChem's implementation](https://github.com/joegomes/deepchem/tree/acdc), we achieve a speedup by
roughly 3.3 for training time per epoch (from 1.40s to 0.42s). If we do not care about
randomness introduced by some kernel optimization, we can achieve a speedup by roughly 4.4 (from 1.40s to 0.32s).
## References
[1] Wu et al. (2017) MoleculeNet: a benchmark for molecular machine learning. *Chemical Science* 9, 513-530.
[2] Wang et al. (2004) The PDBbind database: collection of binding affinities for protein-ligand complexes
with known three-dimensional structures. *J Med Chem* 3;47(12):2977-80.
[3] Wang et al. (2005) The PDBbind database: methodologies and updates. *J Med Chem* 16;48(12):4111-9.
[4] Liu et al. (2015) PDB-wide collection of binding data: current status of the PDBbind database. *Bioinformatics* 1;31(3):405-12.
[5] Gomes et al. (2017) Atomic Convolutional Networks for Predicting Protein-Ligand Binding Affinity. *arXiv preprint arXiv:1703.10603*.
import numpy as np
import torch
ACNN_PDBBind_core_pocket_random = {
'dataset': 'PDBBind',
'subset': 'core',
'load_binding_pocket': True,
'random_seed': 123,
'frac_train': 0.8,
'frac_val': 0.,
'frac_test': 0.2,
'batch_size': 24,
'shuffle': False,
'hidden_sizes': [32, 32, 16],
'weight_init_stddevs': [1. / float(np.sqrt(32)), 1. / float(np.sqrt(32)),
1. / float(np.sqrt(16)), 0.01],
'dropouts': [0., 0., 0.],
'atomic_numbers_considered': torch.tensor([
1., 6., 7., 8., 9., 11., 12., 15., 16., 17., 20., 25., 30., 35., 53.]),
'radial': [[12.0], [0.0, 4.0, 8.0], [4.0]],
'lr': 0.001,
'num_epochs': 120,
'metrics': ['r2', 'mae'],
'split': 'random'
}
ACNN_PDBBind_core_pocket_scaffold = {
'dataset': 'PDBBind',
'subset': 'core',
'load_binding_pocket': True,
'random_seed': 123,
'frac_train': 0.8,
'frac_val': 0.,
'frac_test': 0.2,
'batch_size': 24,
'shuffle': False,
'hidden_sizes': [32, 32, 16],
'weight_init_stddevs': [1. / float(np.sqrt(32)), 1. / float(np.sqrt(32)),
1. / float(np.sqrt(16)), 0.01],
'dropouts': [0., 0., 0.],
'atomic_numbers_considered': torch.tensor([
1., 6., 7., 8., 9., 11., 12., 15., 16., 17., 20., 25., 30., 35., 53.]),
'radial': [[12.0], [0.0, 4.0, 8.0], [4.0]],
'lr': 0.001,
'num_epochs': 170,
'metrics': ['r2', 'mae'],
'split': 'scaffold'
}
ACNN_PDBBind_core_pocket_stratified = {
'dataset': 'PDBBind',
'subset': 'core',
'load_binding_pocket': True,
'random_seed': 123,
'frac_train': 0.8,
'frac_val': 0.,
'frac_test': 0.2,
'batch_size': 24,
'shuffle': False,
'hidden_sizes': [32, 32, 16],
'weight_init_stddevs': [1. / float(np.sqrt(32)), 1. / float(np.sqrt(32)),
1. / float(np.sqrt(16)), 0.01],
'dropouts': [0., 0., 0.],
'atomic_numbers_considered': torch.tensor([
1., 6., 7., 8., 9., 11., 12., 15., 16., 17., 20., 25., 30., 35., 53.]),
'radial': [[12.0], [0.0, 4.0, 8.0], [4.0]],
'lr': 0.001,
'num_epochs': 110,
'metrics': ['r2', 'mae'],
'split': 'stratified'
}
ACNN_PDBBind_core_pocket_temporal = {
'dataset': 'PDBBind',
'subset': 'core',
'load_binding_pocket': True,
'random_seed': 123,
'frac_train': 0.8,
'frac_val': 0.,
'frac_test': 0.2,
'batch_size': 24,
'shuffle': False,
'hidden_sizes': [32, 32, 16],
'weight_init_stddevs': [1. / float(np.sqrt(32)), 1. / float(np.sqrt(32)),
1. / float(np.sqrt(16)), 0.01],
'dropouts': [0., 0., 0.],
'atomic_numbers_considered': torch.tensor([
1., 6., 7., 8., 9., 11., 12., 15., 16., 17., 20., 25., 30., 35., 53.]),
'radial': [[12.0], [0.0, 4.0, 8.0], [4.0]],
'lr': 0.001,
'num_epochs': 80,
'metrics': ['r2', 'mae'],
'split': 'temporal'
}
ACNN_PDBBind_refined_pocket_random = {
'dataset': 'PDBBind',
'subset': 'refined',
'load_binding_pocket': True,
'random_seed': 123,
'frac_train': 0.8,
'frac_val': 0.,
'frac_test': 0.2,
'batch_size': 24,
'shuffle': False,
'hidden_sizes': [128, 128, 64],
'weight_init_stddevs': [0.125, 0.125, 0.177, 0.01],
'dropouts': [0.4, 0.4, 0.],
'atomic_numbers_considered': torch.tensor([
1., 6., 7., 8., 9., 11., 12., 15., 16., 17., 19., 20., 25., 26., 27., 28.,
29., 30., 34., 35., 38., 48., 53., 55., 80.]),
'radial': [[12.0], [0.0, 2.0, 4.0, 6.0, 8.0], [4.0]],
'lr': 0.001,
'num_epochs': 200,
'metrics': ['r2', 'mae'],
'split': 'random'
}
ACNN_PDBBind_refined_pocket_scaffold = {
'dataset': 'PDBBind',
'subset': 'refined',
'load_binding_pocket': True,
'random_seed': 123,
'frac_train': 0.8,
'frac_val': 0.,
'frac_test': 0.2,
'batch_size': 24,
'shuffle': False,
'hidden_sizes': [128, 128, 64],
'weight_init_stddevs': [0.125, 0.125, 0.177, 0.01],
'dropouts': [0.4, 0.4, 0.],
'atomic_numbers_considered': torch.tensor([
1., 6., 7., 8., 9., 11., 12., 15., 16., 17., 19., 20., 25., 26., 27., 28.,
29., 30., 34., 35., 38., 48., 53., 55., 80.]),
'radial': [[12.0], [0.0, 2.0, 4.0, 6.0, 8.0], [4.0]],
'lr': 0.001,
'num_epochs': 350,
'metrics': ['r2', 'mae'],
'split': 'scaffold'
}
ACNN_PDBBind_refined_pocket_stratified = {
'dataset': 'PDBBind',
'subset': 'refined',
'load_binding_pocket': True,
'random_seed': 123,
'frac_train': 0.8,
'frac_val': 0.,
'frac_test': 0.2,
'batch_size': 24,
'shuffle': False,
'hidden_sizes': [128, 128, 64],
'weight_init_stddevs': [0.125, 0.125, 0.177, 0.01],
'dropouts': [0.4, 0.4, 0.],
'atomic_numbers_considered': torch.tensor([
1., 6., 7., 8., 9., 11., 12., 15., 16., 17., 19., 20., 25., 26., 27., 28.,
29., 30., 34., 35., 38., 48., 53., 55., 80.]),
'radial': [[12.0], [0.0, 2.0, 4.0, 6.0, 8.0], [4.0]],
'lr': 0.001,
'num_epochs': 400,
'metrics': ['r2', 'mae'],
'split': 'stratified'
}
ACNN_PDBBind_refined_pocket_temporal = {
'dataset': 'PDBBind',
'subset': 'refined',
'load_binding_pocket': True,
'random_seed': 123,
'frac_train': 0.8,
'frac_val': 0.,
'frac_test': 0.2,
'batch_size': 24,
'shuffle': False,
'hidden_sizes': [128, 128, 64],
'weight_init_stddevs': [0.125, 0.125, 0.177, 0.01],
'dropouts': [0.4, 0.4, 0.],
'atomic_numbers_considered': torch.tensor([
1., 6., 7., 8., 9., 11., 12., 15., 16., 17., 19., 20., 25., 26., 27., 28.,
29., 30., 34., 35., 38., 48., 53., 55., 80.]),
'radial': [[12.0], [0.0, 2.0, 4.0, 6.0, 8.0], [4.0]],
'lr': 0.001,
'num_epochs': 350,
'metrics': ['r2', 'mae'],
'split': 'temporal'
}
experiment_configures = {
'ACNN_PDBBind_core_pocket_random': ACNN_PDBBind_core_pocket_random,
'ACNN_PDBBind_core_pocket_scaffold': ACNN_PDBBind_core_pocket_scaffold,
'ACNN_PDBBind_core_pocket_stratified': ACNN_PDBBind_core_pocket_stratified,
'ACNN_PDBBind_core_pocket_temporal': ACNN_PDBBind_core_pocket_temporal,
'ACNN_PDBBind_refined_pocket_random': ACNN_PDBBind_refined_pocket_random,
'ACNN_PDBBind_refined_pocket_scaffold': ACNN_PDBBind_refined_pocket_scaffold,
'ACNN_PDBBind_refined_pocket_stratified': ACNN_PDBBind_refined_pocket_stratified,
'ACNN_PDBBind_refined_pocket_temporal': ACNN_PDBBind_refined_pocket_temporal
}
def get_exp_configure(exp_name):
return experiment_configures[exp_name]
import torch
import torch.nn as nn
from dgllife.utils.eval import Meter
from torch.utils.data import DataLoader
from utils import set_random_seed, load_dataset, collate, load_model
def update_msg_from_scores(msg, scores):
for metric, score in scores.items():
msg += ', {} {:.4f}'.format(metric, score)
return msg
def run_a_train_epoch(args, epoch, model, data_loader,
loss_criterion, optimizer):
model.train()
train_meter = Meter(args['train_mean'], args['train_std'])
epoch_loss = 0
for batch_id, batch_data in enumerate(data_loader):
indices, ligand_mols, protein_mols, bg, labels = batch_data
labels, bg = labels.to(args['device']), bg.to(args['device'])
prediction = model(bg)
loss = loss_criterion(prediction, (labels - args['train_mean']) / args['train_std'])
epoch_loss += loss.data.item() * len(indices)
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_meter.update(prediction, labels)
avg_loss = epoch_loss / len(data_loader.dataset)
total_scores = {metric: train_meter.compute_metric(metric, 'mean')
for metric in args['metrics']}
msg = 'epoch {:d}/{:d}, training | loss {:.4f}'.format(
epoch + 1, args['num_epochs'], avg_loss)
msg = update_msg_from_scores(msg, total_scores)
print(msg)
def run_an_eval_epoch(args, model, data_loader):
model.eval()
eval_meter = Meter(args['train_mean'], args['train_std'])
with torch.no_grad():
for batch_id, batch_data in enumerate(data_loader):
indices, ligand_mols, protein_mols, bg, labels = batch_data
labels, bg = labels.to(args['device']), bg.to(args['device'])
prediction = model(bg)
eval_meter.update(prediction, labels)
total_scores = {metric: eval_meter.compute_metric(metric, 'mean')
for metric in args['metrics']}
return total_scores
def main(args):
args['device'] = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
set_random_seed(args['random_seed'])
dataset, train_set, test_set = load_dataset(args)
args['train_mean'] = train_set.labels_mean.to(args['device'])
args['train_std'] = train_set.labels_std.to(args['device'])
train_loader = DataLoader(dataset=train_set,
batch_size=args['batch_size'],
shuffle=False,
collate_fn=collate)
test_loader = DataLoader(dataset=test_set,
batch_size=args['batch_size'],
shuffle=True,
collate_fn=collate)
model = load_model(args)
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=args['lr'])
model.to(args['device'])
for epoch in range(args['num_epochs']):
run_a_train_epoch(args, epoch, model, train_loader, loss_fn, optimizer)
test_scores = run_an_eval_epoch(args, model, test_loader)
test_msg = update_msg_from_scores('test results', test_scores)
print(test_msg)
if __name__ == '__main__':
import argparse
from configure import get_exp_configure
parser = argparse.ArgumentParser(description='Protein-Ligand Binding Affinity Prediction')
parser.add_argument('-m', '--model', type=str, choices=['ACNN'],
help='Model to use')
parser.add_argument('-d', '--dataset', type=str,
choices=['PDBBind_core_pocket_random', 'PDBBind_core_pocket_scaffold',
'PDBBind_core_pocket_stratified', 'PDBBind_core_pocket_temporal',
'PDBBind_refined_pocket_random', 'PDBBind_refined_pocket_scaffold',
'PDBBind_refined_pocket_stratified', 'PDBBind_refined_pocket_temporal'],
help='Dataset to use')
args = parser.parse_args().__dict__
args['exp'] = '_'.join([args['model'], args['dataset']])
args.update(get_exp_configure(args['exp']))
main(args)
import dgl
import numpy as np
import random
import torch
from dgl.data.utils import Subset
from dgllife.data import PDBBind
from dgllife.model import ACNN
from dgllife.utils import RandomSplitter, ScaffoldSplitter, SingleTaskStratifiedSplitter
from itertools import accumulate
def set_random_seed(seed=0):
"""Set random seed.
Parameters
----------
seed : int
Random seed to use. Default to 0.
"""
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
def load_dataset(args):
"""Load the dataset.
Parameters
----------
args : dict
Input arguments.
Returns
-------
dataset
Full dataset.
train_set
Train subset of the dataset.
val_set
Validation subset of the dataset.
"""
assert args['dataset'] in ['PDBBind'], 'Unexpected dataset {}'.format(args['dataset'])
if args['dataset'] == 'PDBBind':
dataset = PDBBind(subset=args['subset'],
load_binding_pocket=args['load_binding_pocket'],
zero_padding=True)
# No validation set is used and frac_val = 0.
if args['split'] == 'random':
train_set, _, test_set = RandomSplitter.train_val_test_split(
dataset,
frac_train=args['frac_train'],
frac_val=args['frac_val'],
frac_test=args['frac_test'],
random_state=args['random_seed'])
elif args['split'] == 'scaffold':
train_set, _, test_set = ScaffoldSplitter.train_val_test_split(
dataset,
mols=dataset.ligand_mols,
sanitize=False,
frac_train=args['frac_train'],
frac_val=args['frac_val'],
frac_test=args['frac_test'])
elif args['split'] == 'stratified':
train_set, _, test_set = SingleTaskStratifiedSplitter.train_val_test_split(
dataset,
labels=dataset.labels,
task_id=0,
frac_train=args['frac_train'],
frac_val=args['frac_val'],
frac_test=args['frac_test'],
random_state=args['random_seed'])
elif args['split'] == 'temporal':
years = dataset.df['release_year'].values.astype(np.float32)
indices = np.argsort(years).tolist()
frac_list = np.array([args['frac_train'], args['frac_val'], args['frac_test']])
num_data = len(dataset)
lengths = (num_data * frac_list).astype(int)
lengths[-1] = num_data - np.sum(lengths[:-1])
train_set, val_set, test_set = [
Subset(dataset, list(indices[offset - length:offset]))
for offset, length in zip(accumulate(lengths), lengths)]
else:
raise ValueError('Expect the splitting method '
'to be "random" or "scaffold", got {}'.format(args['split']))
train_labels = torch.stack([train_set.dataset.labels[i] for i in train_set.indices])
train_set.labels_mean = train_labels.mean(dim=0)
train_set.labels_std = train_labels.std(dim=0)
return dataset, train_set, test_set
def collate(data):
indices, ligand_mols, protein_mols, graphs, labels = map(list, zip(*data))
bg = dgl.batch_hetero(graphs)
for nty in bg.ntypes:
bg.set_n_initializer(dgl.init.zero_initializer, ntype=nty)
for ety in bg.canonical_etypes:
bg.set_e_initializer(dgl.init.zero_initializer, etype=ety)
labels = torch.stack(labels, dim=0)
return indices, ligand_mols, protein_mols, bg, labels
def load_model(args):
assert args['model'] in ['ACNN'], 'Unexpected model {}'.format(args['model'])
if args['model'] == 'ACNN':
model = ACNN(hidden_sizes=args['hidden_sizes'],
weight_init_stddevs=args['weight_init_stddevs'],
dropouts=args['dropouts'],
features_to_use=args['atomic_numbers_considered'],
radial=args['radial'])
return model
# Learning Deep Generative Models of Graphs (DGMG)
Yujia Li, Oriol Vinyals, Chris Dyer, Razvan Pascanu, and Peter Battaglia.
Learning Deep Generative Models of Graphs. *arXiv preprint arXiv:1803.03324*, 2018.
DGMG generates graphs by progressively adding nodes and edges as below:
![](https://user-images.githubusercontent.com/19576924/48605003-7f11e900-e9b6-11e8-8880-87362348e154.png)
For molecules, the nodes are atoms and the edges are bonds.
**Goal**: Given a set of real molecules, we want to learn the distribution of them and get new molecules
with similar properties. See the `Evaluation` section for more details.
## Dataset
### Preprocessing
With our implementation, this model has several limitations:
1. Information about protonation and chirality are ignored during generation
2. Molecules consisting of `[N+]`, `[O-]`, etc. cannot be generated.
For example, the model can only generate `O=C1NC(=S)NC(=O)C1=CNC1=CC=C(N(=O)O)C=C1O` from
`O=C1NC(=S)NC(=O)C1=CNC1=CC=C([N+](=O)[O-])C=C1O` even with the correct decisions.
To avoid issues about validity and novelty, we filter out these molecules from the dataset.
### ChEMBL
The authors use the [ChEMBL database](https://www.ebi.ac.uk/chembl/). Since they
did not release the code, we use a subset from [Olivecrona et al.](https://github.com/MarcusOlivecrona/REINVENT),
another work on generative modeling.
The authors restrict their dataset to molecules with at most 20 heavy atoms, and used a training/validation
split of 130, 830/26, 166 examples each. We use the same split but need to relax 20 to 23 as we are using
a different subset.
### ZINC
After the pre-processing, we are left with 232464 molecules for training and 5000 molecules for validation.
## Usage
### Training
Training auto-regressive generative models tends to be very slow. According to the authors, they use multiprocess to
speed up training and gpu does not give much speed advantage. We follow their approach and perform multiprocess cpu
training.
To start training, use `train.py` with required arguments
```
-d DATASET, dataset to use (default: None), built-in support exists for ChEMBL, ZINC
-o {random,canonical}, order to generate graphs (default: None)
```
and optional arguments
```
-s SEED, random seed (default: 0)
-np NUM_PROCESSES, number of processes to use (default: 32)
```
Even though multiprocess yields a significant speedup comparing to a single process, the training can still take a long
time (several days). An epoch of training and validation can take up to one hour and a half on our machine. If not
necessary, we recommend users use our pre-trained models.
Meanwhile, we make a checkpoint of our model whenever there is a performance improvement on the validation set so you
do not need to wait until the training terminates.
All training results can be found in `training_results`.
#### Dataset configuration
You can also use your own dataset with additional arguments
```
-tf TRAIN_FILE, Path to a file with one SMILES a line for training
data. This is only necessary if you want to use a new
dataset. (default: None)
-vf VAL_FILE, Path to a file with one SMILES a line for validation
data. This is only necessary if you want to use a new
dataset. (default: None)
```
#### Monitoring
We can monitor the training process with tensorboard as below:
![](https://data.dgl.ai/dgllife/dgmg/tensorboard.png)
To use tensorboard, you need to install [tensorboardX](https://github.com/lanpa/tensorboardX) and
[TensorFlow](https://www.tensorflow.org/). You can lunch tensorboard with `tensorboard --logdir=.`
If you are training on a remote server, you can still use it with:
1. Launch it on the remote server with `tensorboard --logdir=. --port=A`
2. In the terminal of your local machine, type `ssh -NfL localhost:B:localhost:A username@your_remote_host_name`
3. Go to the address `localhost:B` in your browser
### Evaluation
To start evaluation, use `eval.py` with required arguments
```
-d DATASET, dataset to use (default: None), built-in support exists for ChEMBL, ZINC
-o {random,canonical}, order to generate graphs, used for naming evaluation directory (default: None)
-p MODEL_PATH, path to saved model (default: None). This is not needed if you want to use pretrained models.
-pr, Whether to use a pre-trained model (default: False)
```
and optional arguments
```
-s SEED, random seed (default: 0)
-ns NUM_SAMPLES, Number of molecules to generate (default: 100000)
-mn MAX_NUM_STEPS, Max number of steps allowed in generated molecules to
ensure termination (default: 400)
-np NUM_PROCESSES, number of processes to use (default: 32)
-gt GENERATION_TIME, max time (seconds) allowed for generation with
multiprocess (default: 600)
```
All evaluation results can be found in `eval_results`.
After the evaluation, 100000 molecules will be generated and stored in `generated_smiles.txt` under `eval_results`
directory, with three statistics logged in `generation_stats.txt` under `eval_results`:
1. `Validity among all` gives the percentage of molecules that are valid
2. `Uniqueness among valid ones` gives the percentage of valid molecules that are unique
3. `Novelty among unique ones` gives the percentage of unique valid molecules that are novel (not seen in training data)
We also provide a jupyter notebook where you can visualize the generated molecules
![](https://data.dgl.ai/dgllife/dgmg/DGMG_ZINC_canonical_vis.png)
and compare their property distributions against the training molecule property distributions
![](https://data.dgl.ai/dgllife/dgmg/DGMG_ZINC_canonical_dist.png)
You can download the notebook with `wget https://data.dgl.ai/dgllife/dgmg/eval_jupyter.ipynb`.
### Pre-trained models
Below gives the statistics of pre-trained models. With random order, the training becomes significantly more difficult
as we now have `N^2` data points with `N` molecules.
| Pre-trained model | % valid | % unique among valid | % novel among unique |
| ------------------ | ------- | -------------------- | -------------------- |
| `ChEMBL_canonical` | 78.80 | 99.19 | 98.60 |
| `ChEMBL_random` | 29.09 | 99.87 | 100.00 |
| `ZINC_canonical` | 74.60 | 99.87 | 99.87 |
| `ZINC_random` | 12.37 | 99.38 | 100.00 |
import os
import pickle
import shutil
import torch
from dgllife.model import DGMG, load_pretrained
from utils import MoleculeDataset, set_random_seed, download_data,\
mkdir_p, summarize_molecules, get_unique_smiles, get_novel_smiles
def generate_and_save(log_dir, num_samples, max_num_steps, model):
with open(os.path.join(log_dir, 'generated_smiles.txt'), 'w') as f:
for i in range(num_samples):
with torch.no_grad():
s = model(rdkit_mol=True, max_num_steps=max_num_steps)
f.write(s + '\n')
def prepare_for_evaluation(rank, args):
worker_seed = args['seed'] + rank * 10000
set_random_seed(worker_seed)
torch.set_num_threads(1)
# Setup dataset and data loader
dataset = MoleculeDataset(args['dataset'], subset_id=rank, n_subsets=args['num_processes'])
# Initialize model
if not args['pretrained']:
model = DGMG(atom_types=dataset.atom_types,
bond_types=dataset.bond_types,
node_hidden_size=args['node_hidden_size'],
num_prop_rounds=args['num_propagation_rounds'], dropout=args['dropout'])
model.load_state_dict(torch.load(args['model_path'])['model_state_dict'])
else:
model = load_pretrained('_'.join(['DGMG', args['dataset'], args['order']]), log=False)
model.eval()
worker_num_samples = args['num_samples'] // args['num_processes']
if rank == args['num_processes'] - 1:
worker_num_samples += args['num_samples'] % args['num_processes']
worker_log_dir = os.path.join(args['log_dir'], str(rank))
mkdir_p(worker_log_dir, log=False)
generate_and_save(worker_log_dir, worker_num_samples, args['max_num_steps'], model)
def remove_worker_tmp_dir(args):
for rank in range(args['num_processes']):
worker_path = os.path.join(args['log_dir'], str(rank))
try:
shutil.rmtree(worker_path)
except OSError:
print('Directory {} does not exist!'.format(worker_path))
def aggregate_and_evaluate(args):
print('Merging generated SMILES into a single file...')
smiles = []
for rank in range(args['num_processes']):
with open(os.path.join(args['log_dir'], str(rank), 'generated_smiles.txt'), 'r') as f:
rank_smiles = f.read().splitlines()
smiles.extend(rank_smiles)
with open(os.path.join(args['log_dir'], 'generated_smiles.txt'), 'w') as f:
for s in smiles:
f.write(s + '\n')
print('Removing temporary dirs...')
remove_worker_tmp_dir(args)
# Summarize training molecules
print('Summarizing training molecules...')
train_file = '_'.join([args['dataset'], 'DGMG_train.txt'])
if not os.path.exists(train_file):
download_data(args['dataset'], train_file)
with open(train_file, 'r') as f:
train_smiles = f.read().splitlines()
train_summary = summarize_molecules(train_smiles, args['num_processes'])
with open(os.path.join(args['log_dir'], 'train_summary.pickle'), 'wb') as f:
pickle.dump(train_summary, f)
# Summarize generated molecules
print('Summarizing generated molecules...')
generation_summary = summarize_molecules(smiles, args['num_processes'])
with open(os.path.join(args['log_dir'], 'generation_summary.pickle'), 'wb') as f:
pickle.dump(generation_summary, f)
# Stats computation
print('Preparing generation statistics...')
valid_generated_smiles = generation_summary['smile']
unique_generated_smiles = get_unique_smiles(valid_generated_smiles)
unique_train_smiles = get_unique_smiles(train_summary['smile'])
novel_generated_smiles = get_novel_smiles(unique_generated_smiles, unique_train_smiles)
with open(os.path.join(args['log_dir'], 'generation_stats.txt'), 'w') as f:
f.write('Total number of generated molecules: {:d}\n'.format(len(smiles)))
f.write('Validity among all: {:.4f}\n'.format(
len(valid_generated_smiles) / len(smiles)))
f.write('Uniqueness among valid ones: {:.4f}\n'.format(
len(unique_generated_smiles) / len(valid_generated_smiles)))
f.write('Novelty among unique ones: {:.4f}\n'.format(
len(novel_generated_smiles) / len(unique_generated_smiles)))
if __name__ == '__main__':
import argparse
import datetime
import time
from rdkit import rdBase
from utils import setup
parser = argparse.ArgumentParser(description='Evaluating DGMG for molecule generation',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
# configure
parser.add_argument('-s', '--seed', type=int, default=0, help='random seed')
# dataset and setting
parser.add_argument('-d', '--dataset',
help='dataset to use')
parser.add_argument('-o', '--order', choices=['random', 'canonical'],
help='order to generate graphs, used for naming evaluation directory')
# log
parser.add_argument('-l', '--log-dir', default='./eval_results',
help='folder to save evaluation results')
parser.add_argument('-p', '--model-path', type=str, default=None,
help='path to saved model')
parser.add_argument('-pr', '--pretrained', action='store_true',
help='Whether to use a pre-trained model')
parser.add_argument('-ns', '--num-samples', type=int, default=100000,
help='Number of molecules to generate')
parser.add_argument('-mn', '--max-num-steps', type=int, default=400,
help='Max number of steps allowed in generated molecules to ensure termination')
# multi-process
parser.add_argument('-np', '--num-processes', type=int, default=32,
help='number of processes to use')
parser.add_argument('-gt', '--generation-time', type=int, default=600,
help='max time (seconds) allowed for generation with multiprocess')
args = parser.parse_args()
args = setup(args, train=False)
rdBase.DisableLog('rdApp.error')
t1 = time.time()
if args['num_processes'] == 1:
prepare_for_evaluation(0, args)
else:
import multiprocessing as mp
procs = []
for rank in range(args['num_processes']):
p = mp.Process(target=prepare_for_evaluation, args=(rank, args,))
procs.append(p)
p.start()
while time.time() - t1 <= args['generation_time']:
if any(p.is_alive() for p in procs):
time.sleep(5)
else:
break
else:
print('Timeout, killing all processes.')
for p in procs:
p.terminate()
p.join()
t2 = time.time()
print('It took {} for generation.'.format(
datetime.timedelta(seconds=t2 - t1)))
aggregate_and_evaluate(args)
#
# calculation of synthetic accessibility score as described in:
#
# Estimation of Synthetic Accessibility Score of Drug-like Molecules
# based on Molecular Complexity and Fragment Contributions
# Peter Ertl and Ansgar Schuffenhauer
# Journal of Cheminformatics 1:8 (2009)
# http://www.jcheminf.com/content/1/1/8
#
# several small modifications to the original paper are included
# particularly slightly different formula for marocyclic penalty
# and taking into account also molecule symmetry (fingerprint density)
#
# for a set of 10k diverse molecules the agreement between the original method
# as implemented in PipelinePilot and this implementation is r2 = 0.97
#
# peter ertl & greg landrum, september 2013
#
# A small modification is performed
#
# DGL team, August 2019
#
from __future__ import print_function
import math
import os
from rdkit import Chem
from rdkit.Chem import rdMolDescriptors
from rdkit.six.moves import cPickle
from rdkit.six import iteritems
from dgl.data.utils import download, _get_dgl_url
_fscores = None
def readFragmentScores(name='fpscores'):
import gzip
global _fscores
fname = '{}.pkl.gz'.format(name)
download(_get_dgl_url(os.path.join('dataset', fname)), path=fname)
_fscores = cPickle.load(gzip.open(fname))
outDict = {}
for i in _fscores:
for j in range(1, len(i)):
outDict[i[j]] = float(i[0])
_fscores = outDict
def numBridgeheadsAndSpiro(mol):
nSpiro = rdMolDescriptors.CalcNumSpiroAtoms(mol)
nBridgehead = rdMolDescriptors.CalcNumBridgeheadAtoms(mol)
return nBridgehead, nSpiro
def calculateScore(m):
if _fscores is None:
readFragmentScores()
# fragment score
# 2 is the *radius* of the circular fingerprint
fp = rdMolDescriptors.GetMorganFingerprint(m, 2)
fps = fp.GetNonzeroElements()
score1 = 0.
nf = 0
for bitId, v in iteritems(fps):
nf += v
sfp = bitId
score1 += _fscores.get(sfp, -4) * v
# We add L63 to avoid ZeroDivisionError.
if nf != 0:
score1 /= nf
# features score
nAtoms = m.GetNumAtoms()
nChiralCenters = len(Chem.FindMolChiralCenters(m, includeUnassigned=True))
ri = m.GetRingInfo()
nBridgeheads, nSpiro = numBridgeheadsAndSpiro(m)
nMacrocycles = 0
for x in ri.AtomRings():
if len(x) > 8:
nMacrocycles += 1
sizePenalty = nAtoms**1.005 - nAtoms
stereoPenalty = math.log10(nChiralCenters + 1)
spiroPenalty = math.log10(nSpiro + 1)
bridgePenalty = math.log10(nBridgeheads + 1)
macrocyclePenalty = 0.
# ---------------------------------------
# This differs from the paper, which defines:
# macrocyclePenalty = math.log10(nMacrocycles+1)
# This form generates better results when 2 or more macrocycles are present
if nMacrocycles > 0:
macrocyclePenalty = math.log10(2)
score2 = 0. - sizePenalty - stereoPenalty - \
spiroPenalty - bridgePenalty - macrocyclePenalty
# correction for the fingerprint density
# not in the original publication, added in version 1.1
# to make highly symmetrical molecules easier to synthetise
score3 = 0.
if nAtoms > len(fps):
score3 = math.log(float(nAtoms) / len(fps)) * .5
sascore = score1 + score2 + score3
# need to transform "raw" value into scale between 1 and 10
min = -4.0
max = 2.5
sascore = 11. - (sascore - min + 1) / (max - min) * 9.
# smooth the 10-end
if sascore > 8.:
sascore = 8. + math.log(sascore + 1. - 9.)
if sascore > 10.:
sascore = 10.0
elif sascore < 1.:
sascore = 1.0
return sascore
def processMols(mols):
print('smiles\tName\tsa_score')
for i, m in enumerate(mols):
if m is None:
continue
s = calculateScore(m)
smiles = Chem.MolToSmiles(m)
print(smiles + "\t" + m.GetProp('_Name') + "\t%3f" % s)
if __name__ == '__main__':
import sys, time
t1 = time.time()
readFragmentScores("fpscores")
t2 = time.time()
suppl = Chem.SmilesMolSupplier(sys.argv[1])
t3 = time.time()
processMols(suppl)
t4 = time.time()
print('Reading took %.2f seconds. Calculating took %.2f seconds' % ((t2 - t1), (t4 - t3)),
file=sys.stderr)
#
# Copyright (c) 2013, Novartis Institutes for BioMedical Research Inc.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following
# disclaimer in the documentation and/or other materials provided
# with the distribution.
# * Neither the name of Novartis Institutes for BioMedical Research Inc.
# nor the names of its contributors may be used to endorse or promote
# products derived from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
"""
Learning Deep Generative Models of Graphs
Paper: https://arxiv.org/pdf/1803.03324.pdf
"""
import datetime
import time
import torch
import torch.distributed as dist
from dgllife.model import DGMG
from torch.optim import Adam
from torch.utils.data import DataLoader
from utils import MoleculeDataset, Printer, set_random_seed, synchronize, launch_a_process
def evaluate(epoch, model, data_loader, printer):
model.eval()
batch_size = data_loader.batch_size
total_log_prob = 0
with torch.no_grad():
for i, data in enumerate(data_loader):
log_prob = model(actions=data, compute_log_prob=True).detach()
total_log_prob -= log_prob
if printer is not None:
prob = log_prob.detach().exp()
printer.update(epoch + 1, - log_prob / batch_size, prob / batch_size)
return total_log_prob / len(data_loader)
def main(rank, args):
"""
Parameters
----------
rank : int
Subprocess id
args : dict
Configuration
"""
if rank == 0:
t1 = time.time()
set_random_seed(args['seed'])
# Remove the line below will result in problems for multiprocess
torch.set_num_threads(1)
# Setup dataset and data loader
dataset = MoleculeDataset(args['dataset'], args['order'], ['train', 'val'],
subset_id=rank, n_subsets=args['num_processes'])
# Note that currently the batch size for the loaders should only be 1.
train_loader = DataLoader(dataset.train_set, batch_size=args['batch_size'],
shuffle=True, collate_fn=dataset.collate)
val_loader = DataLoader(dataset.val_set, batch_size=args['batch_size'],
shuffle=True, collate_fn=dataset.collate)
if rank == 0:
try:
from tensorboardX import SummaryWriter
writer = SummaryWriter(args['log_dir'])
except ImportError:
print('If you want to use tensorboard, install tensorboardX with pip.')
writer = None
train_printer = Printer(args['nepochs'], len(dataset.train_set), args['batch_size'], writer)
val_printer = Printer(args['nepochs'], len(dataset.val_set), args['batch_size'])
else:
val_printer = None
# Initialize model
model = DGMG(atom_types=dataset.atom_types,
bond_types=dataset.bond_types,
node_hidden_size=args['node_hidden_size'],
num_prop_rounds=args['num_propagation_rounds'],
dropout=args['dropout'])
if args['num_processes'] == 1:
from utils import Optimizer
optimizer = Optimizer(args['lr'], Adam(model.parameters(), lr=args['lr']))
else:
from utils import MultiProcessOptimizer
optimizer = MultiProcessOptimizer(args['num_processes'], args['lr'],
Adam(model.parameters(), lr=args['lr']))
if rank == 0:
t2 = time.time()
best_val_prob = 0
# Training
for epoch in range(args['nepochs']):
model.train()
if rank == 0:
print('Training')
for i, data in enumerate(train_loader):
log_prob = model(actions=data, compute_log_prob=True)
prob = log_prob.detach().exp()
loss_averaged = - log_prob
prob_averaged = prob
optimizer.backward_and_step(loss_averaged)
if rank == 0:
train_printer.update(epoch + 1, loss_averaged.item(), prob_averaged.item())
synchronize(args['num_processes'])
# Validation
val_log_prob = evaluate(epoch, model, val_loader, val_printer)
if args['num_processes'] > 1:
dist.all_reduce(val_log_prob, op=dist.ReduceOp.SUM)
val_log_prob /= args['num_processes']
# Strictly speaking, the computation of probability here is different from what is
# performed on the training set as we first take an average of log likelihood and then
# take the exponentiation. By Jensen's inequality, the resulting value is then a
# lower bound of the real probabilities.
val_prob = (- val_log_prob).exp().item()
val_log_prob = val_log_prob.item()
if val_prob >= best_val_prob:
if rank == 0:
torch.save({'model_state_dict': model.state_dict()}, args['checkpoint_dir'])
print('Old val prob {:.10f} | new val prob {:.10f} | model saved'.format(best_val_prob, val_prob))
best_val_prob = val_prob
elif epoch >= args['warmup_epochs']:
optimizer.decay_lr()
if rank == 0:
print('Validation')
if writer is not None:
writer.add_scalar('validation_log_prob', val_log_prob, epoch)
writer.add_scalar('validation_prob', val_prob, epoch)
writer.add_scalar('lr', optimizer.lr, epoch)
print('Validation log prob {:.4f} | prob {:.10f}'.format(val_log_prob, val_prob))
synchronize(args['num_processes'])
if rank == 0:
t3 = time.time()
print('It took {} to setup.'.format(datetime.timedelta(seconds=t2 - t1)))
print('It took {} to finish training.'.format(datetime.timedelta(seconds=t3 - t2)))
print('--------------------------------------------------------------------------')
print('On average, an epoch takes {}.'.format(datetime.timedelta(
seconds=(t3 - t2) / args['nepochs'])))
if __name__ == '__main__':
import argparse
from utils import setup
parser = argparse.ArgumentParser(description='Training DGMG for molecule generation',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
# configure
parser.add_argument('-s', '--seed', type=int, default=0, help='random seed')
parser.add_argument('-w', '--warmup-epochs', type=int, default=10,
help='Number of epochs where no lr decay is performed.')
# dataset and setting
parser.add_argument('-d', '--dataset',
help='dataset to use')
parser.add_argument('-o', '--order', choices=['random', 'canonical'],
help='order to generate graphs')
parser.add_argument('-tf', '--train-file', type=str, default=None,
help='Path to a file with one SMILES a line for training data. '
'This is only necessary if you want to use a new dataset.')
parser.add_argument('-vf', '--val-file', type=str, default=None,
help='Path to a file with one SMILES a line for validation data. '
'This is only necessary if you want to use a new dataset.')
# log
parser.add_argument('-l', '--log-dir', default='./training_results',
help='folder to save info like experiment configuration')
# multi-process
parser.add_argument('-np', '--num-processes', type=int, default=32,
help='number of processes to use')
parser.add_argument('-mi', '--master-ip', type=str, default='127.0.0.1')
parser.add_argument('-mp', '--master-port', type=str, default='12345')
args = parser.parse_args()
args = setup(args, train=True)
if args['num_processes'] == 1:
main(0, args)
else:
mp = torch.multiprocessing.get_context('spawn')
procs = []
for rank in range(args['num_processes']):
procs.append(mp.Process(target=launch_a_process, args=(rank, args, main), daemon=True))
procs[-1].start()
for p in procs:
p.join()
import datetime
import math
import numpy as np
import os
import pickle
import random
import torch
import torch.distributed as dist
import torch.nn as nn
from collections import defaultdict
from datetime import timedelta
from dgl.data.utils import download, _get_dgl_url
from dgllife.model.model_zoo.dgmg import MoleculeEnv
from multiprocessing import Pool
from pprint import pprint
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem.Crippen import MolLogP
from rdkit.Chem.QED import qed
from torch.utils.data import Dataset
from sascorer import calculateScore
########################################################################################################################
# configuration #
########################################################################################################################
def mkdir_p(path, log=True):
"""Create a directory for the specified path.
Parameters
----------
path : str
Path name
log : bool
Whether to print result for directory creation
"""
import errno
try:
os.makedirs(path)
if log:
print('Created directory {}'.format(path))
except OSError as exc:
if exc.errno == errno.EEXIST and os.path.isdir(path) and log:
print('Directory {} already exists.'.format(path))
else:
raise
def get_date_postfix():
"""Get a date based postfix for directory name.
Returns
-------
post_fix : str
"""
dt = datetime.datetime.now()
post_fix = '{}_{:02d}-{:02d}-{:02d}'.format(
dt.date(), dt.hour, dt.minute, dt.second)
return post_fix
def setup_log_dir(args):
"""Name and create directory for logging.
Parameters
----------
args : dict
Configuration
Returns
-------
log_dir : str
Path for logging directory
"""
date_postfix = get_date_postfix()
log_dir = os.path.join(
args['log_dir'],
'{}_{}_{}'.format(args['dataset'], args['order'], date_postfix))
mkdir_p(log_dir)
return log_dir
def save_arg_dict(args, filename='settings.txt'):
"""Save all experiment settings in a file.
Parameters
----------
args : dict
Configuration
filename : str
Name for the file to save settings
"""
def _format_value(v):
if isinstance(v, float):
return '{:.4f}'.format(v)
elif isinstance(v, int):
return '{:d}'.format(v)
else:
return '{}'.format(v)
save_path = os.path.join(args['log_dir'], filename)
with open(save_path, 'w') as f:
for key, value in args.items():
f.write('{}\t{}\n'.format(key, _format_value(value)))
print('Saved settings to {}'.format(save_path))
def configure(args):
"""Use default hyperparameters.
Parameters
----------
args : dict
Old configuration
Returns
-------
args : dict
Updated configuration
"""
configure = {
'node_hidden_size': 128,
'num_propagation_rounds': 2,
'lr': 1e-4,
'dropout': 0.2,
'nepochs': 400,
'batch_size': 1,
}
args.update(configure)
return args
def set_random_seed(seed):
"""Fix random seed for reproducible results.
Parameters
----------
seed : int
Random seed to use.
"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
def setup_dataset(args):
"""Dataset setup
For unsupported dataset, we need to perform data preprocessing.
Parameters
----------
args : dict
Configuration
"""
if args['dataset'] in ['ChEMBL', 'ZINC']:
print('Built-in support for dataset {} exists.'.format(args['dataset']))
else:
print('Configure for new dataset {}...'.format(args['dataset']))
configure_new_dataset(args['dataset'], args['train_file'], args['val_file'])
def setup(args, train=True):
"""Setup
Parameters
----------
args : argparse.Namespace
Configuration
train : bool
Whether the setup is for training or evaluation
"""
# Convert argparse.Namespace into a dict
args = args.__dict__.copy()
# Dataset
args = configure(args)
# Log
print('Prepare logging directory...')
log_dir = setup_log_dir(args)
args['log_dir'] = log_dir
save_arg_dict(args)
if train:
setup_dataset(args)
args['checkpoint_dir'] = os.path.join(log_dir, 'checkpoint.pth')
pprint(args)
return args
########################################################################################################################
# multi-process #
########################################################################################################################
def synchronize(num_processes):
"""Synchronize all processes.
Parameters
----------
num_processes : int
Number of subprocesses used
"""
if num_processes > 1:
dist.barrier()
def launch_a_process(rank, args, target, minutes=720):
"""Launch a subprocess for training.
Parameters
----------
rank : int
Subprocess id
args : dict
Configuration
target : callable
Target function for the subprocess
minutes : int
Timeout minutes for operations executed against the process group
"""
dist_init_method = 'tcp://{master_ip}:{master_port}'.format(
master_ip=args['master_ip'], master_port=args['master_port'])
dist.init_process_group(backend='gloo',
init_method=dist_init_method,
# If you have a larger dataset, you will need to increase it.
timeout=timedelta(minutes=minutes),
world_size=args['num_processes'],
rank=rank)
assert torch.distributed.get_rank() == rank
target(rank, args)
########################################################################################################################
# optimization #
########################################################################################################################
class Optimizer(nn.Module):
"""Wrapper for optimization
Parameters
----------
lr : float
Initial learning rate
optimizer
model optimizer
"""
def __init__(self, lr, optimizer):
super(Optimizer, self).__init__()
self.lr = lr
self.optimizer = optimizer
self._reset()
def _reset(self):
self.optimizer.zero_grad()
def backward_and_step(self, loss):
"""Backward and update model.
Parameters
----------
loss : torch.tensor consisting of a float only
"""
loss.backward()
self.optimizer.step()
self._reset()
def decay_lr(self, decay_rate=0.99):
"""Decay learning rate.
Parameters
----------
decay_rate : float
Multiply the current learning rate by the decay_rate
"""
self.lr *= decay_rate
for param_group in self.optimizer.param_groups:
param_group['lr'] = self.lr
class MultiProcessOptimizer(Optimizer):
"""Wrapper for optimization with multiprocess
Parameters
----------
n_processes : int
Number of processes used
lr : float
Initial learning rate
optimizer
model optimizer
"""
def __init__(self, n_processes, lr, optimizer):
super(MultiProcessOptimizer, self).__init__(lr=lr, optimizer=optimizer)
self.n_processes = n_processes
def _sync_gradient(self):
"""Average gradients across all subprocesses."""
for param_group in self.optimizer.param_groups:
for p in param_group['params']:
if p.requires_grad and p.grad is not None:
dist.all_reduce(p.grad.data, op=dist.ReduceOp.SUM)
p.grad.data /= self.n_processes
def backward_and_step(self, loss):
"""Backward and update model.
Parameters
----------
loss : torch.tensor consisting of a float only
"""
loss.backward()
self._sync_gradient()
self.optimizer.step()
self._reset()
########################################################################################################################
# data #
########################################################################################################################
def initialize_neuralization_reactions():
"""Reference neuralization reactions
Code adapted from RDKit Cookbook, by Hans de Winter.
"""
patts = (
# Imidazoles
('[n+;H]', 'n'),
# Amines
('[N+;!H0]', 'N'),
# Carboxylic acids and alcohols
('[$([O-]);!$([O-][#7])]', 'O'),
# Thiols
('[S-;X1]', 'S'),
# Sulfonamides
('[$([N-;X2]S(=O)=O)]', 'N'),
# Enamines
('[$([N-;X2][C,N]=C)]', 'N'),
# Tetrazoles
('[n-]', '[n]'),
# Sulfoxides
('[$([S-]=O)]', 'S'),
# Amides
('[$([N-]C=O)]', 'N'),
)
return [(Chem.MolFromSmarts(x), Chem.MolFromSmiles(y, False)) for x, y in patts]
def neutralize_charges(mol, reactions=None):
"""Deprotonation for molecules.
Code adapted from RDKit Cookbook, by Hans de Winter.
DGMG currently cannot generate protonated molecules.
For example, it can only generate
CC(C)(C)CC1CCC[NH+]1Cc1nnc(-c2ccccc2F)o1
from
CC(C)(C)CC1CCCN1Cc1nnc(-c2ccccc2F)o1
even with correct decisions.
Deprotonation is therefore an important step to avoid
false novel molecules.
Parameters
----------
mol : Chem.rdchem.Mol
reactions : list of 2-tuples
Rules for deprotonation
Returns
-------
mol : Chem.rdchem.Mol
Deprotonated molecule
"""
if reactions is None:
reactions = initialize_neuralization_reactions()
for i, (reactant, product) in enumerate(reactions):
while mol.HasSubstructMatch(reactant):
rms = AllChem.ReplaceSubstructs(mol, reactant, product)
mol = rms[0]
return mol
def standardize_mol(mol):
"""Standardize molecule to avoid false novel molecule.
Kekulize and deprotonate molecules to avoid false novel molecules.
In addition to deprotonation, we also kekulize molecules to avoid
explicit Hs in the SMILES. Otherwise we will get false novel molecules
as well. For example, DGMG can only generate
O=S(=O)(NC1=CC=CC(C(F)(F)F)=C1)C1=CNC=N1
from
O=S(=O)(Nc1cccc(C(F)(F)F)c1)c1c[nH]cn1.
One downside is that we remove all explicit aromatic rings and to
explicitly predict aromatic bond might make the learning easier for
the model.
"""
reactions = initialize_neuralization_reactions()
Chem.Kekulize(mol, clearAromaticFlags=True)
mol = neutralize_charges(mol, reactions)
return mol
def smiles_to_standard_mol(s):
"""Convert SMILES to a standard molecule.
Parameters
----------
s : str
SMILES
Returns
-------
Chem.rdchem.Mol
Standardized molecule
"""
mol = Chem.MolFromSmiles(s)
return standardize_mol(mol)
def mol_to_standard_smile(mol):
"""Standardize a molecule and convert it to a SMILES.
Parameters
----------
mol : Chem.rdchem.Mol
Returns
-------
str
SMILES
"""
return Chem.MolToSmiles(standardize_mol(mol))
def get_atom_and_bond_types(smiles, log=True):
"""Identify the atom types and bond types
appearing in this dataset.
Parameters
----------
smiles : list
List of smiles
log : bool
Whether to print the process of pre-processing.
Returns
-------
atom_types : list
E.g. ['C', 'N']
bond_types : list
E.g. [Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE,
Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC]
"""
atom_types = set()
bond_types = set()
n_smiles = len(smiles)
for i, s in enumerate(smiles):
if log:
print('Processing smiles {:d}/{:d}'.format(i + 1, n_smiles))
mol = smiles_to_standard_mol(s)
if mol is None:
continue
for atom in mol.GetAtoms():
a_symbol = atom.GetSymbol()
if a_symbol not in atom_types:
atom_types.add(a_symbol)
for bond in mol.GetBonds():
b_type = bond.GetBondType()
if b_type not in bond_types:
bond_types.add(b_type)
return list(atom_types), list(bond_types)
def eval_decisions(env, decisions):
"""This function mimics the way DGMG generates a molecule and is
helpful for debugging and verification in data preprocessing.
Parameters
----------
env : MoleculeEnv
MDP environment for generating molecules
decisions : list of 2-tuples of int
A decision sequence for generating a molecule
Returns
-------
str
SMILES for the molecule generated with decisions
"""
env.reset(rdkit_mol=True)
t = 0
def whether_to_add_atom(t):
assert decisions[t][0] == 0
atom_type = decisions[t][1]
t += 1
return t, atom_type
def whether_to_add_bond(t):
assert decisions[t][0] == 1
bond_type = decisions[t][1]
t += 1
return t, bond_type
def decide_atom2(t):
assert decisions[t][0] == 2
dst = decisions[t][1]
t += 1
return t, dst
t, atom_type = whether_to_add_atom(t)
while atom_type != len(env.atom_types):
env.add_atom(atom_type)
t, bond_type = whether_to_add_bond(t)
while bond_type != len(env.bond_types):
t, dst = decide_atom2(t)
env.add_bond((env.num_atoms() - 1), dst, bond_type)
t, bond_type = whether_to_add_bond(t)
t, atom_type = whether_to_add_atom(t)
assert t == len(decisions)
return env.get_current_smiles()
def get_DGMG_smile(env, mol):
"""Mimics the reproduced SMILES with DGMG for a molecule.
Given a molecule, we are interested in what SMILES we will
get if we want to generate it with DGMG. This is an important
step to check false novel molecules.
Parameters
----------
env : MoleculeEnv
MDP environment for generating molecules
mol : Chem.rdchem.Mol
A molecule
Returns
-------
canonical_smile : str
SMILES of the generated molecule with a canonical decision sequence
random_smile : str
SMILES of the generated molecule with a random decision sequence
"""
canonical_decisions = env.get_decision_sequence(mol, list(range(mol.GetNumAtoms())))
canonical_smile = eval_decisions(env, canonical_decisions)
order = list(range(mol.GetNumAtoms()))
random.shuffle(order)
random_decisions = env.get_decision_sequence(mol, order)
random_smile = eval_decisions(env, random_decisions)
return canonical_smile, random_smile
def preprocess_dataset(atom_types, bond_types, smiles, max_num_atoms=23):
"""Preprocess the dataset
1. Standardize the SMILES of the dataset
2. Only keep the SMILES that DGMG can reproduce
3. Drop repeated SMILES
Parameters
----------
atom_types : list
The types of atoms appearing in a dataset. E.g. ['C', 'N']
bond_types : list
The types of bonds appearing in a dataset.
E.g. [Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE,
Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC]
Returns
-------
valid_smiles : list of str
SMILES left after preprocessing
"""
valid_smiles = []
env = MoleculeEnv(atom_types, bond_types)
for id, s in enumerate(smiles):
print('Processing {:d}/{:d}'.format(id + 1, len(smiles)))
raw_s = s.strip()
mol = smiles_to_standard_mol(raw_s)
if mol is None:
continue
standard_s = Chem.MolToSmiles(mol)
if (max_num_atoms is not None) and (mol.GetNumAtoms() > max_num_atoms):
continue
canonical_s, random_s = get_DGMG_smile(env, mol)
canonical_mol = Chem.MolFromSmiles(canonical_s)
random_mol = Chem.MolFromSmiles(random_s)
if (standard_s != canonical_s) or (canonical_s != random_s) or (canonical_mol is None) or (random_mol is None):
continue
valid_smiles.append(standard_s)
valid_smiles = list(set(valid_smiles))
return valid_smiles
def download_data(dataset, fname):
"""Download dataset if built-in support exists
Parameters
----------
dataset : str
Dataset name
fname : str
Name of dataset file
"""
if dataset not in ['ChEMBL', 'ZINC']:
# For dataset without built-in support, they should be locally processed.
return
data_path = fname
download(_get_dgl_url(os.path.join('dataset', fname)), path=data_path)
def load_smiles_from_file(f_name):
"""Load dataset into a list of SMILES
Parameters
----------
f_name : str
Path to a file of molecules, where each line of the file
is a molecule in SMILES format.
Returns
-------
smiles : list of str
List of molecules as SMILES
"""
with open(f_name, 'r') as f:
smiles = f.read().splitlines()
return smiles
def write_smiles_to_file(f_name, smiles):
"""Write dataset to a file.
Parameters
----------
f_name : str
Path to create a file of molecules, where each line of the file
is a molecule in SMILES format.
smiles : list of str
List of SMILES
"""
with open(f_name, 'w') as f:
for s in smiles:
f.write(s + '\n')
def configure_new_dataset(dataset, train_file, val_file):
"""Configure for a new dataset.
Parameters
----------
dataset : str
Dataset name
train_file : str
Path to a file with one SMILES a line for training data
val_file : str
Path to a file with one SMILES a line for validation data
"""
assert train_file is not None, 'Expect a file of SMILES for training, got None.'
assert val_file is not None, 'Expect a file of SMILES for validation, got None.'
train_smiles = load_smiles_from_file(train_file)
val_smiles = load_smiles_from_file(val_file)
all_smiles = train_smiles + val_smiles
# Get all atom and bond types in the dataset
path_to_atom_and_bond_types = '_'.join([dataset, 'atom_and_bond_types.pkl'])
if not os.path.exists(path_to_atom_and_bond_types):
atom_types, bond_types = get_atom_and_bond_types(all_smiles)
with open(path_to_atom_and_bond_types, 'wb') as f:
pickle.dump({'atom_types': atom_types, 'bond_types': bond_types}, f)
else:
with open(path_to_atom_and_bond_types, 'rb') as f:
type_info = pickle.load(f)
atom_types = type_info['atom_types']
bond_types = type_info['bond_types']
# Standardize training data
path_to_processed_train_data = '_'.join([dataset, 'DGMG', 'train.txt'])
if not os.path.exists(path_to_processed_train_data):
processed_train_smiles = preprocess_dataset(atom_types, bond_types, train_smiles, None)
write_smiles_to_file(path_to_processed_train_data, processed_train_smiles)
path_to_processed_val_data = '_'.join([dataset, 'DGMG', 'val.txt'])
if not os.path.exists(path_to_processed_val_data):
processed_val_smiles = preprocess_dataset(atom_types, bond_types, val_smiles, None)
write_smiles_to_file(path_to_processed_val_data, processed_val_smiles)
class MoleculeDataset(object):
"""Initialize and split the dataset.
Parameters
----------
dataset : str
Dataset name
order : None or str
Order to extract a decision sequence for generating a molecule. Default to be None.
modes : None or list
List of subsets to use, which can contain 'train', 'val', corresponding to
training and validation. Default to be None.
subset_id : int
With multiprocess training, we partition the training set into multiple subsets and
each process will use one subset only. This subset_id corresponds to subprocess id.
n_subsets : int
With multiprocess training, this corresponds to the number of total subprocesses.
"""
def __init__(self, dataset, order=None, modes=None, subset_id=0, n_subsets=1):
super(MoleculeDataset, self).__init__()
if modes is None:
modes = []
else:
assert order is not None, 'An order should be specified for extracting ' \
'decision sequences.'
assert order in ['random', 'canonical', None], \
"Unexpected order option to get sequences of graph generation decisions"
assert len(set(modes) - {'train', 'val'}) == 0, \
"modes should be a list, representing a subset of ['train', 'val']"
self.dataset = dataset
self.order = order
self.modes = modes
self.subset_id = subset_id
self.n_subsets = n_subsets
self._setup()
def collate(self, samples):
"""PyTorch's approach to batch multiple samples.
For auto-regressive generative models, we process one sample at a time.
Parameters
----------
samples : list
A list of length 1 that consists of decision sequence to generate a molecule.
Returns
-------
list
List of 2-tuples, a decision sequence to generate a molecule
"""
assert len(samples) == 1
return samples[0]
def _create_a_subset(self, smiles):
"""Create a dataset from a subset of smiles.
Parameters
----------
smiles : list of str
List of molecules in SMILES format
"""
# We evenly divide the smiles into multiple susbets with multiprocess
subset_size = len(smiles) // self.n_subsets
return Subset(smiles[self.subset_id * subset_size: (self.subset_id + 1) * subset_size],
self.order, self.env)
def _setup(self):
"""
1. Instantiate an MDP environment for molecule generation
2. Download the dataset, which is a file of SMILES
3. Create subsets for training and validation
"""
if self.dataset == 'ChEMBL':
# For new datasets, get_atom_and_bond_types can be used to
# identify the atom and bond types in them.
self.atom_types = ['O', 'Cl', 'C', 'S', 'F', 'Br', 'N']
self.bond_types = [Chem.rdchem.BondType.SINGLE,
Chem.rdchem.BondType.DOUBLE,
Chem.rdchem.BondType.TRIPLE]
elif self.dataset == 'ZINC':
self.atom_types = ['Br', 'S', 'C', 'P', 'N', 'O', 'F', 'Cl', 'I']
self.bond_types = [Chem.rdchem.BondType.SINGLE,
Chem.rdchem.BondType.DOUBLE,
Chem.rdchem.BondType.TRIPLE]
else:
path_to_atom_and_bond_types = '_'.join([self.dataset, 'atom_and_bond_types.pkl'])
with open(path_to_atom_and_bond_types, 'rb') as f:
type_info = pickle.load(f)
self.atom_types = type_info['atom_types']
self.bond_types = type_info['bond_types']
self.env = MoleculeEnv(self.atom_types, self.bond_types)
dataset_prefix = self._dataset_prefix()
if 'train' in self.modes:
fname = '_'.join([dataset_prefix, 'train.txt'])
download_data(self.dataset, fname)
smiles = load_smiles_from_file(fname)
self.train_set = self._create_a_subset(smiles)
if 'val' in self.modes:
fname = '_'.join([dataset_prefix, 'val.txt'])
download_data(self.dataset, fname)
smiles = load_smiles_from_file(fname)
# We evenly divide the smiles into multiple susbets with multiprocess
self.val_set = self._create_a_subset(smiles)
def _dataset_prefix(self):
"""Get the prefix for the data files of supported datasets.
Returns
-------
str
Prefix for dataset file name
"""
return '_'.join([self.dataset, 'DGMG'])
class Subset(Dataset):
"""A set of molecules which can be used for training, validation, test.
Parameters
----------
smiles : list
List of SMILES for the dataset
order : str
Specifies how decision sequences for molecule generation
are obtained, can be either "random" or "canonical"
env : MoleculeEnv object
MDP environment for generating molecules
"""
def __init__(self, smiles, order, env):
super(Subset, self).__init__()
self.smiles = smiles
self.order = order
self.env = env
self._setup()
def _setup(self):
"""Convert SMILES into rdkit molecule objects.
Decision sequences are extracted if we use a fixed order.
"""
smiles_ = []
mols = []
for s in self.smiles:
m = smiles_to_standard_mol(s)
if m is None:
continue
smiles_.append(s)
mols.append(m)
self.smiles = smiles_
self.mols = mols
if self.order is 'random':
return
self.decisions = []
for m in self.mols:
self.decisions.append(
self.env.get_decision_sequence(m, list(range(m.GetNumAtoms())))
)
def __len__(self):
"""Get number of molecules in the dataset."""
return len(self.mols)
def __getitem__(self, item):
"""Get the decision sequence for generating the molecule indexed by item."""
if self.order == 'canonical':
return self.decisions[item]
else:
m = self.mols[item]
nodes = list(range(m.GetNumAtoms()))
random.shuffle(nodes)
return self.env.get_decision_sequence(m, nodes)
########################################################################################################################
# progress tracking #
########################################################################################################################
class Printer(object):
def __init__(self, num_epochs, dataset_size, batch_size, writer=None):
"""Wrapper to track the learning progress.
Parameters
----------
num_epochs : int
Number of epochs for training
dataset_size : int
batch_size : int
writer : None or SummaryWriter
If not None, tensorboard will be used to visualize learning curves.
"""
super(Printer, self).__init__()
self.num_epochs = num_epochs
self.batch_size = batch_size
self.num_batches = math.ceil(dataset_size / batch_size)
self.count = 0
self.batch_count = 0
self.writer = writer
self._reset()
def _reset(self):
"""Reset when an epoch is completed."""
self.batch_loss = 0
self.batch_prob = 0
def _get_current_batch(self):
"""Get current batch index."""
remainer = self.batch_count % self.num_batches
if (remainer == 0):
return self.num_batches
else:
return remainer
def update(self, epoch, loss, prob):
"""Update learning progress.
Parameters
----------
epoch : int
loss : float
prob : float
"""
self.count += 1
self.batch_loss += loss
self.batch_prob += prob
if self.count % self.batch_size == 0:
self.batch_count += 1
if self.writer is not None:
self.writer.add_scalar('train_log_prob', self.batch_loss, self.batch_count)
self.writer.add_scalar('train_prob', self.batch_prob, self.batch_count)
print('epoch {:d}/{:d}, batch {:d}/{:d}, loss {:.4f}, prob {:.4f}'.format(
epoch, self.num_epochs, self._get_current_batch(),
self.num_batches, self.batch_loss, self.batch_prob))
self._reset()
########################################################################################################################
# eval #
########################################################################################################################
def summarize_a_molecule(smile, checklist=None):
"""Get information about a molecule.
Parameters
----------
smile : str
Molecule in SMILES format
checklist : dict
Things to learn about the molecule
"""
if checklist is None:
checklist = {
'HBA': Chem.rdMolDescriptors.CalcNumHBA,
'HBD': Chem.rdMolDescriptors.CalcNumHBD,
'logP': MolLogP,
'SA': calculateScore,
'TPSA': Chem.rdMolDescriptors.CalcTPSA,
'QED': qed,
'NumAtoms': lambda mol: mol.GetNumAtoms(),
'NumBonds': lambda mol: mol.GetNumBonds()
}
summary = dict()
mol = Chem.MolFromSmiles(smile)
if mol is None:
summary.update({
'smile': smile,
'valid': False
})
for k in checklist.keys():
summary[k] = None
else:
mol = standardize_mol(mol)
summary.update({
'smile': Chem.MolToSmiles(mol),
'valid': True
})
Chem.SanitizeMol(mol)
for k, f in checklist.items():
summary[k] = f(mol)
return summary
def summarize_molecules(smiles, num_processes):
"""Summarize molecules with multiprocess.
Parameters
----------
smiles : list of str
List of molecules in SMILES for summarization
num_processes : int
Number of processes to use for summarization
Returns
-------
summary_for_valid : dict
Summary of all valid molecules, where
summary_for_valid[k] gives the values of all
valid molecules on item k.
"""
with Pool(processes=num_processes) as pool:
result = pool.map(summarize_a_molecule, smiles)
items = list(result[0].keys())
items.remove('valid')
summary_for_valid = defaultdict(list)
for summary in result:
if summary['valid']:
for k in items:
summary_for_valid[k].append(summary[k])
return summary_for_valid
def get_unique_smiles(smiles):
"""Given a list of smiles, return a list consisting of unique elements in it.
Parameters
----------
smiles : list of str
Molecules in SMILES
Returns
-------
list of str
Sublist where each SMIES occurs exactly once
"""
unique_set = set()
for mol_s in smiles:
if mol_s not in unique_set:
unique_set.add(mol_s)
return list(unique_set)
def get_novel_smiles(new_unique_smiles, reference_unique_smiles):
"""Get novel smiles which do not appear in the reference set.
Parameters
----------
new_unique_smiles : list of str
List of SMILES from which we want to identify novel ones
reference_unique_smiles : list of str
List of reference SMILES that we already have
"""
return set(new_unique_smiles).difference(set(reference_unique_smiles))
# Junction Tree Variational Autoencoder for Molecular Graph Generation (JTNN)
Wengong Jin, Regina Barzilay, Tommi Jaakkola.
Junction Tree Variational Autoencoder for Molecular Graph Generation.
*arXiv preprint arXiv:1802.04364*, 2018.
JTNN uses algorithm called junction tree algorithm to form a tree from the molecular graph.
Then the model will encode the tree and graph into two separate vectors `z_G` and `z_T`. Details can
be found in original paper. The brief process is as below (from original paper):
![image](https://user-images.githubusercontent.com/8686776/63677300-3fb6d980-c81f-11e9-8a65-57c8b03aaf52.png)
**Goal**: JTNN is an auto-encoder model, aiming to learn hidden representation for molecular graphs.
These representations can be used for downstream tasks, such as property prediction, or molecule optimizations.
## Dataset
### ZINC
> The ZINC database is a curated collection of commercially available chemical compounds
prepared especially for virtual screening. (introduction from Wikipedia)
Generally speaking, molecules in the ZINC dataset are more drug-like. We uses ~220,000
molecules for training and 5000 molecules for validation.
### Preprocessing
Class `JTNNDataset` will process a SMILES string into a dict, consisting of a junction tree, a graph with
encoded nodes(atoms) and edges(bonds), and other information for model to use.
## Usage
### Training
To start training, use `python train.py`. By default, the script will use ZINC dataset
with preprocessed vocabulary, and save model checkpoint periodically in the current working directory.
### Evaluation
To start evaluation, use `python reconstruct_eval.py`. By default, we will perform evaluation with
DGL's pre-trained model. During the evaluation, the program will print out the success rate of
molecule reconstruction.
### Pre-trained models
Below gives the statistics of our pre-trained `JTNN_ZINC` model.
| Pre-trained model | % Reconstruction Accuracy
| ------------------ | -------
| `JTNN_ZINC` | 73.7
### Visualization
Here we draw some "neighbor" of a given molecule, by adding noises on the intermediate representations.
You can download the script with `wget https://data.dgl.ai/dgllife/jtnn_viz_neighbor_mol.ipynb`.
Please put this script at the current directory (`examples/pytorch/model_zoo/chem/generative_models/jtnn/`).
#### Given Molecule
![image](https://user-images.githubusercontent.com/8686776/63773593-0d37da00-c90e-11e9-8933-0abca4b430db.png)
#### Neighbor Molecules
![image](https://user-images.githubusercontent.com/8686776/63773602-1163f780-c90e-11e9-8341-5122dc0d0c82.png)
### Dataset configuration
If you want to use your own dataset, please create a file with one SMILES a line as below
```
CCO
Fc1ccccc1
```
You can generate the vocabulary file corresponding to your dataset with `python vocab.py -d X -v Y`, where `X`
is the path to the dataset and `Y` is the path to the vocabulary file to save. An example vocabulary file
corresponding to the two molecules above will be
```
CC
CF
C1=CC=CC=C1
CO
```
If you want to develop a model based on DGL's pre-trained model, it's important to make sure that the vocabulary
generated above is a subset of the vocabulary we use for the pre-trained model. By running `vocab.py` above, we
also check if the new vocabulary is a subset of the vocabulary we use for the pre-trained model and print the
result in the terminal as follows:
```
The new vocabulary is a subset of the default vocabulary: True
```
To train on this new dataset, run
```
python train.py -t X
```
where `X` is the path to the new dataset. If you want to use the vocabulary generated above, also add `-v Y`, where
`Y` is the path to the vocabulary file we just saved.
To evaluate on this new dataset, run `python reconstruct_eval.py` with arguments same as above.
from .mol_tree import Vocab
from .datautils import JTNNDataset, JTNNCollator
from .chemutils import decode_stereo
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment