Unverified Commit 36c7b771 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[LifeSci] Move to Independent Repo (#1592)

* Move LifeSci

* Remove doc
parent 94c67203
.. _apimodelzoo:
Model Zoo
=========
This section introduces complete models for various downstream tasks.
.. contents:: Contents
:local:
Building Blocks
---------------
MLP Predictor
`````````````
.. automodule:: dgllife.model.model_zoo.mlp_predictor
:members:
Molecular Property Prediction
-----------------------------
AttentiveFP Predictor
`````````````````````
.. automodule:: dgllife.model.model_zoo.attentivefp_predictor
:members:
GAT Predictor
`````````````
.. automodule:: dgllife.model.model_zoo.gat_predictor
:members:
GCN Predictor
`````````````
.. automodule:: dgllife.model.model_zoo.gcn_predictor
:members:
MGCN Predictor
``````````````
.. automodule:: dgllife.model.model_zoo.mgcn_predictor
:members:
MPNN Predictor
``````````````
.. automodule:: dgllife.model.model_zoo.mpnn_predictor
:members:
SchNet Predictor
````````````````
.. automodule:: dgllife.model.model_zoo.schnet_predictor
:members:
Weave Predictor
```````````````
.. automodule:: dgllife.model.model_zoo.weave_predictor
:members:
GIN Predictor
`````````````
.. automodule:: dgllife.model.model_zoo.gin_predictor
:members:
Generative Models
-----------------
DGMG
````
.. automodule:: dgllife.model.model_zoo.dgmg
:members:
JTNN
````
.. autoclass:: dgllife.model.model_zoo.jtnn.DGLJTNNVAE
:members:
Reaction Prediction
WLN for Reaction Center Prediction
``````````````````````````````````
.. automodule:: dgllife.model.model_zoo.wln_reaction_center
:members:
WLN for Ranking Candidate Products
``````````````````````````````````
.. automodule:: dgllife.model.model_zoo.wln_reaction_ranking
:members:
Protein-Ligand Binding Affinity Prediction
ACNN
````
.. automodule:: dgllife.model.model_zoo.acnn
:members:
\ No newline at end of file
.. _apiutilscomplexes:
Utils for protein-ligand complexes
==================================
Utilities in DGL-LifeSci for working with protein-ligand complexes.
.. autosummary::
:toctree: ../generated/
dgllife.utils.ACNN_graph_construction_and_featurization
.. _apiutilsmols:
Utils for Molecules
===================
Utilities in DGL-LifeSci for working with molecules.
RDKit Utils
-----------
RDKit utils for loading molecules and accessing their information.
.. autosummary::
:toctree: ../generated/
dgllife.utils.get_mol_3d_coordinates
dgllife.utils.load_molecule
dgllife.utils.multiprocess_load_molecules
Graph Construction
------------------
The modeling of graph neural networks starts with constructing appropriate graph topologies. We provide
three common graph constructions:
* ``bigraph``: Bi-directed graphs corresponding exactly to molecular graphs
* ``complete_graph``: Graphs with all pairs of atoms connected
* ``nearest_neighbor_graph``: Graphs where each atom is connected to its closest (k) atoms based on molecule coordinates
.. autosummary::
:toctree: ../generated/
dgllife.utils.mol_to_graph
dgllife.utils.smiles_to_bigraph
dgllife.utils.mol_to_bigraph
dgllife.utils.smiles_to_complete_graph
dgllife.utils.mol_to_complete_graph
dgllife.utils.k_nearest_neighbors
dgllife.utils.mol_to_nearest_neighbor_graph
dgllife.utils.smiles_to_nearest_neighbor_graph
Featurization for Molecules
---------------------------
To apply graph neural networks, we need to prepare node and edge features for molecules. Intuitively,
they can be developed based on various descriptors (features) of atoms/bonds/molecules. Particularly, we can
work with numerical descriptors directly or use ``one_hot_encoding`` for categorical descriptors. When using
multiple descriptors together, we can simply concatenate them with ``ConcatFeaturizer``.
General Utils
```````````
.. autosummary::
:toctree: ../generated/
dgllife.utils.one_hot_encoding
dgllife.utils.ConcatFeaturizer
Featurization for Nodes
```````````````````````
We consider the following atom descriptors:
* type/atomic number
* degree (excluding neighboring hydrogen atoms)
* total degree (including neighboring hydrogen atoms)
* explicit valence
* implicit valence
* hybridization
* total number of neighboring hydrogen atoms
* formal charge
* number of radical electrons
* aromatic atom
* ring membership
* chirality
* mass
We can employ their numerical values directly or with one-hot encoding.
.. autosummary::
:toctree: ../generated/
dgllife.utils.atom_type_one_hot
dgllife.utils.atomic_number_one_hot
dgllife.utils.atomic_number
dgllife.utils.atom_degree_one_hot
dgllife.utils.atom_degree
dgllife.utils.atom_total_degree_one_hot
dgllife.utils.atom_total_degree
dgllife.utils.atom_explicit_valence_one_hot
dgllife.utils.atom_explicit_valence
dgllife.utils.atom_implicit_valence_one_hot
dgllife.utils.atom_implicit_valence
dgllife.utils.atom_hybridization_one_hot
dgllife.utils.atom_total_num_H_one_hot
dgllife.utils.atom_total_num_H
dgllife.utils.atom_formal_charge_one_hot
dgllife.utils.atom_formal_charge
dgllife.utils.atom_num_radical_electrons_one_hot
dgllife.utils.atom_num_radical_electrons
dgllife.utils.atom_is_aromatic_one_hot
dgllife.utils.atom_is_aromatic
dgllife.utils.atom_is_in_ring_one_hot
dgllife.utils.atom_is_in_ring
dgllife.utils.atom_chiral_tag_one_hot
dgllife.utils.atom_mass
For using featurization methods like above in creating node features:
.. autosummary::
:toctree: ../generated/
dgllife.utils.BaseAtomFeaturizer
dgllife.utils.BaseAtomFeaturizer.feat_size
dgllife.utils.CanonicalAtomFeaturizer
dgllife.utils.CanonicalAtomFeaturizer.feat_size
dgllife.utils.PretrainAtomFeaturizer
Featurization for Edges
```````````````````````
We consider the following bond descriptors:
* type
* conjugated bond
* ring membership
* stereo configuration
.. autosummary::
:toctree: ../generated/
dgllife.utils.bond_type_one_hot
dgllife.utils.bond_is_conjugated_one_hot
dgllife.utils.bond_is_conjugated
dgllife.utils.bond_is_in_ring_one_hot
dgllife.utils.bond_is_in_ring
dgllife.utils.bond_stereo_one_hot
dgllife.utils.bond_direction_one_hot
For using featurization methods like above in creating edge features:
.. autosummary::
:toctree: ../generated/
dgllife.utils.BaseBondFeaturizer
dgllife.utils.BaseBondFeaturizer.feat_size
dgllife.utils.CanonicalBondFeaturizer
dgllife.utils.CanonicalBondFeaturizer.feat_size
dgllife.utils.PretrainBondFeaturizer
.. _apiutilspipeline:
Model Development Pipeline
==========================
.. contents:: Contents
:local:
Model Evaluation
----------------
A utility class for evaluating model performance on (multi-label) supervised learning.
.. autoclass:: dgllife.utils.Meter
:members: update, compute_metric
Early Stopping
--------------
Early stopping is a standard practice for preventing models from overfitting and we provide a utility
class for handling it.
.. autoclass:: dgllife.utils.EarlyStopping
:members:
.. _apiutilssplitters:
Splitting Datasets
==================
We provide multiple splitting methods for datasets.
.. contents:: Contents
:local:
ConsecutiveSplitter
-------------------
.. autoclass:: dgllife.utils.ConsecutiveSplitter
:members: train_val_test_split, k_fold_split
RandomSplitter
--------------
.. autoclass:: dgllife.utils.RandomSplitter
:members: train_val_test_split, k_fold_split
MolecularWeightSplitter
-----------------------
.. autoclass:: dgllife.utils.MolecularWeightSplitter
:members: train_val_test_split, k_fold_split
ScaffoldSplitter
----------------
.. autoclass:: dgllife.utils.ScaffoldSplitter
:members: train_val_test_split, k_fold_split
SingleTaskStratifiedSplitter
----------------------------
.. autoclass:: dgllife.utils.SingleTaskStratifiedSplitter
:members: train_val_test_split, k_fold_split
# -*- coding: utf-8 -*-
#
# Configuration file for the Sphinx documentation builder.
#
# This file does only contain a selection of the most common options. For a
# full list see the documentation:
# http://www.sphinx-doc.org/en/master/config
# -- Path setup --------------------------------------------------------------
# If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.
#
import os
import sys
sys.path.insert(0, os.path.abspath('../../python'))
# -- Project information -----------------------------------------------------
project = 'DGL-LifeSci'
copyright = '2020, DGL Team'
author = 'DGL Team'
import dgllife
version = dgllife.__version__
release = dgllife.__version__
# -- General configuration ---------------------------------------------------
# If your documentation needs a minimal Sphinx version, state it here.
#
# needs_sphinx = '1.0'
# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = [
'sphinx.ext.autodoc',
'sphinx.ext.autosummary',
'sphinx.ext.coverage',
'sphinx.ext.mathjax',
'sphinx.ext.napoleon',
'sphinx.ext.viewcode',
'sphinx.ext.intersphinx',
'sphinx.ext.graphviz',
'sphinx_gallery.gen_gallery',
]
# Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates']
# The suffix(es) of source filenames.
# You can specify multiple suffix as a list of string:
#
source_suffix = ['.rst', '.md']
# The master toctree document.
master_doc = 'index'
# The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages.
#
# This is also used if you do content translation via gettext catalogs.
# Usually you set "language" from the command line for these cases.
language = None
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This pattern also affects html_static_path and html_extra_path.
exclude_patterns = []
# The name of the Pygments (syntax highlighting) style to use.
pygments_style = None
# -- Options for HTML output -------------------------------------------------
# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
#
html_theme = 'sphinx_rtd_theme'
# Theme options are theme-specific and customize the look and feel of a theme
# further. For a list of options available for each theme, see the
# documentation.
#
# html_theme_options = {}
# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ['_static']
# Custom sidebar templates, must be a dictionary that maps document names
# to template names.
#
# The default sidebars (for documents that don't match any pattern) are
# defined by theme itself. Builtin themes are using these templates by
# default: ``['localtoc.html', 'relations.html', 'sourcelink.html',
# 'searchbox.html']``.
#
# html_sidebars = {}
# -- Options for HTMLHelp output ---------------------------------------------
# Output file base name for HTML help builder.
htmlhelp_basename = 'dgllifedoc'
# -- Options for LaTeX output ------------------------------------------------
latex_elements = {
# The paper size ('letterpaper' or 'a4paper').
#
# 'papersize': 'letterpaper',
# The font size ('10pt', '11pt' or '12pt').
#
# 'pointsize': '10pt',
# Additional stuff for the LaTeX preamble.
#
# 'preamble': '',
# Latex figure (float) alignment
#
# 'figure_align': 'htbp',
}
# Grouping the document tree into LaTeX files. List of tuples
# (source start file, target name, title,
# author, documentclass [howto, manual, or own class]).
latex_documents = [
(master_doc, 'dgllife.tex', 'DGL-LifeSci Documentation',
'DGL Team', 'manual'),
]
# -- Options for manual page output ------------------------------------------
# One entry per manual page. List of tuples
# (source start file, name, description, authors, manual section).
man_pages = [
(master_doc, 'dgllife', 'DGL-LifeSci Documentation',
[author], 1)
]
# -- Options for Texinfo output ----------------------------------------------
# Grouping the document tree into Texinfo files. List of tuples
# (source start file, target name, title, author,
# dir menu entry, description, category)
texinfo_documents = [
(master_doc, 'dgllife', 'DGL-LifeSci Documentation',
author, 'dgllife', 'Application library for life science.',
'Miscellaneous'),
]
# -- Options for Epub output -------------------------------------------------
# Bibliographic Dublin Core info.
epub_title = project
# The unique identifier of the text. This can be a ISBN number
# or the project homepage.
#
# epub_identifier = ''
# A unique identification for the text.
#
# epub_uid = ''
# A list of files that should not be packed into the epub file.
epub_exclude_files = ['search.html']
# -- Extension configuration -------------------------------------------------
autosummary_generate = True
intersphinx_mapping = {
'python': ('https://docs.python.org/{.major}'.format(sys.version_info), None),
'numpy': ('http://docs.scipy.org/doc/numpy/', None),
'scipy': ('http://docs.scipy.org/doc/scipy/reference', None),
'matplotlib': ('http://matplotlib.org/', None),
'networkx' : ('https://networkx.github.io/documentation/stable', None),
}
# sphinx gallery configurations
from sphinx_gallery.sorting import FileNameSortKey
examples_dirs = [] # path to find sources
gallery_dirs = [] # path to generate docs
reference_url = {
'dgllife' : None,
'numpy': 'http://docs.scipy.org/doc/numpy/',
'scipy': 'http://docs.scipy.org/doc/scipy/reference',
'matplotlib': 'http://matplotlib.org/',
'networkx' : 'https://networkx.github.io/documentation/stable',
}
sphinx_gallery_conf = {
'backreferences_dir' : 'generated/backreferences',
'doc_module' : ('dgllife', 'numpy'),
'examples_dirs' : examples_dirs,
'gallery_dirs' : gallery_dirs,
'within_subsection_order' : FileNameSortKey,
'filename_pattern' : '.py',
'download_all_examples' : False,
}
DGL-LifeSci: Bringing Graph Neural Networks to Chemistry and Biology
===========================================================================================
DGL-LifeSci is a python package for applying graph neural networks to various tasks in chemistry
and biology, on top of PyTorch and DGL. It provides:
* Various utilities for data processing, training and evaluation.
* Efficient and flexible model implementations.
* Pre-trained models for use without training from scratch.
We cover various applications in our
`examples <https://github.com/dmlc/dgl/tree/master/apps/life_sci/examples>`_, including:
* `Molecular property prediction <https://github.com/dmlc/dgl/tree/master/apps/life_sci/examples/property_prediction>`_
* `Generative models <https://github.com/dmlc/dgl/tree/master/apps/life_sci/examples/generative_models>`_
* `Protein-ligand binding affinity prediction <https://github.com/dmlc/dgl/tree/master/apps/life_sci/examples/binding_affinity_prediction>`_
* `Reaction prediction <https://github.com/dmlc/dgl/tree/master/apps/life_sci/examples/reaction_prediction>`_
Get Started
------------
Follow the :doc:`instructions<install/index>` to install DGL.
.. toctree::
:maxdepth: 1
:caption: Installation
:hidden:
:glob:
install/index
.. toctree::
:maxdepth: 2
:caption: API Reference
:hidden:
:glob:
api/utils.mols
api/utils.splitters
api/utils.pipeline
api/utils.complexes
api/data
api/model.pretrain
api/model.gnn
api/model.readout
api/model.zoo
Free software
-------------
DGL-LifeSci is free software; you can redistribute it and/or modify it under the terms
of the Apache License 2.0. We welcome contributions. Join us on `GitHub <https://github.com/dmlc/dgl/tree/master/apps/life_sci>`_.
Index
-----
* :ref:`genindex`
Install DGL-LifeSci
===================
This topic explains how to install DGL-LifeSci. We recommend installing DGL-LifeSci by using ``conda`` or ``pip``.
System requirements
-------------------
DGL-LifeSci works with the following operating systems:
* Ubuntu 16.04
* macOS X
* Windows 10
DGL-LifeSci requires:
* Python 3.6 or later
* `DGL 0.4.3 or later <https://www.dgl.ai/pages/start.html>`_
* `PyTorch 1.2.0 or later <https://pytorch.org/>`_
If you have just installed DGL, the first time you use it, a message will pop up as follows:
.. code:: bash
DGL does not detect a valid backend option. Which backend would you like to work with?
Backend choice (pytorch, mxnet or tensorflow):
and you need to enter ``pytorch``.
Additionally, we require **RDKit 2018.09.3** for cheminformatics. We recommend installing it with
.. code:: bash
conda install -c conda-forge rdkit==2018.09.3
Other verions of RDKit are not tested.
Install from conda
----------------------
If ``conda`` is not yet installed, get either `miniconda <https://conda.io/miniconda.html>`_ or
the full `anaconda <https://www.anaconda.com/download/>`_.
.. code:: bash
conda install -c dglteam dgllife
Install from pip
----------------
.. code:: bash
pip install dgllife
.. _install-from-source:
Install from source
-------------------
To use the latest experimental features,
.. code:: bash
git clone https://github.com/dmlc/dgl.git
cd apps/life_sci/python
python setup.py install
# Work Implemented in DGL-LifeSci
We provide various examples across 3 applications -- property prediction, generative models and protein-ligand binding affinity prediction.
## Datasets/Benchmarks
- MoleculeNet: A Benchmark for Molecular Machine Learning [[paper]](https://arxiv.org/abs/1703.00564), [[website]](http://moleculenet.ai/)
- [Tox21 with DGL](../python/dgllife/data/tox21.py)
- [PDBBind with DGL](../python/dgllife/data/pdbbind.py)
- Alchemy: A Quantum Chemistry Dataset for Benchmarking AI Models [[paper]](https://arxiv.org/abs/1906.09427), [[github]](https://github.com/tencent-alchemy/Alchemy)
- [Alchemy with DGL](../python/dgllife/data/alchemy.py)
## Property Prediction
- Molecular graph convolutions: moving beyond fingerprints (Weave) [[paper]](https://arxiv.org/abs/1603.00856), [[github]](https://github.com/deepchem/deepchem)
- [Weave Predictor with DGL](../python/dgllife/model/model_zoo/weave_predictor.py)
- [Example for Molecule Classification](property_prediction/classification.py)
- Semi-Supervised Classification with Graph Convolutional Networks (GCN) [[paper]](https://arxiv.org/abs/1609.02907), [[github]](https://github.com/tkipf/gcn)
- [GCN-Based Predictor with DGL](../python/dgllife/model/model_zoo/gcn_predictor.py)
- [Example for Molecule Classification](property_prediction/classification.py)
- Graph Attention Networks (GAT) [[paper]](https://arxiv.org/abs/1710.10903), [[github]](https://github.com/PetarV-/GAT)
- [GAT-Based Predictor with DGL](../python/dgllife/model/model_zoo/gat_predictor.py)
- [Example for Molecule Classification](property_prediction/classification.py)
- SchNet: A continuous-filter convolutional neural network for modeling quantum interactions [[paper]](https://arxiv.org/abs/1706.08566), [[github]](https://github.com/atomistic-machine-learning/SchNet)
- [SchNet with DGL](../python/dgllife/model/model_zoo/schnet_predictor.py)
- [Example for Molecule Regression](property_prediction/regression.py)
- Molecular Property Prediction: A Multilevel Quantum Interactions Modeling Perspective (MGCN) [[paper]](https://arxiv.org/abs/1906.11081)
- [MGCN with DGL](../python/dgllife/model/model_zoo/mgcn_predictor.py)
- [Example for Molecule Regression](property_prediction/regression.py)
- Neural Message Passing for Quantum Chemistry (MPNN) [[paper]](https://arxiv.org/abs/1704.01212), [[github]](https://github.com/brain-research/mpnn)
- [MPNN with DGL](../python/dgllife/model/model_zoo/mpnn_predictor.py)
- [Example for Molecule Regression](property_prediction/regression.py)
- Pushing the Boundaries of Molecular Representation for Drug Discovery with the Graph Attention Mechanism (AttentiveFP) [[paper]](https://pubs.acs.org/doi/abs/10.1021/acs.jmedchem.9b00959)
- [AttentiveFP with DGL](../python/dgllife/model/model_zoo/attentivefp_predictor.py)
- [Example for Molecule Regression](property_prediction/regression.py)
## Generative Models
- Learning Deep Generative Models of Graphs (DGMG) [[paper]](https://arxiv.org/abs/1803.03324)
- [DGMG with DGL](../python/dgllife/model/model_zoo/dgmg.py)
- [Example Training Script](generative_models/dgmg)
- Junction Tree Variational Autoencoder for Molecular Graph Generation (JTNN) [[paper]](https://arxiv.org/abs/1802.04364)
- [JTNN with DGL](../python/dgllife/model/model_zoo/jtnn)
- [Example Training Script](generative_models/jtnn)
## Binding Affinity Prediction
- Atomic Convolutional Networks for Predicting Protein-Ligand Binding Affinity (ACNN) [[paper]](https://arxiv.org/abs/1703.10603), [[github]](https://github.com/deepchem/deepchem/tree/master/contrib/atomicconv)
- [ACNN with DGL](../python/dgllife/model/model_zoo/acnn.py)
- [Example Training Script](binding_affinity_prediction)
## Reaction Prediction
- A graph-convolutional neural network model for the prediction of chemical reactivity [[paper]](https://pubs.rsc.org/en/content/articlelanding/2019/sc/c8sc04228d#!divAbstract), [[github]](https://github.com/connorcoley/rexgen_direct)
- An earlier version was published in NeurIPS 2017 as "Predicting Organic Reaction Outcomes with Weisfeiler-Lehman Network" [[paper]](https://arxiv.org/abs/1709.04555)
- [WLN with DGL for Reaction Center Prediction](../python/dgllife/model/model_zoo/wln_reaction_center.py)
- [Example Script](reaction_prediction/rexgen_direct)
# Binding Affinity Prediction
## Datasets
- **PDBBind**: The PDBBind dataset in MoleculeNet [1] processed from the PDBBind database. The PDBBind
database consists of experimentally measured binding affinities for bio-molecular complexes [2], [3].
It provides detailed 3D Cartesian coordinates of both ligands and their target proteins derived from
experimental(e.g., X-ray crystallography) measurements. The availability of coordinates of the
protein-ligand complexes permits structure-based featurization that is aware of the protein-ligand
binding geometry. The authors of [1] use the "refined" and "core" subsets of the database [4], more carefully
processed for data artifacts, as additional benchmarking targets.
## Models
- **Atomic Convolutional Networks (ACNN)** [5]: Constructs nearest neighbor graphs separately for the ligand, protein and complex
based on the 3D coordinates of the atoms and predicts the binding free energy.
## Usage
Use `main.py` with arguments
```
-m {ACNN}, Model to use
-d {PDBBind_core_pocket_random, PDBBind_core_pocket_scaffold, PDBBind_core_pocket_stratified,
PDBBind_core_pocket_temporal, PDBBind_refined_pocket_random, PDBBind_refined_pocket_scaffold,
PDBBind_refined_pocket_stratified, PDBBind_refined_pocket_temporal}, dataset and splitting method to use
```
## Performance
### PDBBind
#### ACNN
| Subset | Splitting Method | Test MAE | Test R2 |
| ------- | ---------------- | -------- | ------- |
| Core | Random | 1.7688 | 0.1511 |
| Core | Scaffold | 2.5420 | 0.1471 |
| Core | Stratified | 1.7419 | 0.1520 |
| Core | Temporal | 1.9543 | 0.1640 |
| Refined | Random | 1.1948 | 0.4373 |
| Refined | Scaffold | 1.4021 | 0.2086 |
| Refined | Stratified | 1.6376 | 0.3050 |
| Refined | Temporal | 1.2457 | 0.3438 |
## Speed
### ACNN
Comparing to the [DeepChem's implementation](https://github.com/joegomes/deepchem/tree/acdc), we achieve a speedup by
roughly 3.3 for training time per epoch (from 1.40s to 0.42s). If we do not care about
randomness introduced by some kernel optimization, we can achieve a speedup by roughly 4.4 (from 1.40s to 0.32s).
## References
[1] Wu et al. (2017) MoleculeNet: a benchmark for molecular machine learning. *Chemical Science* 9, 513-530.
[2] Wang et al. (2004) The PDBbind database: collection of binding affinities for protein-ligand complexes
with known three-dimensional structures. *J Med Chem* 3;47(12):2977-80.
[3] Wang et al. (2005) The PDBbind database: methodologies and updates. *J Med Chem* 16;48(12):4111-9.
[4] Liu et al. (2015) PDB-wide collection of binding data: current status of the PDBbind database. *Bioinformatics* 1;31(3):405-12.
[5] Gomes et al. (2017) Atomic Convolutional Networks for Predicting Protein-Ligand Binding Affinity. *arXiv preprint arXiv:1703.10603*.
import numpy as np
import torch
ACNN_PDBBind_core_pocket_random = {
'dataset': 'PDBBind',
'subset': 'core',
'load_binding_pocket': True,
'random_seed': 123,
'frac_train': 0.8,
'frac_val': 0.,
'frac_test': 0.2,
'batch_size': 24,
'shuffle': False,
'hidden_sizes': [32, 32, 16],
'weight_init_stddevs': [1. / float(np.sqrt(32)), 1. / float(np.sqrt(32)),
1. / float(np.sqrt(16)), 0.01],
'dropouts': [0., 0., 0.],
'atomic_numbers_considered': torch.tensor([
1., 6., 7., 8., 9., 11., 12., 15., 16., 17., 20., 25., 30., 35., 53.]),
'radial': [[12.0], [0.0, 4.0, 8.0], [4.0]],
'lr': 0.001,
'num_epochs': 120,
'metrics': ['r2', 'mae'],
'split': 'random'
}
ACNN_PDBBind_core_pocket_scaffold = {
'dataset': 'PDBBind',
'subset': 'core',
'load_binding_pocket': True,
'random_seed': 123,
'frac_train': 0.8,
'frac_val': 0.,
'frac_test': 0.2,
'batch_size': 24,
'shuffle': False,
'hidden_sizes': [32, 32, 16],
'weight_init_stddevs': [1. / float(np.sqrt(32)), 1. / float(np.sqrt(32)),
1. / float(np.sqrt(16)), 0.01],
'dropouts': [0., 0., 0.],
'atomic_numbers_considered': torch.tensor([
1., 6., 7., 8., 9., 11., 12., 15., 16., 17., 20., 25., 30., 35., 53.]),
'radial': [[12.0], [0.0, 4.0, 8.0], [4.0]],
'lr': 0.001,
'num_epochs': 170,
'metrics': ['r2', 'mae'],
'split': 'scaffold'
}
ACNN_PDBBind_core_pocket_stratified = {
'dataset': 'PDBBind',
'subset': 'core',
'load_binding_pocket': True,
'random_seed': 123,
'frac_train': 0.8,
'frac_val': 0.,
'frac_test': 0.2,
'batch_size': 24,
'shuffle': False,
'hidden_sizes': [32, 32, 16],
'weight_init_stddevs': [1. / float(np.sqrt(32)), 1. / float(np.sqrt(32)),
1. / float(np.sqrt(16)), 0.01],
'dropouts': [0., 0., 0.],
'atomic_numbers_considered': torch.tensor([
1., 6., 7., 8., 9., 11., 12., 15., 16., 17., 20., 25., 30., 35., 53.]),
'radial': [[12.0], [0.0, 4.0, 8.0], [4.0]],
'lr': 0.001,
'num_epochs': 110,
'metrics': ['r2', 'mae'],
'split': 'stratified'
}
ACNN_PDBBind_core_pocket_temporal = {
'dataset': 'PDBBind',
'subset': 'core',
'load_binding_pocket': True,
'random_seed': 123,
'frac_train': 0.8,
'frac_val': 0.,
'frac_test': 0.2,
'batch_size': 24,
'shuffle': False,
'hidden_sizes': [32, 32, 16],
'weight_init_stddevs': [1. / float(np.sqrt(32)), 1. / float(np.sqrt(32)),
1. / float(np.sqrt(16)), 0.01],
'dropouts': [0., 0., 0.],
'atomic_numbers_considered': torch.tensor([
1., 6., 7., 8., 9., 11., 12., 15., 16., 17., 20., 25., 30., 35., 53.]),
'radial': [[12.0], [0.0, 4.0, 8.0], [4.0]],
'lr': 0.001,
'num_epochs': 80,
'metrics': ['r2', 'mae'],
'split': 'temporal'
}
ACNN_PDBBind_refined_pocket_random = {
'dataset': 'PDBBind',
'subset': 'refined',
'load_binding_pocket': True,
'random_seed': 123,
'frac_train': 0.8,
'frac_val': 0.,
'frac_test': 0.2,
'batch_size': 24,
'shuffle': False,
'hidden_sizes': [128, 128, 64],
'weight_init_stddevs': [0.125, 0.125, 0.177, 0.01],
'dropouts': [0.4, 0.4, 0.],
'atomic_numbers_considered': torch.tensor([
1., 6., 7., 8., 9., 11., 12., 15., 16., 17., 19., 20., 25., 26., 27., 28.,
29., 30., 34., 35., 38., 48., 53., 55., 80.]),
'radial': [[12.0], [0.0, 2.0, 4.0, 6.0, 8.0], [4.0]],
'lr': 0.001,
'num_epochs': 200,
'metrics': ['r2', 'mae'],
'split': 'random'
}
ACNN_PDBBind_refined_pocket_scaffold = {
'dataset': 'PDBBind',
'subset': 'refined',
'load_binding_pocket': True,
'random_seed': 123,
'frac_train': 0.8,
'frac_val': 0.,
'frac_test': 0.2,
'batch_size': 24,
'shuffle': False,
'hidden_sizes': [128, 128, 64],
'weight_init_stddevs': [0.125, 0.125, 0.177, 0.01],
'dropouts': [0.4, 0.4, 0.],
'atomic_numbers_considered': torch.tensor([
1., 6., 7., 8., 9., 11., 12., 15., 16., 17., 19., 20., 25., 26., 27., 28.,
29., 30., 34., 35., 38., 48., 53., 55., 80.]),
'radial': [[12.0], [0.0, 2.0, 4.0, 6.0, 8.0], [4.0]],
'lr': 0.001,
'num_epochs': 350,
'metrics': ['r2', 'mae'],
'split': 'scaffold'
}
ACNN_PDBBind_refined_pocket_stratified = {
'dataset': 'PDBBind',
'subset': 'refined',
'load_binding_pocket': True,
'random_seed': 123,
'frac_train': 0.8,
'frac_val': 0.,
'frac_test': 0.2,
'batch_size': 24,
'shuffle': False,
'hidden_sizes': [128, 128, 64],
'weight_init_stddevs': [0.125, 0.125, 0.177, 0.01],
'dropouts': [0.4, 0.4, 0.],
'atomic_numbers_considered': torch.tensor([
1., 6., 7., 8., 9., 11., 12., 15., 16., 17., 19., 20., 25., 26., 27., 28.,
29., 30., 34., 35., 38., 48., 53., 55., 80.]),
'radial': [[12.0], [0.0, 2.0, 4.0, 6.0, 8.0], [4.0]],
'lr': 0.001,
'num_epochs': 400,
'metrics': ['r2', 'mae'],
'split': 'stratified'
}
ACNN_PDBBind_refined_pocket_temporal = {
'dataset': 'PDBBind',
'subset': 'refined',
'load_binding_pocket': True,
'random_seed': 123,
'frac_train': 0.8,
'frac_val': 0.,
'frac_test': 0.2,
'batch_size': 24,
'shuffle': False,
'hidden_sizes': [128, 128, 64],
'weight_init_stddevs': [0.125, 0.125, 0.177, 0.01],
'dropouts': [0.4, 0.4, 0.],
'atomic_numbers_considered': torch.tensor([
1., 6., 7., 8., 9., 11., 12., 15., 16., 17., 19., 20., 25., 26., 27., 28.,
29., 30., 34., 35., 38., 48., 53., 55., 80.]),
'radial': [[12.0], [0.0, 2.0, 4.0, 6.0, 8.0], [4.0]],
'lr': 0.001,
'num_epochs': 350,
'metrics': ['r2', 'mae'],
'split': 'temporal'
}
experiment_configures = {
'ACNN_PDBBind_core_pocket_random': ACNN_PDBBind_core_pocket_random,
'ACNN_PDBBind_core_pocket_scaffold': ACNN_PDBBind_core_pocket_scaffold,
'ACNN_PDBBind_core_pocket_stratified': ACNN_PDBBind_core_pocket_stratified,
'ACNN_PDBBind_core_pocket_temporal': ACNN_PDBBind_core_pocket_temporal,
'ACNN_PDBBind_refined_pocket_random': ACNN_PDBBind_refined_pocket_random,
'ACNN_PDBBind_refined_pocket_scaffold': ACNN_PDBBind_refined_pocket_scaffold,
'ACNN_PDBBind_refined_pocket_stratified': ACNN_PDBBind_refined_pocket_stratified,
'ACNN_PDBBind_refined_pocket_temporal': ACNN_PDBBind_refined_pocket_temporal
}
def get_exp_configure(exp_name):
return experiment_configures[exp_name]
import torch
import torch.nn as nn
from dgllife.utils.eval import Meter
from torch.utils.data import DataLoader
from utils import set_random_seed, load_dataset, collate, load_model
def update_msg_from_scores(msg, scores):
for metric, score in scores.items():
msg += ', {} {:.4f}'.format(metric, score)
return msg
def run_a_train_epoch(args, epoch, model, data_loader,
loss_criterion, optimizer):
model.train()
train_meter = Meter(args['train_mean'], args['train_std'])
epoch_loss = 0
for batch_id, batch_data in enumerate(data_loader):
indices, ligand_mols, protein_mols, bg, labels = batch_data
labels, bg = labels.to(args['device']), bg.to(args['device'])
prediction = model(bg)
loss = loss_criterion(prediction, (labels - args['train_mean']) / args['train_std'])
epoch_loss += loss.data.item() * len(indices)
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_meter.update(prediction, labels)
avg_loss = epoch_loss / len(data_loader.dataset)
total_scores = {metric: train_meter.compute_metric(metric, 'mean')
for metric in args['metrics']}
msg = 'epoch {:d}/{:d}, training | loss {:.4f}'.format(
epoch + 1, args['num_epochs'], avg_loss)
msg = update_msg_from_scores(msg, total_scores)
print(msg)
def run_an_eval_epoch(args, model, data_loader):
model.eval()
eval_meter = Meter(args['train_mean'], args['train_std'])
with torch.no_grad():
for batch_id, batch_data in enumerate(data_loader):
indices, ligand_mols, protein_mols, bg, labels = batch_data
labels, bg = labels.to(args['device']), bg.to(args['device'])
prediction = model(bg)
eval_meter.update(prediction, labels)
total_scores = {metric: eval_meter.compute_metric(metric, 'mean')
for metric in args['metrics']}
return total_scores
def main(args):
args['device'] = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
set_random_seed(args['random_seed'])
dataset, train_set, test_set = load_dataset(args)
args['train_mean'] = train_set.labels_mean.to(args['device'])
args['train_std'] = train_set.labels_std.to(args['device'])
train_loader = DataLoader(dataset=train_set,
batch_size=args['batch_size'],
shuffle=False,
collate_fn=collate)
test_loader = DataLoader(dataset=test_set,
batch_size=args['batch_size'],
shuffle=True,
collate_fn=collate)
model = load_model(args)
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=args['lr'])
model.to(args['device'])
for epoch in range(args['num_epochs']):
run_a_train_epoch(args, epoch, model, train_loader, loss_fn, optimizer)
test_scores = run_an_eval_epoch(args, model, test_loader)
test_msg = update_msg_from_scores('test results', test_scores)
print(test_msg)
if __name__ == '__main__':
import argparse
from configure import get_exp_configure
parser = argparse.ArgumentParser(description='Protein-Ligand Binding Affinity Prediction')
parser.add_argument('-m', '--model', type=str, choices=['ACNN'],
help='Model to use')
parser.add_argument('-d', '--dataset', type=str,
choices=['PDBBind_core_pocket_random', 'PDBBind_core_pocket_scaffold',
'PDBBind_core_pocket_stratified', 'PDBBind_core_pocket_temporal',
'PDBBind_refined_pocket_random', 'PDBBind_refined_pocket_scaffold',
'PDBBind_refined_pocket_stratified', 'PDBBind_refined_pocket_temporal'],
help='Dataset to use')
args = parser.parse_args().__dict__
args['exp'] = '_'.join([args['model'], args['dataset']])
args.update(get_exp_configure(args['exp']))
main(args)
import dgl
import numpy as np
import random
import torch
from dgl.data.utils import Subset
from dgllife.data import PDBBind
from dgllife.model import ACNN
from dgllife.utils import RandomSplitter, ScaffoldSplitter, SingleTaskStratifiedSplitter
from itertools import accumulate
def set_random_seed(seed=0):
"""Set random seed.
Parameters
----------
seed : int
Random seed to use. Default to 0.
"""
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
def load_dataset(args):
"""Load the dataset.
Parameters
----------
args : dict
Input arguments.
Returns
-------
dataset
Full dataset.
train_set
Train subset of the dataset.
val_set
Validation subset of the dataset.
"""
assert args['dataset'] in ['PDBBind'], 'Unexpected dataset {}'.format(args['dataset'])
if args['dataset'] == 'PDBBind':
dataset = PDBBind(subset=args['subset'],
load_binding_pocket=args['load_binding_pocket'],
zero_padding=True)
# No validation set is used and frac_val = 0.
if args['split'] == 'random':
train_set, _, test_set = RandomSplitter.train_val_test_split(
dataset,
frac_train=args['frac_train'],
frac_val=args['frac_val'],
frac_test=args['frac_test'],
random_state=args['random_seed'])
elif args['split'] == 'scaffold':
train_set, _, test_set = ScaffoldSplitter.train_val_test_split(
dataset,
mols=dataset.ligand_mols,
sanitize=False,
frac_train=args['frac_train'],
frac_val=args['frac_val'],
frac_test=args['frac_test'])
elif args['split'] == 'stratified':
train_set, _, test_set = SingleTaskStratifiedSplitter.train_val_test_split(
dataset,
labels=dataset.labels,
task_id=0,
frac_train=args['frac_train'],
frac_val=args['frac_val'],
frac_test=args['frac_test'],
random_state=args['random_seed'])
elif args['split'] == 'temporal':
years = dataset.df['release_year'].values.astype(np.float32)
indices = np.argsort(years).tolist()
frac_list = np.array([args['frac_train'], args['frac_val'], args['frac_test']])
num_data = len(dataset)
lengths = (num_data * frac_list).astype(int)
lengths[-1] = num_data - np.sum(lengths[:-1])
train_set, val_set, test_set = [
Subset(dataset, list(indices[offset - length:offset]))
for offset, length in zip(accumulate(lengths), lengths)]
else:
raise ValueError('Expect the splitting method '
'to be "random" or "scaffold", got {}'.format(args['split']))
train_labels = torch.stack([train_set.dataset.labels[i] for i in train_set.indices])
train_set.labels_mean = train_labels.mean(dim=0)
train_set.labels_std = train_labels.std(dim=0)
return dataset, train_set, test_set
def collate(data):
indices, ligand_mols, protein_mols, graphs, labels = map(list, zip(*data))
bg = dgl.batch_hetero(graphs)
for nty in bg.ntypes:
bg.set_n_initializer(dgl.init.zero_initializer, ntype=nty)
for ety in bg.canonical_etypes:
bg.set_e_initializer(dgl.init.zero_initializer, etype=ety)
labels = torch.stack(labels, dim=0)
return indices, ligand_mols, protein_mols, bg, labels
def load_model(args):
assert args['model'] in ['ACNN'], 'Unexpected model {}'.format(args['model'])
if args['model'] == 'ACNN':
model = ACNN(hidden_sizes=args['hidden_sizes'],
weight_init_stddevs=args['weight_init_stddevs'],
dropouts=args['dropouts'],
features_to_use=args['atomic_numbers_considered'],
radial=args['radial'])
return model
# Learning Deep Generative Models of Graphs (DGMG)
Yujia Li, Oriol Vinyals, Chris Dyer, Razvan Pascanu, and Peter Battaglia.
Learning Deep Generative Models of Graphs. *arXiv preprint arXiv:1803.03324*, 2018.
DGMG generates graphs by progressively adding nodes and edges as below:
![](https://user-images.githubusercontent.com/19576924/48605003-7f11e900-e9b6-11e8-8880-87362348e154.png)
For molecules, the nodes are atoms and the edges are bonds.
**Goal**: Given a set of real molecules, we want to learn the distribution of them and get new molecules
with similar properties. See the `Evaluation` section for more details.
## Dataset
### Preprocessing
With our implementation, this model has several limitations:
1. Information about protonation and chirality are ignored during generation
2. Molecules consisting of `[N+]`, `[O-]`, etc. cannot be generated.
For example, the model can only generate `O=C1NC(=S)NC(=O)C1=CNC1=CC=C(N(=O)O)C=C1O` from
`O=C1NC(=S)NC(=O)C1=CNC1=CC=C([N+](=O)[O-])C=C1O` even with the correct decisions.
To avoid issues about validity and novelty, we filter out these molecules from the dataset.
### ChEMBL
The authors use the [ChEMBL database](https://www.ebi.ac.uk/chembl/). Since they
did not release the code, we use a subset from [Olivecrona et al.](https://github.com/MarcusOlivecrona/REINVENT),
another work on generative modeling.
The authors restrict their dataset to molecules with at most 20 heavy atoms, and used a training/validation
split of 130, 830/26, 166 examples each. We use the same split but need to relax 20 to 23 as we are using
a different subset.
### ZINC
After the pre-processing, we are left with 232464 molecules for training and 5000 molecules for validation.
## Usage
### Training
Training auto-regressive generative models tends to be very slow. According to the authors, they use multiprocess to
speed up training and gpu does not give much speed advantage. We follow their approach and perform multiprocess cpu
training.
To start training, use `train.py` with required arguments
```
-d DATASET, dataset to use (default: None), built-in support exists for ChEMBL, ZINC
-o {random,canonical}, order to generate graphs (default: None)
```
and optional arguments
```
-s SEED, random seed (default: 0)
-np NUM_PROCESSES, number of processes to use (default: 32)
```
Even though multiprocess yields a significant speedup comparing to a single process, the training can still take a long
time (several days). An epoch of training and validation can take up to one hour and a half on our machine. If not
necessary, we recommend users use our pre-trained models.
Meanwhile, we make a checkpoint of our model whenever there is a performance improvement on the validation set so you
do not need to wait until the training terminates.
All training results can be found in `training_results`.
#### Dataset configuration
You can also use your own dataset with additional arguments
```
-tf TRAIN_FILE, Path to a file with one SMILES a line for training
data. This is only necessary if you want to use a new
dataset. (default: None)
-vf VAL_FILE, Path to a file with one SMILES a line for validation
data. This is only necessary if you want to use a new
dataset. (default: None)
```
#### Monitoring
We can monitor the training process with tensorboard as below:
![](https://data.dgl.ai/dgllife/dgmg/tensorboard.png)
To use tensorboard, you need to install [tensorboardX](https://github.com/lanpa/tensorboardX) and
[TensorFlow](https://www.tensorflow.org/). You can lunch tensorboard with `tensorboard --logdir=.`
If you are training on a remote server, you can still use it with:
1. Launch it on the remote server with `tensorboard --logdir=. --port=A`
2. In the terminal of your local machine, type `ssh -NfL localhost:B:localhost:A username@your_remote_host_name`
3. Go to the address `localhost:B` in your browser
### Evaluation
To start evaluation, use `eval.py` with required arguments
```
-d DATASET, dataset to use (default: None), built-in support exists for ChEMBL, ZINC
-o {random,canonical}, order to generate graphs, used for naming evaluation directory (default: None)
-p MODEL_PATH, path to saved model (default: None). This is not needed if you want to use pretrained models.
-pr, Whether to use a pre-trained model (default: False)
```
and optional arguments
```
-s SEED, random seed (default: 0)
-ns NUM_SAMPLES, Number of molecules to generate (default: 100000)
-mn MAX_NUM_STEPS, Max number of steps allowed in generated molecules to
ensure termination (default: 400)
-np NUM_PROCESSES, number of processes to use (default: 32)
-gt GENERATION_TIME, max time (seconds) allowed for generation with
multiprocess (default: 600)
```
All evaluation results can be found in `eval_results`.
After the evaluation, 100000 molecules will be generated and stored in `generated_smiles.txt` under `eval_results`
directory, with three statistics logged in `generation_stats.txt` under `eval_results`:
1. `Validity among all` gives the percentage of molecules that are valid
2. `Uniqueness among valid ones` gives the percentage of valid molecules that are unique
3. `Novelty among unique ones` gives the percentage of unique valid molecules that are novel (not seen in training data)
We also provide a jupyter notebook where you can visualize the generated molecules
![](https://data.dgl.ai/dgllife/dgmg/DGMG_ZINC_canonical_vis.png)
and compare their property distributions against the training molecule property distributions
![](https://data.dgl.ai/dgllife/dgmg/DGMG_ZINC_canonical_dist.png)
You can download the notebook with `wget https://data.dgl.ai/dgllife/dgmg/eval_jupyter.ipynb`.
### Pre-trained models
Below gives the statistics of pre-trained models. With random order, the training becomes significantly more difficult
as we now have `N^2` data points with `N` molecules.
| Pre-trained model | % valid | % unique among valid | % novel among unique |
| ------------------ | ------- | -------------------- | -------------------- |
| `ChEMBL_canonical` | 78.80 | 99.19 | 98.60 |
| `ChEMBL_random` | 29.09 | 99.87 | 100.00 |
| `ZINC_canonical` | 74.60 | 99.87 | 99.87 |
| `ZINC_random` | 12.37 | 99.38 | 100.00 |
import os
import pickle
import shutil
import torch
from dgllife.model import DGMG, load_pretrained
from utils import MoleculeDataset, set_random_seed, download_data,\
mkdir_p, summarize_molecules, get_unique_smiles, get_novel_smiles
def generate_and_save(log_dir, num_samples, max_num_steps, model):
with open(os.path.join(log_dir, 'generated_smiles.txt'), 'w') as f:
for i in range(num_samples):
with torch.no_grad():
s = model(rdkit_mol=True, max_num_steps=max_num_steps)
f.write(s + '\n')
def prepare_for_evaluation(rank, args):
worker_seed = args['seed'] + rank * 10000
set_random_seed(worker_seed)
torch.set_num_threads(1)
# Setup dataset and data loader
dataset = MoleculeDataset(args['dataset'], subset_id=rank, n_subsets=args['num_processes'])
# Initialize model
if not args['pretrained']:
model = DGMG(atom_types=dataset.atom_types,
bond_types=dataset.bond_types,
node_hidden_size=args['node_hidden_size'],
num_prop_rounds=args['num_propagation_rounds'], dropout=args['dropout'])
model.load_state_dict(torch.load(args['model_path'])['model_state_dict'])
else:
model = load_pretrained('_'.join(['DGMG', args['dataset'], args['order']]), log=False)
model.eval()
worker_num_samples = args['num_samples'] // args['num_processes']
if rank == args['num_processes'] - 1:
worker_num_samples += args['num_samples'] % args['num_processes']
worker_log_dir = os.path.join(args['log_dir'], str(rank))
mkdir_p(worker_log_dir, log=False)
generate_and_save(worker_log_dir, worker_num_samples, args['max_num_steps'], model)
def remove_worker_tmp_dir(args):
for rank in range(args['num_processes']):
worker_path = os.path.join(args['log_dir'], str(rank))
try:
shutil.rmtree(worker_path)
except OSError:
print('Directory {} does not exist!'.format(worker_path))
def aggregate_and_evaluate(args):
print('Merging generated SMILES into a single file...')
smiles = []
for rank in range(args['num_processes']):
with open(os.path.join(args['log_dir'], str(rank), 'generated_smiles.txt'), 'r') as f:
rank_smiles = f.read().splitlines()
smiles.extend(rank_smiles)
with open(os.path.join(args['log_dir'], 'generated_smiles.txt'), 'w') as f:
for s in smiles:
f.write(s + '\n')
print('Removing temporary dirs...')
remove_worker_tmp_dir(args)
# Summarize training molecules
print('Summarizing training molecules...')
train_file = '_'.join([args['dataset'], 'DGMG_train.txt'])
if not os.path.exists(train_file):
download_data(args['dataset'], train_file)
with open(train_file, 'r') as f:
train_smiles = f.read().splitlines()
train_summary = summarize_molecules(train_smiles, args['num_processes'])
with open(os.path.join(args['log_dir'], 'train_summary.pickle'), 'wb') as f:
pickle.dump(train_summary, f)
# Summarize generated molecules
print('Summarizing generated molecules...')
generation_summary = summarize_molecules(smiles, args['num_processes'])
with open(os.path.join(args['log_dir'], 'generation_summary.pickle'), 'wb') as f:
pickle.dump(generation_summary, f)
# Stats computation
print('Preparing generation statistics...')
valid_generated_smiles = generation_summary['smile']
unique_generated_smiles = get_unique_smiles(valid_generated_smiles)
unique_train_smiles = get_unique_smiles(train_summary['smile'])
novel_generated_smiles = get_novel_smiles(unique_generated_smiles, unique_train_smiles)
with open(os.path.join(args['log_dir'], 'generation_stats.txt'), 'w') as f:
f.write('Total number of generated molecules: {:d}\n'.format(len(smiles)))
f.write('Validity among all: {:.4f}\n'.format(
len(valid_generated_smiles) / len(smiles)))
f.write('Uniqueness among valid ones: {:.4f}\n'.format(
len(unique_generated_smiles) / len(valid_generated_smiles)))
f.write('Novelty among unique ones: {:.4f}\n'.format(
len(novel_generated_smiles) / len(unique_generated_smiles)))
if __name__ == '__main__':
import argparse
import datetime
import time
from rdkit import rdBase
from utils import setup
parser = argparse.ArgumentParser(description='Evaluating DGMG for molecule generation',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
# configure
parser.add_argument('-s', '--seed', type=int, default=0, help='random seed')
# dataset and setting
parser.add_argument('-d', '--dataset',
help='dataset to use')
parser.add_argument('-o', '--order', choices=['random', 'canonical'],
help='order to generate graphs, used for naming evaluation directory')
# log
parser.add_argument('-l', '--log-dir', default='./eval_results',
help='folder to save evaluation results')
parser.add_argument('-p', '--model-path', type=str, default=None,
help='path to saved model')
parser.add_argument('-pr', '--pretrained', action='store_true',
help='Whether to use a pre-trained model')
parser.add_argument('-ns', '--num-samples', type=int, default=100000,
help='Number of molecules to generate')
parser.add_argument('-mn', '--max-num-steps', type=int, default=400,
help='Max number of steps allowed in generated molecules to ensure termination')
# multi-process
parser.add_argument('-np', '--num-processes', type=int, default=32,
help='number of processes to use')
parser.add_argument('-gt', '--generation-time', type=int, default=600,
help='max time (seconds) allowed for generation with multiprocess')
args = parser.parse_args()
args = setup(args, train=False)
rdBase.DisableLog('rdApp.error')
t1 = time.time()
if args['num_processes'] == 1:
prepare_for_evaluation(0, args)
else:
import multiprocessing as mp
procs = []
for rank in range(args['num_processes']):
p = mp.Process(target=prepare_for_evaluation, args=(rank, args,))
procs.append(p)
p.start()
while time.time() - t1 <= args['generation_time']:
if any(p.is_alive() for p in procs):
time.sleep(5)
else:
break
else:
print('Timeout, killing all processes.')
for p in procs:
p.terminate()
p.join()
t2 = time.time()
print('It took {} for generation.'.format(
datetime.timedelta(seconds=t2 - t1)))
aggregate_and_evaluate(args)
#
# calculation of synthetic accessibility score as described in:
#
# Estimation of Synthetic Accessibility Score of Drug-like Molecules
# based on Molecular Complexity and Fragment Contributions
# Peter Ertl and Ansgar Schuffenhauer
# Journal of Cheminformatics 1:8 (2009)
# http://www.jcheminf.com/content/1/1/8
#
# several small modifications to the original paper are included
# particularly slightly different formula for marocyclic penalty
# and taking into account also molecule symmetry (fingerprint density)
#
# for a set of 10k diverse molecules the agreement between the original method
# as implemented in PipelinePilot and this implementation is r2 = 0.97
#
# peter ertl & greg landrum, september 2013
#
# A small modification is performed
#
# DGL team, August 2019
#
from __future__ import print_function
import math
import os
from rdkit import Chem
from rdkit.Chem import rdMolDescriptors
from rdkit.six.moves import cPickle
from rdkit.six import iteritems
from dgl.data.utils import download, _get_dgl_url
_fscores = None
def readFragmentScores(name='fpscores'):
import gzip
global _fscores
fname = '{}.pkl.gz'.format(name)
download(_get_dgl_url(os.path.join('dataset', fname)), path=fname)
_fscores = cPickle.load(gzip.open(fname))
outDict = {}
for i in _fscores:
for j in range(1, len(i)):
outDict[i[j]] = float(i[0])
_fscores = outDict
def numBridgeheadsAndSpiro(mol):
nSpiro = rdMolDescriptors.CalcNumSpiroAtoms(mol)
nBridgehead = rdMolDescriptors.CalcNumBridgeheadAtoms(mol)
return nBridgehead, nSpiro
def calculateScore(m):
if _fscores is None:
readFragmentScores()
# fragment score
# 2 is the *radius* of the circular fingerprint
fp = rdMolDescriptors.GetMorganFingerprint(m, 2)
fps = fp.GetNonzeroElements()
score1 = 0.
nf = 0
for bitId, v in iteritems(fps):
nf += v
sfp = bitId
score1 += _fscores.get(sfp, -4) * v
# We add L63 to avoid ZeroDivisionError.
if nf != 0:
score1 /= nf
# features score
nAtoms = m.GetNumAtoms()
nChiralCenters = len(Chem.FindMolChiralCenters(m, includeUnassigned=True))
ri = m.GetRingInfo()
nBridgeheads, nSpiro = numBridgeheadsAndSpiro(m)
nMacrocycles = 0
for x in ri.AtomRings():
if len(x) > 8:
nMacrocycles += 1
sizePenalty = nAtoms**1.005 - nAtoms
stereoPenalty = math.log10(nChiralCenters + 1)
spiroPenalty = math.log10(nSpiro + 1)
bridgePenalty = math.log10(nBridgeheads + 1)
macrocyclePenalty = 0.
# ---------------------------------------
# This differs from the paper, which defines:
# macrocyclePenalty = math.log10(nMacrocycles+1)
# This form generates better results when 2 or more macrocycles are present
if nMacrocycles > 0:
macrocyclePenalty = math.log10(2)
score2 = 0. - sizePenalty - stereoPenalty - \
spiroPenalty - bridgePenalty - macrocyclePenalty
# correction for the fingerprint density
# not in the original publication, added in version 1.1
# to make highly symmetrical molecules easier to synthetise
score3 = 0.
if nAtoms > len(fps):
score3 = math.log(float(nAtoms) / len(fps)) * .5
sascore = score1 + score2 + score3
# need to transform "raw" value into scale between 1 and 10
min = -4.0
max = 2.5
sascore = 11. - (sascore - min + 1) / (max - min) * 9.
# smooth the 10-end
if sascore > 8.:
sascore = 8. + math.log(sascore + 1. - 9.)
if sascore > 10.:
sascore = 10.0
elif sascore < 1.:
sascore = 1.0
return sascore
def processMols(mols):
print('smiles\tName\tsa_score')
for i, m in enumerate(mols):
if m is None:
continue
s = calculateScore(m)
smiles = Chem.MolToSmiles(m)
print(smiles + "\t" + m.GetProp('_Name') + "\t%3f" % s)
if __name__ == '__main__':
import sys, time
t1 = time.time()
readFragmentScores("fpscores")
t2 = time.time()
suppl = Chem.SmilesMolSupplier(sys.argv[1])
t3 = time.time()
processMols(suppl)
t4 = time.time()
print('Reading took %.2f seconds. Calculating took %.2f seconds' % ((t2 - t1), (t4 - t3)),
file=sys.stderr)
#
# Copyright (c) 2013, Novartis Institutes for BioMedical Research Inc.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following
# disclaimer in the documentation and/or other materials provided
# with the distribution.
# * Neither the name of Novartis Institutes for BioMedical Research Inc.
# nor the names of its contributors may be used to endorse or promote
# products derived from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
"""
Learning Deep Generative Models of Graphs
Paper: https://arxiv.org/pdf/1803.03324.pdf
"""
import datetime
import time
import torch
import torch.distributed as dist
from dgllife.model import DGMG
from torch.optim import Adam
from torch.utils.data import DataLoader
from utils import MoleculeDataset, Printer, set_random_seed, synchronize, launch_a_process
def evaluate(epoch, model, data_loader, printer):
model.eval()
batch_size = data_loader.batch_size
total_log_prob = 0
with torch.no_grad():
for i, data in enumerate(data_loader):
log_prob = model(actions=data, compute_log_prob=True).detach()
total_log_prob -= log_prob
if printer is not None:
prob = log_prob.detach().exp()
printer.update(epoch + 1, - log_prob / batch_size, prob / batch_size)
return total_log_prob / len(data_loader)
def main(rank, args):
"""
Parameters
----------
rank : int
Subprocess id
args : dict
Configuration
"""
if rank == 0:
t1 = time.time()
set_random_seed(args['seed'])
# Remove the line below will result in problems for multiprocess
torch.set_num_threads(1)
# Setup dataset and data loader
dataset = MoleculeDataset(args['dataset'], args['order'], ['train', 'val'],
subset_id=rank, n_subsets=args['num_processes'])
# Note that currently the batch size for the loaders should only be 1.
train_loader = DataLoader(dataset.train_set, batch_size=args['batch_size'],
shuffle=True, collate_fn=dataset.collate)
val_loader = DataLoader(dataset.val_set, batch_size=args['batch_size'],
shuffle=True, collate_fn=dataset.collate)
if rank == 0:
try:
from tensorboardX import SummaryWriter
writer = SummaryWriter(args['log_dir'])
except ImportError:
print('If you want to use tensorboard, install tensorboardX with pip.')
writer = None
train_printer = Printer(args['nepochs'], len(dataset.train_set), args['batch_size'], writer)
val_printer = Printer(args['nepochs'], len(dataset.val_set), args['batch_size'])
else:
val_printer = None
# Initialize model
model = DGMG(atom_types=dataset.atom_types,
bond_types=dataset.bond_types,
node_hidden_size=args['node_hidden_size'],
num_prop_rounds=args['num_propagation_rounds'],
dropout=args['dropout'])
if args['num_processes'] == 1:
from utils import Optimizer
optimizer = Optimizer(args['lr'], Adam(model.parameters(), lr=args['lr']))
else:
from utils import MultiProcessOptimizer
optimizer = MultiProcessOptimizer(args['num_processes'], args['lr'],
Adam(model.parameters(), lr=args['lr']))
if rank == 0:
t2 = time.time()
best_val_prob = 0
# Training
for epoch in range(args['nepochs']):
model.train()
if rank == 0:
print('Training')
for i, data in enumerate(train_loader):
log_prob = model(actions=data, compute_log_prob=True)
prob = log_prob.detach().exp()
loss_averaged = - log_prob
prob_averaged = prob
optimizer.backward_and_step(loss_averaged)
if rank == 0:
train_printer.update(epoch + 1, loss_averaged.item(), prob_averaged.item())
synchronize(args['num_processes'])
# Validation
val_log_prob = evaluate(epoch, model, val_loader, val_printer)
if args['num_processes'] > 1:
dist.all_reduce(val_log_prob, op=dist.ReduceOp.SUM)
val_log_prob /= args['num_processes']
# Strictly speaking, the computation of probability here is different from what is
# performed on the training set as we first take an average of log likelihood and then
# take the exponentiation. By Jensen's inequality, the resulting value is then a
# lower bound of the real probabilities.
val_prob = (- val_log_prob).exp().item()
val_log_prob = val_log_prob.item()
if val_prob >= best_val_prob:
if rank == 0:
torch.save({'model_state_dict': model.state_dict()}, args['checkpoint_dir'])
print('Old val prob {:.10f} | new val prob {:.10f} | model saved'.format(best_val_prob, val_prob))
best_val_prob = val_prob
elif epoch >= args['warmup_epochs']:
optimizer.decay_lr()
if rank == 0:
print('Validation')
if writer is not None:
writer.add_scalar('validation_log_prob', val_log_prob, epoch)
writer.add_scalar('validation_prob', val_prob, epoch)
writer.add_scalar('lr', optimizer.lr, epoch)
print('Validation log prob {:.4f} | prob {:.10f}'.format(val_log_prob, val_prob))
synchronize(args['num_processes'])
if rank == 0:
t3 = time.time()
print('It took {} to setup.'.format(datetime.timedelta(seconds=t2 - t1)))
print('It took {} to finish training.'.format(datetime.timedelta(seconds=t3 - t2)))
print('--------------------------------------------------------------------------')
print('On average, an epoch takes {}.'.format(datetime.timedelta(
seconds=(t3 - t2) / args['nepochs'])))
if __name__ == '__main__':
import argparse
from utils import setup
parser = argparse.ArgumentParser(description='Training DGMG for molecule generation',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
# configure
parser.add_argument('-s', '--seed', type=int, default=0, help='random seed')
parser.add_argument('-w', '--warmup-epochs', type=int, default=10,
help='Number of epochs where no lr decay is performed.')
# dataset and setting
parser.add_argument('-d', '--dataset',
help='dataset to use')
parser.add_argument('-o', '--order', choices=['random', 'canonical'],
help='order to generate graphs')
parser.add_argument('-tf', '--train-file', type=str, default=None,
help='Path to a file with one SMILES a line for training data. '
'This is only necessary if you want to use a new dataset.')
parser.add_argument('-vf', '--val-file', type=str, default=None,
help='Path to a file with one SMILES a line for validation data. '
'This is only necessary if you want to use a new dataset.')
# log
parser.add_argument('-l', '--log-dir', default='./training_results',
help='folder to save info like experiment configuration')
# multi-process
parser.add_argument('-np', '--num-processes', type=int, default=32,
help='number of processes to use')
parser.add_argument('-mi', '--master-ip', type=str, default='127.0.0.1')
parser.add_argument('-mp', '--master-port', type=str, default='12345')
args = parser.parse_args()
args = setup(args, train=True)
if args['num_processes'] == 1:
main(0, args)
else:
mp = torch.multiprocessing.get_context('spawn')
procs = []
for rank in range(args['num_processes']):
procs.append(mp.Process(target=launch_a_process, args=(rank, args, main), daemon=True))
procs[-1].start()
for p in procs:
p.join()
This diff is collapsed.
# Junction Tree Variational Autoencoder for Molecular Graph Generation (JTNN)
Wengong Jin, Regina Barzilay, Tommi Jaakkola.
Junction Tree Variational Autoencoder for Molecular Graph Generation.
*arXiv preprint arXiv:1802.04364*, 2018.
JTNN uses algorithm called junction tree algorithm to form a tree from the molecular graph.
Then the model will encode the tree and graph into two separate vectors `z_G` and `z_T`. Details can
be found in original paper. The brief process is as below (from original paper):
![image](https://user-images.githubusercontent.com/8686776/63677300-3fb6d980-c81f-11e9-8a65-57c8b03aaf52.png)
**Goal**: JTNN is an auto-encoder model, aiming to learn hidden representation for molecular graphs.
These representations can be used for downstream tasks, such as property prediction, or molecule optimizations.
## Dataset
### ZINC
> The ZINC database is a curated collection of commercially available chemical compounds
prepared especially for virtual screening. (introduction from Wikipedia)
Generally speaking, molecules in the ZINC dataset are more drug-like. We uses ~220,000
molecules for training and 5000 molecules for validation.
### Preprocessing
Class `JTNNDataset` will process a SMILES string into a dict, consisting of a junction tree, a graph with
encoded nodes(atoms) and edges(bonds), and other information for model to use.
## Usage
### Training
To start training, use `python train.py`. By default, the script will use ZINC dataset
with preprocessed vocabulary, and save model checkpoint periodically in the current working directory.
### Evaluation
To start evaluation, use `python reconstruct_eval.py`. By default, we will perform evaluation with
DGL's pre-trained model. During the evaluation, the program will print out the success rate of
molecule reconstruction.
### Pre-trained models
Below gives the statistics of our pre-trained `JTNN_ZINC` model.
| Pre-trained model | % Reconstruction Accuracy
| ------------------ | -------
| `JTNN_ZINC` | 73.7
### Visualization
Here we draw some "neighbor" of a given molecule, by adding noises on the intermediate representations.
You can download the script with `wget https://data.dgl.ai/dgllife/jtnn_viz_neighbor_mol.ipynb`.
Please put this script at the current directory (`examples/pytorch/model_zoo/chem/generative_models/jtnn/`).
#### Given Molecule
![image](https://user-images.githubusercontent.com/8686776/63773593-0d37da00-c90e-11e9-8933-0abca4b430db.png)
#### Neighbor Molecules
![image](https://user-images.githubusercontent.com/8686776/63773602-1163f780-c90e-11e9-8341-5122dc0d0c82.png)
### Dataset configuration
If you want to use your own dataset, please create a file with one SMILES a line as below
```
CCO
Fc1ccccc1
```
You can generate the vocabulary file corresponding to your dataset with `python vocab.py -d X -v Y`, where `X`
is the path to the dataset and `Y` is the path to the vocabulary file to save. An example vocabulary file
corresponding to the two molecules above will be
```
CC
CF
C1=CC=CC=C1
CO
```
If you want to develop a model based on DGL's pre-trained model, it's important to make sure that the vocabulary
generated above is a subset of the vocabulary we use for the pre-trained model. By running `vocab.py` above, we
also check if the new vocabulary is a subset of the vocabulary we use for the pre-trained model and print the
result in the terminal as follows:
```
The new vocabulary is a subset of the default vocabulary: True
```
To train on this new dataset, run
```
python train.py -t X
```
where `X` is the path to the new dataset. If you want to use the vocabulary generated above, also add `-v Y`, where
`Y` is the path to the vocabulary file we just saved.
To evaluate on this new dataset, run `python reconstruct_eval.py` with arguments same as above.
from .mol_tree import Vocab
from .datautils import JTNNDataset, JTNNCollator
from .chemutils import decode_stereo
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment