"examples/vscode:/vscode.git/clone" did not exist on "b2b531e0412f1207a8d17ea24cfbece77490053e"
Unverified Commit 828a5e5b authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[DGL-LifeSci] Migration and Refactor (#1226)

* First commit

* Update

* Update splitters

* Update

* Update

* Update

* Update

* Update

* Update

* Migrate ACNN

* Fix

* Fix

* Update

* Update

* Update

* Update

* Update

* Update

* Finish classification

* Update

* Fix

* Update

* Update

* Update

* Fix

* Fix

* Fix

* Update

* Update

* Update

* trigger CI

* Fix CI

* Update

* Update

* Update

* Add default values

* Rename

* Update deprecation message
parent e4948c5c
# DGL-LifeSci
## Introduction
Deep learning on graphs has been an arising trend in the past few years. There are a lot of graphs in
life science such as molecular graphs and biological networks, making it an import area for applying
deep learning on graphs. `dgllife` is a DGL-based package for various applications in life science
with graph neural networks.
We provide various functionalities, including but not limited to methods for graph construction,
featurization, and evaluation, model architectures, training scripts and pre-trained models.
## Dependencies
For the time being, we only support PyTorch.
Depending on the features you want to use, you may need to manually install the following dependencies:
- RDKit 2018.09.3
- We recommend installation with `conda install -c conda-forge rdkit==2018.09.3`. For other installation recipes,
see the [official documentation](https://www.rdkit.org/docs/Install.html).
- MDTraj
- We recommend installation with `conda install -c conda-forge mdtraj`. For alternative ways of installation,
see the [official documentation](http://mdtraj.org/1.9.3/installation.html).
## Organization
For a full list of work implemented in DGL-LifeSci, **see implemented.md**.
```
dgllife
data
csv_dataset.py
...
model
gnn
model_zoo
readout
pretrain.py
utils
complex_to_graph.py
early_stop.py
eval.py
featurizers.py
mol_to_graph.py
rdkit_utils.py
splitters.py
```
### `data`
The directory consists of interfaces for working with several datasets. Additionally, one can adapt any
`.csv` dataset to dgl with `MoleculeCSVDataset` in `csv_dataset.py`.
### `model`
- `gnn` implements several graph neural networks for message passing and updating node representations.
- `readout` implements several methods for computing graph representations out of node representations.
In the context of molecules, they may be viewed as learned fingerprints.
- `model_zoo` implements several models for property prediction, generative models and protein-ligand
binding affinity prediction. Many of them are based on modules in `gnn` and `readout`.
- `pretrain.py` contains APIs for loading pre-trained models.
### `utils`
- `complex_to_graph.py` contains utils for graph construction and featurization of protein-ligand complexes.
- `early_stop.py` contains utils for early stopping.
- `eval.py` contains utils for evaluating models on property prediction.
- `featurizers.py` contains utils for featurizing molecular graphs.
- `mol_to_graph.py` contains several ways for graph construction of molecules.
- `rdkit_utils.py` contains utils for RDKit, in particular loading RDKit molecule instances from different
formats, including `mol2`, `sdf`, `pdbqt`, and `pdb`.
- `splitters.py` contains several ways for splitting the dataset.
## Example Usage
Currently we provide examples for molecular property prediction, generative models and protein-ligand binding
affinity prediction. See the examples folder for details.
For some examples we also provide pre-trained models, which can be used off-shelf without training from scratch.
```python
"""Load a pre-trained model for property prediction."""
from dgllife.data import Tox21
from dgllife.model import load_pretrained
from dgllife.utils import smiles_to_bigraph, CanonicalAtomFeaturizer
dataset = Tox21(smiles_to_bigraph, CanonicalAtomFeaturizer())
model = load_pretrained('GCN_Tox21') # Pretrained model loaded
model.eval()
smiles, g, label, mask = dataset[0]
feats = g.ndata.pop('h')
label_pred = model(g, feats)
print(smiles) # CCOc1ccc2nc(S(N)(=O)=O)sc2c1
print(label_pred[:, mask != 0]) # Mask non-existing labels
# tensor([[ 1.4190, -0.1820, 1.2974, 1.4416, 0.6914,
# 2.0957, 0.5919, 0.7715, 1.7273, 0.2070]])
```
```python
"""Load a pre-trained model for generating molecules."""
from IPython.display import SVG
from rdkit import Chem
from rdkit.Chem import Draw
from dgllife.model import load_pretrained
model = load_pretrained('DGMG_ZINC_canonical')
model.eval()
mols = []
for i in range(4):
SMILES = model(rdkit_mol=True)
mols.append(Chem.MolFromSmiles(SMILES))
# Generating 4 molecules takes less than a second.
SVG(Draw.MolsToGridImage(mols, molsPerRow=4, subImgSize=(180, 150), useSVG=True))
```
![](https://s3.us-east-2.amazonaws.com/dgl-data/model_zoo/drug_discovery/dgmg_model_zoo_example2.png)
## Speed Reference
Below we provide some reference numbers to show how DGL improves the speed of training models per epoch in seconds.
| Model | Original Implementation | DGL Implementation | Improvement |
| -------------------------- | ----------------------- | ------------------ | ----------- |
| GCN on Tox21 | 5.5 (DeepChem) | 1.0 | 5.5x |
| AttentiveFP on Aromaticity | 6.0 | 1.2 | 5x |
| JTNN on ZINC | 1826 | 743 | 2.5x |
"""Dataset classes."""
from .alchemy import *
from .csv_dataset import *
from .pdbbind import *
from .pubchem_aromaticity import *
from .tox21 import *
# -*- coding:utf-8 -*-
"""Tencent Alchemy Dataset https://alchemy.tencent.com/"""
import numpy as np
import os
import os.path as osp
import pandas as pd
import pathlib
import zipfile
from collections import defaultdict
from dgl import backend as F
from dgl.data.utils import download, get_download_dir, _get_dgl_url, save_graphs, load_graphs
from rdkit import Chem
from rdkit.Chem import ChemicalFeatures
from rdkit import RDConfig
from ..utils.mol_to_graph import mol_to_complete_graph
from ..utils.featurizers import atom_type_one_hot, atom_hybridization_one_hot, atom_is_aromatic
__all__ = ['TencentAlchemyDataset']
def alchemy_nodes(mol):
"""Featurization for all atoms in a molecule. The atom indices
will be preserved.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule object
Returns
-------
atom_feats_dict : dict
Dictionary for atom features
"""
atom_feats_dict = defaultdict(list)
is_donor = defaultdict(int)
is_acceptor = defaultdict(int)
fdef_name = osp.join(RDConfig.RDDataDir, 'BaseFeatures.fdef')
mol_featurizer = ChemicalFeatures.BuildFeatureFactory(fdef_name)
mol_feats = mol_featurizer.GetFeaturesForMol(mol)
mol_conformers = mol.GetConformers()
assert len(mol_conformers) == 1
for i in range(len(mol_feats)):
if mol_feats[i].GetFamily() == 'Donor':
node_list = mol_feats[i].GetAtomIds()
for u in node_list:
is_donor[u] = 1
elif mol_feats[i].GetFamily() == 'Acceptor':
node_list = mol_feats[i].GetAtomIds()
for u in node_list:
is_acceptor[u] = 1
num_atoms = mol.GetNumAtoms()
for u in range(num_atoms):
atom = mol.GetAtomWithIdx(u)
atom_type = atom.GetAtomicNum()
num_h = atom.GetTotalNumHs()
atom_feats_dict['node_type'].append(atom_type)
h_u = []
h_u += atom_type_one_hot(atom, ['H', 'C', 'N', 'O', 'F', 'S', 'Cl'])
h_u.append(atom_type)
h_u.append(is_acceptor[u])
h_u.append(is_donor[u])
h_u += atom_is_aromatic(atom)
h_u += atom_hybridization_one_hot(atom, [Chem.rdchem.HybridizationType.SP,
Chem.rdchem.HybridizationType.SP2,
Chem.rdchem.HybridizationType.SP3])
h_u.append(num_h)
atom_feats_dict['n_feat'].append(F.tensor(np.array(h_u).astype(np.float32)))
atom_feats_dict['n_feat'] = F.stack(atom_feats_dict['n_feat'], dim=0)
atom_feats_dict['node_type'] = F.tensor(np.array(
atom_feats_dict['node_type']).astype(np.int64))
return atom_feats_dict
def alchemy_edges(mol, self_loop=False):
"""Featurization for all bonds in a molecule.
The bond indices will be preserved.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule object
self_loop : bool
Whether to add self loops. Default to be False.
Returns
-------
bond_feats_dict : dict
Dictionary for bond features
"""
bond_feats_dict = defaultdict(list)
mol_conformers = mol.GetConformers()
assert len(mol_conformers) == 1
geom = mol_conformers[0].GetPositions()
num_atoms = mol.GetNumAtoms()
for u in range(num_atoms):
for v in range(num_atoms):
if u == v and not self_loop:
continue
e_uv = mol.GetBondBetweenAtoms(u, v)
if e_uv is None:
bond_type = None
else:
bond_type = e_uv.GetBondType()
bond_feats_dict['e_feat'].append([
float(bond_type == x)
for x in (Chem.rdchem.BondType.SINGLE,
Chem.rdchem.BondType.DOUBLE,
Chem.rdchem.BondType.TRIPLE,
Chem.rdchem.BondType.AROMATIC, None)
])
bond_feats_dict['distance'].append(
np.linalg.norm(geom[u] - geom[v]))
bond_feats_dict['e_feat'] = F.tensor(
np.array(bond_feats_dict['e_feat']).astype(np.float32))
bond_feats_dict['distance'] = F.tensor(
np.array(bond_feats_dict['distance']).astype(np.float32)).reshape(-1 , 1)
return bond_feats_dict
class TencentAlchemyDataset(object):
"""
Developed by the Tencent Quantum Lab, the dataset lists 12 quantum mechanical
properties of 130, 000+ organic molecules, comprising up to 12 heavy atoms
(C, N, O, S, F and Cl), sampled from the GDBMedChem database. These properties
have been calculated using the open-source computational chemistry program
Python-based Simulation of Chemistry Framework (PySCF).
For more details, check the `paper <https://arxiv.org/abs/1906.09427>`__.
Parameters
----------
mode : str
'dev', 'valid' or 'test', separately for training, validation and test.
Default to be 'dev'. Note that 'test' is not available as the Alchemy
contest is ongoing.
mol_to_graph: callable, str -> DGLGraph
A function turning an RDKit molecule instance into a DGLGraph.
Default to :func:`dgl.data.chem.mol_to_complete_graph`.
node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for nodes like atoms in a molecule, which can be used to update
ndata for a DGLGraph. By default, we construct graphs where nodes represent atoms
and node features represent atom features. We store the atomic numbers under the
name ``"node_type"`` and store the atom features under the name ``"n_feat"``.
The atom features include:
* One hot encoding for atom types
* Atomic number of atoms
* Whether the atom is a donor
* Whether the atom is an acceptor
* Whether the atom is aromatic
* One hot encoding for atom hybridization
* Total number of Hs on the atom
edge_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for edges like bonds in a molecule, which can be used to update
edata for a DGLGraph. By default, we construct edges between every pair of atoms,
excluding the self loops. We store the distance between the end atoms under the name
``"distance"`` and store the edge features under the name ``"e_feat"``. The edge
features represent one hot encoding of edge types (bond types and non-bond edges).
load : bool
Whether to load the previously pre-processed dataset or pre-process from scratch.
``load`` should be False when we want to try different graph construction and
featurization methods and need to preprocess from scratch. Default to True.
"""
def __init__(self, mode='dev',
mol_to_graph=mol_to_complete_graph,
node_featurizer=alchemy_nodes,
edge_featurizer=alchemy_edges,
load=True):
if mode == 'test':
raise ValueError('The test mode is not supported before '
'the Alchemy contest finishes.')
assert mode in ['dev', 'valid', 'test'], \
'Expect mode to be dev, valid or test, got {}.'.format(mode)
self.mode = mode
# Construct DGLGraphs from raw data or use the preprocessed data
self.load = load
file_dir = osp.join(get_download_dir(), 'Alchemy_data')
if load:
file_name = "{}_processed_dgl".format(mode)
else:
file_name = "{}_single_sdf".format(mode)
self.file_dir = pathlib.Path(file_dir, file_name)
self._url = 'dataset/alchemy/'
self.zip_file_path = pathlib.Path(file_dir, file_name + '.zip')
download(_get_dgl_url(self._url + file_name + '.zip'), path=str(self.zip_file_path))
if not os.path.exists(str(self.file_dir)):
archive = zipfile.ZipFile(self.zip_file_path)
archive.extractall(file_dir)
archive.close()
self._load(mol_to_graph, node_featurizer, edge_featurizer)
def _load(self, mol_to_graph, node_featurizer, edge_featurizer):
if self.load:
self.graphs, label_dict = load_graphs(osp.join(self.file_dir, "{}_graphs.bin".format(self.mode)))
self.labels = label_dict['labels']
with open(osp.join(self.file_dir, "{}_smiles.txt".format(self.mode)), 'r') as f:
smiles_ = f.readlines()
self.smiles = [s.strip() for s in smiles_]
else:
print('Start preprocessing dataset...')
target_file = pathlib.Path(self.file_dir, "{}_target.csv".format(self.mode))
self.target = pd.read_csv(
target_file,
index_col=0,
usecols=['gdb_idx',] + ['property_{:d}'.format(x) for x in range(12)])
self.target = self.target[['property_{:d}'.format(x) for x in range(12)]]
self.graphs, self.labels, self.smiles = [], [], []
supp = Chem.SDMolSupplier(osp.join(self.file_dir, self.mode + ".sdf"))
cnt = 0
dataset_size = len(self.target)
for mol, label in zip(supp, self.target.iterrows()):
cnt += 1
print('Processing molecule {:d}/{:d}'.format(cnt, dataset_size))
graph = mol_to_graph(mol, node_featurizer=node_featurizer,
edge_featurizer=edge_featurizer)
smiles = Chem.MolToSmiles(mol)
self.smiles.append(smiles)
self.graphs.append(graph)
label = F.tensor(np.array(label[1].tolist()).astype(np.float32))
self.labels.append(label)
save_graphs(osp.join(self.file_dir, "{}_graphs.bin".format(self.mode)), self.graphs,
labels={'labels': F.stack(self.labels, dim=0)})
with open(osp.join(self.file_dir, "{}_smiles.txt".format(self.mode)), 'w') as f:
for s in self.smiles:
f.write(s + '\n')
self.set_mean_and_std()
print(len(self.graphs), "loaded!")
def __getitem__(self, item):
"""Get datapoint with index
Parameters
----------
item : int
Datapoint index
Returns
-------
str
SMILES for the ith datapoint
DGLGraph
DGLGraph for the ith datapoint
Tensor of dtype float32
Labels of the datapoint for all tasks
"""
return self.smiles[item], self.graphs[item], self.labels[item]
def __len__(self):
"""Length of the dataset
Returns
-------
int
Length of Dataset
"""
return len(self.graphs)
def set_mean_and_std(self, mean=None, std=None):
"""Set mean and std or compute from labels for future normalization.
Parameters
----------
mean : int or float
Default to be None.
std : int or float
Default to be None.
"""
labels = np.array([i.numpy() for i in self.labels])
if mean is None:
mean = np.mean(labels, axis=0)
if std is None:
std = np.std(labels, axis=0)
self.mean = mean
self.std = std
"""Creating datasets from .csv files for molecular property prediction."""
import dgl.backend as F
import numpy as np
import os
from dgl.data.utils import save_graphs, load_graphs
__all__ = ['MoleculeCSVDataset']
class MoleculeCSVDataset(object):
"""MoleculeCSVDataset
This is a general class for loading molecular data from pandas.DataFrame.
In data pre-processing, we set non-existing labels to be 0,
and returning mask with 1 where label exists.
All molecules are converted into DGLGraphs. After the first-time construction, the
DGLGraphs can be saved for reloading so that we do not need to reconstruct them every time.
Parameters
----------
df: pandas.DataFrame
Dataframe including smiles and labels. Can be loaded by pandas.read_csv(file_path).
One column includes smiles and other columns for labels.
Column names other than smiles column would be considered as task names.
smiles_to_graph: callable, str -> DGLGraph
A function turning a SMILES into a DGLGraph.
node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for nodes like atoms in a molecule, which can be used to update
ndata for a DGLGraph.
edge_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for edges like bonds in a molecule, which can be used to update
edata for a DGLGraph.
smiles_column: str
Column name that including smiles.
cache_file_path: str
Path to store the preprocessed DGLGraphs. For example, this can be ``'dglgraph.bin'``.
task_names : list of str or None
Columns in the data frame corresponding to real-valued labels. If None, we assume
all columns except the smiles_column are labels. Default to None.
load : bool
Whether to load the previously pre-processed dataset or pre-process from scratch.
``load`` should be False when we want to try different graph construction and
featurization methods and need to preprocess from scratch. Default to True.
"""
def __init__(self, df, smiles_to_graph, node_featurizer, edge_featurizer,
smiles_column, cache_file_path, task_names=None, load=True):
self.df = df
self.smiles = self.df[smiles_column].tolist()
if task_names is None:
self.task_names = self.df.columns.drop([smiles_column]).tolist()
else:
self.task_names = task_names
self.n_tasks = len(self.task_names)
self.cache_file_path = cache_file_path
self._pre_process(smiles_to_graph, node_featurizer, edge_featurizer, load)
def _pre_process(self, smiles_to_graph, node_featurizer, edge_featurizer, load):
"""Pre-process the dataset
* Convert molecules from smiles format into DGLGraphs
and featurize their atoms
* Set missing labels to be 0 and use a binary masking
matrix to mask them
Parameters
----------
smiles_to_graph : callable, SMILES -> DGLGraph
Function for converting a SMILES (str) into a DGLGraph.
node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for nodes like atoms in a molecule, which can be used to update
ndata for a DGLGraph.
edge_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for edges like bonds in a molecule, which can be used to update
edata for a DGLGraph.
load : bool
Whether to load the previously pre-processed dataset or pre-process from scratch.
``load`` should be False when we want to try different graph construction and
featurization methods and need to preprocess from scratch. Default to True.
"""
if os.path.exists(self.cache_file_path) and load:
# DGLGraphs have been constructed before, reload them
print('Loading previously saved dgl graphs...')
self.graphs, label_dict = load_graphs(self.cache_file_path)
self.labels = label_dict['labels']
self.mask = label_dict['mask']
else:
print('Processing dgl graphs from scratch...')
self.graphs = []
for i, s in enumerate(self.smiles):
print('Processing molecule {:d}/{:d}'.format(i+1, len(self)))
self.graphs.append(smiles_to_graph(s, node_featurizer=node_featurizer,
edge_featurizer=edge_featurizer))
_label_values = self.df[self.task_names].values
# np.nan_to_num will also turn inf into a very large number
self.labels = F.zerocopy_from_numpy(np.nan_to_num(_label_values).astype(np.float32))
self.mask = F.zerocopy_from_numpy((~np.isnan(_label_values)).astype(np.float32))
save_graphs(self.cache_file_path, self.graphs,
labels={'labels': self.labels, 'mask': self.mask})
def __getitem__(self, item):
"""Get datapoint with index
Parameters
----------
item : int
Datapoint index
Returns
-------
str
SMILES for the ith datapoint
DGLGraph
DGLGraph for the ith datapoint
Tensor of dtype float32
Labels of the datapoint for all tasks
Tensor of dtype float32
Binary masks indicating the existence of labels for all tasks
"""
return self.smiles[item], self.graphs[item], self.labels[item], self.mask[item]
def __len__(self):
"""Length of the dataset
Returns
-------
int
Length of Dataset
"""
return len(self.smiles)
"""PDBBind dataset processed by MoleculeNet."""
import dgl.backend as F
import numpy as np
import os
import pandas as pd
from dgl.data.utils import get_download_dir, download, _get_dgl_url, extract_archive
from ..utils import multiprocess_load_molecules, ACNN_graph_construction_and_featurization
__all__ = ['PDBBind']
class PDBBind(object):
"""PDBbind dataset processed by MoleculeNet.
The description below is mainly based on
`[1] <https://pubs.rsc.org/en/content/articlelanding/2018/sc/c7sc02664a#cit50>`__.
The PDBBind database consists of experimentally measured binding affinities for
bio-molecular complexes `[2] <https://www.ncbi.nlm.nih.gov/pubmed/?term=15163179%5Buid%5D>`__,
`[3] <https://www.ncbi.nlm.nih.gov/pubmed/?term=15943484%5Buid%5D>`__. It provides detailed
3D Cartesian coordinates of both ligands and their target proteins derived from experimental
(e.g., X-ray crystallography) measurements. The availability of coordinates of the
protein-ligand complexes permits structure-based featurization that is aware of the
protein-ligand binding geometry. The authors of
`[1] <https://pubs.rsc.org/en/content/articlelanding/2018/sc/c7sc02664a#cit50>`__ use the
"refined" and "core" subsets of the database
`[4] <https://www.ncbi.nlm.nih.gov/pubmed/?term=25301850%5Buid%5D>`__, more carefully
processed for data artifacts, as additional benchmarking targets.
References:
* [1] MoleculeNet: a benchmark for molecular machine learning
* [2] The PDBbind database: collection of binding affinities for protein-ligand complexes
with known three-dimensional structures
* [3] The PDBbind database: methodologies and updates
* [4] PDB-wide collection of binding data: current status of the PDBbind database
Parameters
----------
subset : str
In MoleculeNet, we can use either the "refined" subset or the "core" subset. We can
retrieve them by setting ``subset`` to be ``'refined'`` or ``'core'``. The size
of the ``'core'`` set is 195 and the size of the ``'refined'`` set is 3706.
load_binding_pocket : bool
Whether to load binding pockets or full proteins. Default to True.
sanitize : bool
Whether sanitization is performed in initializing RDKit molecule instances. See
https://www.rdkit.org/docs/RDKit_Book.html for details of the sanitization.
Default to False.
calc_charges : bool
Whether to add Gasteiger charges via RDKit. Setting this to be True will enforce
``sanitize`` to be True. Default to False.
remove_hs : bool
Whether to remove hydrogens via RDKit. Note that removing hydrogens can be quite
slow for large molecules. Default to False.
use_conformation : bool
Whether we need to extract molecular conformation from proteins and ligands.
Default to True.
construct_graph_and_featurize : callable
Construct a DGLHeteroGraph for the use of GNNs. Mapping self.ligand_mols[i],
self.protein_mols[i], self.ligand_coordinates[i] and self.protein_coordinates[i]
to a DGLHeteroGraph. Default to :func:`ACNN_graph_construction_and_featurization`.
zero_padding : bool
Whether to perform zero padding. While DGL does not necessarily require zero padding,
pooling operations for variable length inputs can introduce stochastic behaviour, which
is not desired for sensitive scenarios. Default to True.
num_processes : int or None
Number of worker processes to use. If None,
then we will use the number of CPUs in the system. Default to 64.
"""
def __init__(self, subset, load_binding_pocket=True, sanitize=False, calc_charges=False,
remove_hs=False, use_conformation=True,
construct_graph_and_featurize=ACNN_graph_construction_and_featurization,
zero_padding=True, num_processes=64):
self.task_names = ['-logKd/Ki']
self.n_tasks = len(self.task_names)
self._url = 'dataset/pdbbind_v2015.tar.gz'
root_dir_path = get_download_dir()
data_path = root_dir_path + '/pdbbind_v2015.tar.gz'
extracted_data_path = root_dir_path + '/pdbbind_v2015'
download(_get_dgl_url(self._url), path=data_path)
extract_archive(data_path, extracted_data_path)
if subset == 'core':
index_label_file = extracted_data_path + '/v2015/INDEX_core_data.2013'
elif subset == 'refined':
index_label_file = extracted_data_path + '/v2015/INDEX_refined_data.2015'
else:
raise ValueError(
'Expect the subset_choice to be either '
'core or refined, got {}'.format(subset))
self._preprocess(extracted_data_path, index_label_file, load_binding_pocket,
sanitize, calc_charges, remove_hs, use_conformation,
construct_graph_and_featurize, zero_padding, num_processes)
def _filter_out_invalid(self, ligands_loaded, proteins_loaded, use_conformation):
"""Filter out invalid ligand-protein pairs.
Parameters
----------
ligands_loaded : list
Each element is a 2-tuple of the RDKit molecule instance and its associated atom
coordinates. None is used to represent invalid/non-existing molecule or coordinates.
proteins_loaded : list
Each element is a 2-tuple of the RDKit molecule instance and its associated atom
coordinates. None is used to represent invalid/non-existing molecule or coordinates.
use_conformation : bool
Whether we need conformation information (atom coordinates) and filter out molecules
without valid conformation.
"""
num_pairs = len(proteins_loaded)
self.indices, self.ligand_mols, self.protein_mols = [], [], []
if use_conformation:
self.ligand_coordinates, self.protein_coordinates = [], []
else:
# Use None for placeholders.
self.ligand_coordinates = [None for _ in range(num_pairs)]
self.protein_coordinates = [None for _ in range(num_pairs)]
for i in range(num_pairs):
ligand_mol, ligand_coordinates = ligands_loaded[i]
protein_mol, protein_coordinates = proteins_loaded[i]
if (not use_conformation) and all(v is not None for v in [protein_mol, ligand_mol]):
self.indices.append(i)
self.ligand_mols.append(ligand_mol)
self.protein_mols.append(protein_mol)
elif all(v is not None for v in [
protein_mol, protein_coordinates, ligand_mol, ligand_coordinates]):
self.indices.append(i)
self.ligand_mols.append(ligand_mol)
self.ligand_coordinates.append(ligand_coordinates)
self.protein_mols.append(protein_mol)
self.protein_coordinates.append(protein_coordinates)
def _preprocess(self, root_path, index_label_file, load_binding_pocket,
sanitize, calc_charges, remove_hs, use_conformation,
construct_graph_and_featurize, zero_padding, num_processes):
"""Preprocess the dataset.
The pre-processing proceeds as follows:
1. Load the dataset
2. Clean the dataset and filter out invalid pairs
3. Construct graphs
4. Prepare node and edge features
Parameters
----------
root_path : str
Root path for molecule files.
index_label_file : str
Path to the index file for the dataset.
load_binding_pocket : bool
Whether to load binding pockets or full proteins.
sanitize : bool
Whether sanitization is performed in initializing RDKit molecule instances. See
https://www.rdkit.org/docs/RDKit_Book.html for details of the sanitization.
calc_charges : bool
Whether to add Gasteiger charges via RDKit. Setting this to be True will enforce
``sanitize`` to be True.
remove_hs : bool
Whether to remove hydrogens via RDKit. Note that removing hydrogens can be quite
slow for large molecules.
use_conformation : bool
Whether we need to extract molecular conformation from proteins and ligands.
construct_graph_and_featurize : callable
Construct a DGLHeteroGraph for the use of GNNs. Mapping self.ligand_mols[i],
self.protein_mols[i], self.ligand_coordinates[i] and self.protein_coordinates[i]
to a DGLHeteroGraph. Default to :func:`ACNN_graph_construction_and_featurization`.
zero_padding : bool
Whether to perform zero padding. While DGL does not necessarily require zero padding,
pooling operations for variable length inputs can introduce stochastic behaviour, which
is not desired for sensitive scenarios.
num_processes : int or None
Number of worker processes to use. If None,
then we will use the number of CPUs in the system.
"""
contents = []
with open(index_label_file, 'r') as f:
for line in f.readlines():
if line[0] != "#":
splitted_elements = line.split()
if len(splitted_elements) == 8:
# Ignore "//"
contents.append(splitted_elements[:5] + splitted_elements[6:])
else:
print('Incorrect data format.')
print(splitted_elements)
self.df = pd.DataFrame(contents, columns=(
'PDB_code', 'resolution', 'release_year',
'-logKd/Ki', 'Kd/Ki', 'reference', 'ligand_name'))
pdbs = self.df['PDB_code'].tolist()
self.ligand_files = [os.path.join(
root_path, 'v2015', pdb, '{}_ligand.sdf'.format(pdb)) for pdb in pdbs]
if load_binding_pocket:
self.protein_files = [os.path.join(
root_path, 'v2015', pdb, '{}_pocket.pdb'.format(pdb)) for pdb in pdbs]
else:
self.protein_files = [os.path.join(
root_path, 'v2015', pdb, '{}_protein.pdb'.format(pdb)) for pdb in pdbs]
num_processes = min(num_processes, len(pdbs))
print('Loading ligands...')
ligands_loaded = multiprocess_load_molecules(self.ligand_files,
sanitize=sanitize,
calc_charges=calc_charges,
remove_hs=remove_hs,
use_conformation=use_conformation,
num_processes=num_processes)
print('Loading proteins...')
proteins_loaded = multiprocess_load_molecules(self.protein_files,
sanitize=sanitize,
calc_charges=calc_charges,
remove_hs=remove_hs,
use_conformation=use_conformation,
num_processes=num_processes)
self._filter_out_invalid(ligands_loaded, proteins_loaded, use_conformation)
self.df = self.df.iloc[self.indices]
self.labels = F.zerocopy_from_numpy(self.df[self.task_names].values.astype(np.float32))
print('Finished cleaning the dataset, '
'got {:d}/{:d} valid pairs'.format(len(self), len(pdbs)))
# Prepare zero padding
if zero_padding:
max_num_ligand_atoms = 0
max_num_protein_atoms = 0
for i in range(len(self)):
max_num_ligand_atoms = max(
max_num_ligand_atoms, self.ligand_mols[i].GetNumAtoms())
max_num_protein_atoms = max(
max_num_protein_atoms, self.protein_mols[i].GetNumAtoms())
else:
max_num_ligand_atoms = None
max_num_protein_atoms = None
print('Start constructing graphs and featurizing them.')
self.graphs = []
for i in range(len(self)):
print('Constructing and featurizing datapoint {:d}/{:d}'.format(i+1, len(self)))
self.graphs.append(construct_graph_and_featurize(
self.ligand_mols[i], self.protein_mols[i],
self.ligand_coordinates[i], self.protein_coordinates[i],
max_num_ligand_atoms, max_num_protein_atoms))
def __len__(self):
"""Get the size of the dataset.
Returns
-------
int
Number of valid ligand-protein pairs in the dataset.
"""
return len(self.indices)
def __getitem__(self, item):
"""Get the datapoint associated with the index.
Parameters
----------
item : int
Index for the datapoint.
Returns
-------
int
Index for the datapoint.
rdkit.Chem.rdchem.Mol
RDKit molecule instance for the ligand molecule.
rdkit.Chem.rdchem.Mol
RDKit molecule instance for the protein molecule.
DGLHeteroGraph
Pre-processed DGLHeteroGraph with features extracted.
Float32 tensor
Label for the datapoint.
"""
return item, self.ligand_mols[item], self.protein_mols[item], \
self.graphs[item], self.labels[item]
"""Dataset for aromaticity prediction"""
import pandas as pd
from dgl.data.utils import get_download_dir, download, _get_dgl_url
from .csv_dataset import MoleculeCSVDataset
from ..utils.mol_to_graph import smiles_to_bigraph
__all__ = ['PubChemBioAssayAromaticity']
class PubChemBioAssayAromaticity(MoleculeCSVDataset):
"""Subset of PubChem BioAssay Dataset for aromaticity prediction.
The dataset was constructed in `Pushing the Boundaries of Molecular Representation for Drug
Discovery with the Graph Attention Mechanism.
<https://www.ncbi.nlm.nih.gov/pubmed/31408336>`__ and is accompanied by the task of predicting
the number of aromatic atoms in molecules.
The dataset was constructed by sampling 3945 molecules with 0-40 aromatic atoms from the
PubChem BioAssay dataset.
Parameters
----------
smiles_to_graph: callable, str -> DGLGraph
A function turning smiles into a DGLGraph.
Default to :func:`dgl.data.chem.smiles_to_bigraph`.
node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for nodes like atoms in a molecule, which can be used to update
ndata for a DGLGraph. Default to None.
edge_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for edges like bonds in a molecule, which can be used to update
edata for a DGLGraph. Default to None.
load : bool
Whether to load the previously pre-processed dataset or pre-process from scratch.
``load`` should be False when we want to try different graph construction and
featurization methods and need to pre-process from scratch. Default to True.
"""
def __init__(self, smiles_to_graph=smiles_to_bigraph,
node_featurizer=None, edge_featurizer=None, load=True):
self._url = 'dataset/pubchem_bioassay_aromaticity.csv'
data_path = get_download_dir() + '/pubchem_bioassay_aromaticity.csv'
download(_get_dgl_url(self._url), path=data_path)
df = pd.read_csv(data_path)
super(PubChemBioAssayAromaticity, self).__init__(
df, smiles_to_graph, node_featurizer, edge_featurizer, "cano_smiles",
"pubchem_aromaticity_dglgraph.bin", load=load)
"""The Toxicology in the 21st Century initiative."""
import dgl.backend as F
import pandas as pd
from dgl.data.utils import get_download_dir, download, _get_dgl_url
from .csv_dataset import MoleculeCSVDataset
from ..utils.mol_to_graph import smiles_to_bigraph
__all__ = ['Tox21']
class Tox21(MoleculeCSVDataset):
"""Tox21 dataset.
The Toxicology in the 21st Century (https://tripod.nih.gov/tox21/challenge/)
initiative created a public database measuring toxicity of compounds, which
has been used in the 2014 Tox21 Data Challenge. The dataset contains qualitative
toxicity measurements for 8014 compounds on 12 different targets, including nuclear
receptors and stress response pathways. Each target results in a binary label.
A common issue for multi-task prediction is that some datapoints are not labeled for
all tasks. This is also the case for Tox21. In data pre-processing, we set non-existing
labels to be 0 so that they can be placed in tensors and used for masking in loss computation.
See examples below for more details.
All molecules are converted into DGLGraphs. After the first-time construction,
the DGLGraphs will be saved for reloading so that we do not need to reconstruct them everytime.
Parameters
----------
smiles_to_graph: callable, str -> DGLGraph
A function turning smiles into a DGLGraph.
Default to :func:`dgl.data.chem.smiles_to_bigraph`.
node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for nodes like atoms in a molecule, which can be used to update
ndata for a DGLGraph. Default to None.
edge_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for edges like bonds in a molecule, which can be used to update
edata for a DGLGraph. Default to None.
load : bool
Whether to load the previously pre-processed dataset or pre-process from scratch.
``load`` should be False when we want to try different graph construction and
featurization methods and need to preprocess from scratch. Default to True.
"""
def __init__(self, smiles_to_graph=smiles_to_bigraph,
node_featurizer=None,
edge_featurizer=None,
load=True):
self._url = 'dataset/tox21.csv.gz'
data_path = get_download_dir() + '/tox21.csv.gz'
download(_get_dgl_url(self._url), path=data_path)
df = pd.read_csv(data_path)
self.id = df['mol_id']
df = df.drop(columns=['mol_id'])
super(Tox21, self).__init__(df, smiles_to_graph, node_featurizer, edge_featurizer,
"smiles", "tox21_dglgraph.bin", load=load)
self._weight_balancing()
def _weight_balancing(self):
"""Perform re-balancing for each task.
It's quite common that the number of positive samples and the
number of negative samples are significantly different. To compensate
for the class imbalance issue, we can weight each datapoint in
loss computation.
In particular, for each task we will set the weight of negative samples
to be 1 and the weight of positive samples to be the number of negative
samples divided by the number of positive samples.
If weight balancing is performed, one attribute will be affected:
* self._task_pos_weights is set, which is a list of positive sample weights
for each task.
"""
num_pos = F.sum(self.labels, dim=0)
num_indices = F.sum(self.mask, dim=0)
self._task_pos_weights = (num_indices - num_pos) / num_pos
@property
def task_pos_weights(self):
"""Get weights for positive samples on each task
Returns
-------
numpy.ndarray
numpy array gives the weight of positive samples on all tasks
"""
return self._task_pos_weights
"""Model architectures and components at different levels."""
from .gnn import *
from .readout import *
from .model_zoo import *
from .pretrain import *
"""Graph neural networks for updating node representations."""
from .attentivefp import *
from .gat import *
from .gcn import *
from .mgcn import *
from .mpnn import *
from .schnet import *
"""AttentiveFP"""
import dgl.function as fn
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn.pytorch import edge_softmax
__all__ = ['AttentiveFPGNN']
class AttentiveGRU1(nn.Module):
"""Update node features with attention and GRU.
This will be used for incorporating the information of edge features
into node features for message passing.
Parameters
----------
node_feat_size : int
Size for the input node features.
edge_feat_size : int
Size for the input edge (bond) features.
edge_hidden_size : int
Size for the intermediate edge (bond) representations.
dropout : float
The probability for performing dropout.
"""
def __init__(self, node_feat_size, edge_feat_size, edge_hidden_size, dropout):
super(AttentiveGRU1, self).__init__()
self.edge_transform = nn.Sequential(
nn.Dropout(dropout),
nn.Linear(edge_feat_size, edge_hidden_size)
)
self.gru = nn.GRUCell(edge_hidden_size, node_feat_size)
def forward(self, g, edge_logits, edge_feats, node_feats):
"""Update node representations.
Parameters
----------
g : DGLGraph
DGLGraph for a batch of graphs
edge_logits : float32 tensor of shape (E, 1)
The edge logits based on which softmax will be performed for weighting
edges within 1-hop neighborhoods. E represents the number of edges.
edge_feats : float32 tensor of shape (E, edge_feat_size)
Previous edge features.
node_feats : float32 tensor of shape (V, node_feat_size)
Previous node features. V represents the number of nodes.
Returns
-------
float32 tensor of shape (V, node_feat_size)
Updated node features.
"""
g = g.local_var()
g.edata['e'] = edge_softmax(g, edge_logits) * self.edge_transform(edge_feats)
g.update_all(fn.copy_edge('e', 'm'), fn.sum('m', 'c'))
context = F.elu(g.ndata['c'])
return F.relu(self.gru(context, node_feats))
class AttentiveGRU2(nn.Module):
"""Update node features with attention and GRU.
This will be used in GNN layers for updating node representations.
Parameters
----------
node_feat_size : int
Size for the input node features.
edge_hidden_size : int
Size for the intermediate edge (bond) representations.
dropout : float
The probability for performing dropout.
"""
def __init__(self, node_feat_size, edge_hidden_size, dropout):
super(AttentiveGRU2, self).__init__()
self.project_node = nn.Sequential(
nn.Dropout(dropout),
nn.Linear(node_feat_size, edge_hidden_size)
)
self.gru = nn.GRUCell(edge_hidden_size, node_feat_size)
def forward(self, g, edge_logits, node_feats):
"""Update node representations.
Parameters
----------
g : DGLGraph
DGLGraph for a batch of graphs
edge_logits : float32 tensor of shape (E, 1)
The edge logits based on which softmax will be performed for weighting
edges within 1-hop neighborhoods. E represents the number of edges.
node_feats : float32 tensor of shape (V, node_feat_size)
Previous node features. V represents the number of nodes.
Returns
-------
float32 tensor of shape (V, node_feat_size)
Updated node features.
"""
g = g.local_var()
g.edata['a'] = edge_softmax(g, edge_logits)
g.ndata['hv'] = self.project_node(node_feats)
g.update_all(fn.src_mul_edge('hv', 'a', 'm'), fn.sum('m', 'c'))
context = F.elu(g.ndata['c'])
return F.relu(self.gru(context, node_feats))
class GetContext(nn.Module):
"""Generate context for each node by message passing at the beginning.
This layer incorporates the information of edge features into node
representations so that message passing needs to be only performed over
node representations.
Parameters
----------
node_feat_size : int
Size for the input node features.
edge_feat_size : int
Size for the input edge (bond) features.
graph_feat_size : int
Size of the learned graph representation (molecular fingerprint).
dropout : float
The probability for performing dropout.
"""
def __init__(self, node_feat_size, edge_feat_size, graph_feat_size, dropout):
super(GetContext, self).__init__()
self.project_node = nn.Sequential(
nn.Linear(node_feat_size, graph_feat_size),
nn.LeakyReLU()
)
self.project_edge1 = nn.Sequential(
nn.Linear(node_feat_size + edge_feat_size, graph_feat_size),
nn.LeakyReLU()
)
self.project_edge2 = nn.Sequential(
nn.Dropout(dropout),
nn.Linear(2 * graph_feat_size, 1),
nn.LeakyReLU()
)
self.attentive_gru = AttentiveGRU1(graph_feat_size, graph_feat_size,
graph_feat_size, dropout)
def apply_edges1(self, edges):
"""Edge feature update.
Parameters
----------
edges : EdgeBatch
Container for a batch of edges
Returns
-------
dict
Mapping ``'he1'`` to updated edge features.
"""
return {'he1': torch.cat([edges.src['hv'], edges.data['he']], dim=1)}
def apply_edges2(self, edges):
"""Edge feature update.
Parameters
----------
edges : EdgeBatch
Container for a batch of edges
Returns
-------
dict
Mapping ``'he2'`` to updated edge features.
"""
return {'he2': torch.cat([edges.dst['hv_new'], edges.data['he1']], dim=1)}
def forward(self, g, node_feats, edge_feats):
"""Incorporate edge features and update node representations.
Parameters
----------
g : DGLGraph
DGLGraph for a batch of graphs.
node_feats : float32 tensor of shape (V, node_feat_size)
Input node features. V for the number of nodes.
edge_feats : float32 tensor of shape (E, edge_feat_size)
Input edge features. E for the number of edges.
Returns
-------
float32 tensor of shape (V, graph_feat_size)
Updated node features.
"""
g = g.local_var()
g.ndata['hv'] = node_feats
g.ndata['hv_new'] = self.project_node(node_feats)
g.edata['he'] = edge_feats
g.apply_edges(self.apply_edges1)
g.edata['he1'] = self.project_edge1(g.edata['he1'])
g.apply_edges(self.apply_edges2)
logits = self.project_edge2(g.edata['he2'])
return self.attentive_gru(g, logits, g.edata['he1'], g.ndata['hv_new'])
class GNNLayer(nn.Module):
"""GNNLayer for updating node features.
This layer performs message passing over node representations and update them.
Parameters
----------
node_feat_size : int
Size for the input node features.
graph_feat_size : int
Size for the graph representations to be computed.
dropout : float
The probability for performing dropout.
"""
def __init__(self, node_feat_size, graph_feat_size, dropout):
super(GNNLayer, self).__init__()
self.project_edge = nn.Sequential(
nn.Dropout(dropout),
nn.Linear(2 * node_feat_size, 1),
nn.LeakyReLU()
)
self.attentive_gru = AttentiveGRU2(node_feat_size, graph_feat_size, dropout)
def apply_edges(self, edges):
"""Edge feature generation.
Generate edge features by concatenating the features of the destination
and source nodes.
Parameters
----------
edges : EdgeBatch
Container for a batch of edges.
Returns
-------
dict
Mapping ``'he'`` to the generated edge features.
"""
return {'he': torch.cat([edges.dst['hv'], edges.src['hv']], dim=1)}
def forward(self, g, node_feats):
"""Perform message passing and update node representations.
Parameters
----------
g : DGLGraph
DGLGraph for a batch of graphs.
node_feats : float32 tensor of shape (V, node_feat_size)
Input node features. V for the number of nodes.
Returns
-------
float32 tensor of shape (V, graph_feat_size)
Updated node features.
"""
g = g.local_var()
g.ndata['hv'] = node_feats
g.apply_edges(self.apply_edges)
logits = self.project_edge(g.edata['he'])
return self.attentive_gru(g, logits, node_feats)
class AttentiveFPGNN(nn.Module):
"""`Pushing the Boundaries of Molecular Representation for Drug Discovery with the Graph
Attention Mechanism <https://www.ncbi.nlm.nih.gov/pubmed/31408336>`__
This class performs message passing in AttentiveFP and returns the updated node representations.
Parameters
----------
node_feat_size : int
Size for the input node features.
edge_feat_size : int
Size for the input edge features.
num_layers : int
Number of GNN layers. Default to 2.
graph_feat_size : int
Size for the graph representations to be computed. Default to 200.
dropout : float
The probability for performing dropout. Default to 0.
"""
def __init__(self,
node_feat_size,
edge_feat_size,
num_layers=2,
graph_feat_size=200,
dropout=0.):
super(AttentiveFPGNN, self).__init__()
self.init_context = GetContext(node_feat_size, edge_feat_size, graph_feat_size, dropout)
self.gnn_layers = nn.ModuleList()
for i in range(num_layers - 1):
self.gnn_layers.append(GNNLayer(graph_feat_size, graph_feat_size, dropout))
def forward(self, g, node_feats, edge_feats):
"""Performs message passing and updates node representations.
Parameters
----------
g : DGLGraph
DGLGraph for a batch of graphs.
node_feats : float32 tensor of shape (V, node_feat_size)
Input node features. V for the number of nodes.
edge_feats : float32 tensor of shape (E, edge_feat_size)
Input edge features. E for the number of edges.
Returns
-------
node_feats : float32 tensor of shape (V, graph_feat_size)
Updated node representations.
"""
node_feats = self.init_context(g, node_feats, edge_feats)
for gnn in self.gnn_layers:
node_feats = gnn(g, node_feats)
return node_feats
"""Graph Attention Networks"""
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn.pytorch import GATConv
__all__ = ['GAT']
class GATLayer(nn.Module):
r"""Single GAT layer from `Graph Attention Networks <https://arxiv.org/abs/1710.10903>`__
Parameters
----------
in_feats : int
Number of input node features
out_feats : int
Number of output node features
num_heads : int
Number of attention heads
feat_drop : float
Dropout applied to the input features
attn_drop : float
Dropout applied to attention values of edges
alpha : float
Hyperparameter in LeakyReLU, which is the slope for negative values.
Default to 0.2.
residual : bool
Whether to perform skip connection, default to True.
agg_mode : str
The way to aggregate multi-head attention results, can be either
'flatten' for concatenating all-head results or 'mean' for averaging
all head results.
activation : activation function or None
Activation function applied to the aggregated multi-head results, default to None.
"""
def __init__(self, in_feats, out_feats, num_heads, feat_drop, attn_drop,
alpha=0.2, residual=True, agg_mode='flatten', activation=None):
super(GATLayer, self).__init__()
self.gat_conv = GATConv(in_feats=in_feats, out_feats=out_feats, num_heads=num_heads,
feat_drop=feat_drop, attn_drop=attn_drop,
negative_slope=alpha, residual=residual)
assert agg_mode in ['flatten', 'mean']
self.agg_mode = agg_mode
self.activation = activation
def forward(self, bg, feats):
"""Update node representations
Parameters
----------
bg : DGLGraph
DGLGraph for a batch of graphs.
feats : FloatTensor of shape (N, M1)
* N is the total number of nodes in the batch of graphs
* M1 is the input node feature size, which equals in_feats in initialization
Returns
-------
feats : FloatTensor of shape (N, M2)
* N is the total number of nodes in the batch of graphs
* M2 is the output node representation size, which equals
out_feats in initialization if self.agg_mode == 'mean' and
out_feats * num_heads in initialization otherwise.
"""
feats = self.gat_conv(bg, feats)
if self.agg_mode == 'flatten':
feats = feats.flatten(1)
else:
feats = feats.mean(1)
if self.activation is not None:
feats = self.activation(feats)
return feats
class GAT(nn.Module):
r"""GAT from `Graph Attention Networks <https://arxiv.org/abs/1710.10903>`__
Parameters
----------
in_feats : int
Number of input node features
hidden_feats : list of int
``hidden_feats[i]`` gives the output size of an attention head in the i-th GAT layer.
``len(hidden_feats)`` equals the number of GAT layers. By default, we use ``[32, 32]``.
num_heads : list of int
``num_heads[i]`` gives the number of attention heads in the i-th GAT layer.
``len(num_heads)`` equals the number of GAT layers. By default, we use 4 attention heads
for each GAT layer.
feat_drops : list of float
``feat_drops[i]`` gives the dropout applied to the input features in the i-th GAT layer.
``len(feat_drops)`` equals the number of GAT layers. By default, this will be zero for
all GAT layers.
attn_drops : list of float
``attn_drops[i]`` gives the dropout applied to attention values of edges in the i-th GAT
layer. ``len(attn_drops)`` equals the number of GAT layers. By default, this will be zero
for all GAT layers.
alphas : list of float
Hyperparameters in LeakyReLU, which are the slopes for negative values. ``alphas[i]``
gives the slope for negative value in the i-th GAT layer. ``len(alphas)`` equals the
number of GAT layers. By default, this will be 0.2 for all GAT layers.
residuals : list of bool
``residual[i]`` decides if residual connection is to be used for the i-th GAT layer.
``len(residual)`` equals the number of GAT layers. By default, residual connection
is performed for each GAT layer.
agg_modes : list of str
The way to aggregate multi-head attention results for each GAT layer, which can be either
'flatten' for concatenating all-head results or 'mean' for averaging all-head results.
``agg_modes[i]`` gives the way to aggregate multi-head attention results for the i-th
GAT layer. ``len(agg_modes)`` equals the number of GAT layers. By default, we flatten
all-head results for each GAT layer.
activations : list of activation function or None
``activations[i]`` gives the activation function applied to the aggregated multi-head
results for the i-th GAT layer. ``len(activations)`` equals the number of GAT layers.
By default, no activation is applied for each GAT layer.
"""
def __init__(self, in_feats, hidden_feats=None, num_heads=None, feat_drops=None,
attn_drops=None, alphas=None, residuals=None, agg_modes=None, activations=None):
super(GAT, self).__init__()
if hidden_feats is None:
hidden_feats = [32, 32]
n_layers = len(hidden_feats)
if num_heads is None:
num_heads = [4 for _ in range(n_layers)]
if feat_drops is None:
feat_drops = [0. for _ in range(n_layers)]
if attn_drops is None:
attn_drops = [0. for _ in range(n_layers)]
if alphas is None:
alphas = [0.2 for _ in range(n_layers)]
if residuals is None:
residuals = [True for _ in range(n_layers)]
if agg_modes is None:
agg_modes = ['flatten' for _ in range(n_layers - 1)]
agg_modes.append('mean')
if activations is None:
activations = [F.elu for _ in range(n_layers - 1)]
activations.append(None)
lengths = [len(hidden_feats), len(num_heads), len(feat_drops), len(attn_drops),
len(alphas), len(residuals), len(agg_modes), len(activations)]
assert len(set(lengths)) == 1, 'Expect the lengths of hidden_feats, num_heads, ' \
'feat_drops, attn_drops, alphas, residuals, ' \
'agg_modes and activations to be the same, ' \
'got {}'.format(lengths)
self.hidden_feats = hidden_feats
self.num_heads = num_heads
self.agg_modes = agg_modes
self.gnn_layers = nn.ModuleList()
for i in range(n_layers):
self.gnn_layers.append(GATLayer(in_feats, hidden_feats[i], num_heads[i],
feat_drops[i], attn_drops[i], alphas[i],
residuals[i], agg_modes[i], activations[i]))
if agg_modes[i] == 'flatten':
in_feats = hidden_feats[i] * num_heads[i]
else:
in_feats = hidden_feats[i]
def forward(self, g, feats):
"""Update node representations.
Parameters
----------
g : DGLGraph
DGLGraph for a batch of graphs
feats : FloatTensor of shape (N, M1)
* N is the total number of nodes in the batch of graphs
* M1 is the input node feature size, which equals in_feats in initialization
Returns
-------
feats : FloatTensor of shape (N, M2)
* N is the total number of nodes in the batch of graphs
* M2 is the output node representation size, which equals
hidden_sizes[-1] if agg_modes[-1] == 'mean' and
hidden_sizes[-1] * num_heads[-1] otherwise.
"""
for gnn in self.gnn_layers:
feats = gnn(g, feats)
return feats
"""Graph Convolutional Networks."""
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn.pytorch import GraphConv
__all__ = ['GCN']
class GCNLayer(nn.Module):
r"""Single GCN layer from `Semi-Supervised Classification with Graph Convolutional Networks
<https://arxiv.org/abs/1609.02907>`__
Parameters
----------
in_feats : int
Number of input node features.
out_feats : int
Number of output node features.
activation : activation function
Default to be None.
residual : bool
Whether to use residual connection, default to be True.
batchnorm : bool
Whether to use batch normalization on the output,
default to be True.
dropout : float
The probability for dropout. Default to be 0., i.e. no
dropout is performed.
"""
def __init__(self, in_feats, out_feats, activation=None,
residual=True, batchnorm=True, dropout=0.):
super(GCNLayer, self).__init__()
self.activation = activation
self.graph_conv = GraphConv(in_feats=in_feats, out_feats=out_feats,
norm=False, activation=activation)
self.dropout = nn.Dropout(dropout)
self.residual = residual
if residual:
self.res_connection = nn.Linear(in_feats, out_feats)
self.bn = batchnorm
if batchnorm:
self.bn_layer = nn.BatchNorm1d(out_feats)
def forward(self, g, feats):
"""Update node representations.
Parameters
----------
g : DGLGraph
DGLGraph for a batch of graphs
feats : FloatTensor of shape (N, M1)
* N is the total number of nodes in the batch of graphs
* M1 is the input node feature size, which must match in_feats in initialization
Returns
-------
new_feats : FloatTensor of shape (N, M2)
* M2 is the output node feature size, which must match out_feats in initialization
"""
new_feats = self.graph_conv(g, feats)
if self.residual:
res_feats = self.activation(self.res_connection(feats))
new_feats = new_feats + res_feats
new_feats = self.dropout(new_feats)
if self.bn:
new_feats = self.bn_layer(new_feats)
return new_feats
class GCN(nn.Module):
r"""GCN from `Semi-Supervised Classification with Graph Convolutional Networks
<https://arxiv.org/abs/1609.02907>`__
Parameters
----------
in_feats : int
Number of input node features.
hidden_feats : list of int
``hidden_feats[i]`` gives the size of node representations after the i-th GCN layer.
``len(hidden_feats)`` equals the number of GCN layers. By default, we use
``[64, 64]``.
activation : list of activation functions or None
If None, no activation will be applied. If not None, ``activation[i]`` gives the
activation function to be used for the i-th GCN layer. ``len(activation)`` equals
the number of GCN layers. By default, ReLU is applied for all GCN layers.
residual : list of bool
``residual[i]`` decides if residual connection is to be used for the i-th GCN layer.
``len(residual)`` equals the number of GCN layers. By default, residual connection
is performed for each GCN layer.
batchnorm : list of bool
``batchnorm[i]`` decides if batch normalization is to be applied on the output of
the i-th GCN layer. ``len(batchnorm)`` equals the number of GCN layers. By default,
batch normalization is applied for all GCN layers.
dropout : list of float
``dropout[i]`` decides the dropout probability on the output of the i-th GCN layer.
``len(dropout)`` equals the number of GCN layers. By default, no dropout is
performed for all layers.
"""
def __init__(self, in_feats, hidden_feats=None, activation=None, residual=None,
batchnorm=None, dropout=None):
super(GCN, self).__init__()
if hidden_feats is None:
hidden_feats = [64, 64]
n_layers = len(hidden_feats)
if activation is None:
activation = [F.relu for _ in range(n_layers)]
if residual is None:
residual = [True for _ in range(n_layers)]
if batchnorm is None:
batchnorm = [True for _ in range(n_layers)]
if dropout is None:
dropout = [0. for _ in range(n_layers)]
lengths = [len(hidden_feats), len(activation),
len(residual), len(batchnorm), len(dropout)]
assert len(set(lengths)) == 1, 'Expect the lengths of hidden_feats, activation, ' \
'residual, batchnorm and dropout to be the same, ' \
'got {}'.format(lengths)
self.hidden_feats = hidden_feats
self.gnn_layers = nn.ModuleList()
for i in range(n_layers):
self.gnn_layers.append(GCNLayer(in_feats, hidden_feats[i], activation[i],
residual[i], batchnorm[i], dropout[i]))
in_feats = hidden_feats[i]
def forward(self, g, feats):
"""Update node representations.
Parameters
----------
g : DGLGraph
DGLGraph for a batch of graphs
feats : FloatTensor of shape (N, M1)
* N is the total number of nodes in the batch of graphs
* M1 is the input node feature size, which equals in_feats in initialization
Returns
-------
feats : FloatTensor of shape (N, M2)
* N is the total number of nodes in the batch of graphs
* M2 is the output node representation size, which equals
hidden_sizes[-1] in initialization.
"""
for gnn in self.gnn_layers:
feats = gnn(g, feats)
return feats
"""MGCN"""
import dgl.function as fn
import torch
import torch.nn as nn
from .schnet import RBFExpansion
__all__ = ['MGCNGNN']
class EdgeEmbedding(nn.Module):
"""Module for embedding edges.
Edges whose end nodes have the same combination of types
share the same initial embedding.
Parameters
----------
num_types : int
Number of edge types to embed.
edge_feats : int
Size for the edge representations to learn.
"""
def __init__(self, num_types, edge_feats):
super(EdgeEmbedding, self).__init__()
self.embed = nn.Embedding(num_types, edge_feats)
def get_edge_types(self, edges):
"""Generates edge types.
The edge type is based on the type of the source and destination nodes.
Note that directions are not distinguished, e.g. C-O and O-C are the same edge type.
To map each pair of node types to a unique number, we use an unordered pairing function.
See more details in this discussion:
https://math.stackexchange.com/questions/23503/create-unique-number-from-2-numbers
Note that the number of edge types should be larger than the square of the maximum node
type in the dataset.
Parameters
----------
edges : EdgeBatch
Container for a batch of edges.
Returns
-------
dict
Mapping 'type' to the computed edge types.
"""
node_type1 = edges.src['type']
node_type2 = edges.dst['type']
return {
'type': node_type1 * node_type2 + \
(torch.abs(node_type1 - node_type2) - 1) ** 2 / 4
}
def forward(self, g, node_types):
"""Embeds edge types.
Parameters
----------
g : DGLGraph
DGLGraph for a batch of graphs.
node_types : int64 tensor of shape (V)
Node types to embed, V for the number of nodes.
Returns
-------
float32 tensor of shape (E, edge_feats)
Edge representations.
"""
g = g.local_var()
g.ndata['type'] = node_types
g.apply_edges(self.get_edge_types)
return self.embed(g.edata['type'])
class VEConv(nn.Module):
"""Vertex-Edge Convolution in MGCN
MGCN is introduced in `Molecular Property Prediction: A Multilevel Quantum Interactions
Modeling Perspective <https://arxiv.org/abs/1906.11081>`__.
This layer combines both node and edge features in updating node representations.
Parameters
----------
dist_feats : int
Size for the expanded distances.
feats : int
Size for the input and output node and edge representations.
update_edge : bool
Whether to update edge representations. Default to True.
"""
def __init__(self, dist_feats, feats, update_edge=True):
super(VEConv, self).__init__()
self.update_dists = nn.Sequential(
nn.Linear(dist_feats, feats),
nn.Softplus(beta=0.5, threshold=14),
nn.Linear(feats, feats)
)
if update_edge:
self.update_edge_feats = nn.Linear(feats, feats)
else:
self.update_edge_feats = None
def forward(self, g, node_feats, edge_feats, expanded_dists):
"""Performs message passing and updates node and edge representations.
Parameters
----------
g : DGLGraph
DGLGraph for a batch of graphs.
node_feats : float32 tensor of shape (V, feats)
Input node features.
edge_feats : float32 tensor of shape (E, feats)
Input edge features.
expanded_dists : float32 tensor of shape (E, dist_feats)
Expanded distances, i.e. the output of RBFExpansion.
Returns
-------
node_feats : float32 tensor of shape (V, feats)
Updated node representations.
edge_feats : float32 tensor of shape (E, feats)
Edge representations, updated if ``update_edge == True`` in initialization.
"""
expanded_dists = self.update_dists(expanded_dists)
if self.update_edge_feats is not None:
edge_feats = self.update_edge_feats(edge_feats)
g = g.local_var()
g.ndata.update({'hv': node_feats})
g.edata.update({'dist': expanded_dists, 'he': edge_feats})
g.update_all(message_func=[fn.u_mul_e('hv', 'dist', 'm_0'),
fn.copy_e('he', 'm_1')],
reduce_func=[fn.sum('m_0', 'hv_0'),
fn.sum('m_1', 'hv_1')])
node_feats = g.ndata.pop('hv_0') + g.ndata.pop('hv_1')
return node_feats, edge_feats
class MultiLevelInteraction(nn.Module):
"""Building block for MGCN.
MGCN is introduced in `Molecular Property Prediction: A Multilevel Quantum Interactions
Modeling Perspective <https://arxiv.org/abs/1906.11081>`__. This layer combines node features,
edge features and expanded distances in message passing and updates node and edge
representations.
Parameters
----------
feats : int
Size for the input and output node and edge representations.
dist_feats : int
Size for the expanded distances.
"""
def __init__(self, feats, dist_feats):
super(MultiLevelInteraction, self).__init__()
self.project_in_node_feats = nn.Linear(feats, feats)
self.conv = VEConv(dist_feats, feats)
self.project_out_node_feats = nn.Sequential(
nn.Linear(feats, feats),
nn.Softplus(beta=0.5, threshold=14),
nn.Linear(feats, feats)
)
self.project_edge_feats = nn.Sequential(
nn.Linear(feats, feats),
nn.Softplus(beta=0.5, threshold=14)
)
def forward(self, g, node_feats, edge_feats, expanded_dists):
"""Performs message passing and updates node and edge representations.
Parameters
----------
g : DGLGraph
DGLGraph for a batch of graphs.
node_feats : float32 tensor of shape (V, feats)
Input node features.
edge_feats : float32 tensor of shape (E, feats)
Input edge features
expanded_dists : float32 tensor of shape (E, dist_feats)
Expanded distances, i.e. the output of RBFExpansion.
Returns
-------
node_feats : float32 tensor of shape (V, feats)
Updated node representations.
edge_feats : float32 tensor of shape (E, feats)
Updated edge representations.
"""
new_node_feats = self.project_in_node_feats(node_feats)
new_node_feats, edge_feats = self.conv(g, new_node_feats, edge_feats, expanded_dists)
new_node_feats = self.project_out_node_feats(new_node_feats)
node_feats = node_feats + new_node_feats
edge_feats = self.project_edge_feats(edge_feats)
return node_feats, edge_feats
class MGCNGNN(nn.Module):
"""MGCN.
MGCN is introduced in `Molecular Property Prediction: A Multilevel Quantum Interactions
Modeling Perspective <https://arxiv.org/abs/1906.11081>`__.
This class performs message passing in MGCN and returns the updated node representations.
Parameters
----------
feats : int
Size for the node and edge embeddings to learn. Default to 128.
n_layers : int
Number of gnn layers to use. Default to 3.
num_node_types : int
Number of node types to embed. Default to 100.
num_edge_types : int
Number of edge types to embed. Default to 3000.
cutoff : float
Largest center in RBF expansion. Default to 30.
gap : float
Difference between two adjacent centers in RBF expansion. Default to 0.1.
"""
def __init__(self, feats=128, n_layers=3, num_node_types=100,
num_edge_types=3000, cutoff=30., gap=0.1):
super(MGCNGNN, self).__init__()
self.node_embed = nn.Embedding(num_node_types, feats)
self.edge_embed = EdgeEmbedding(num_edge_types, feats)
self.rbf = RBFExpansion(high=cutoff, gap=gap)
self.gnn_layers = nn.ModuleList()
for i in range(n_layers):
self.gnn_layers.append(MultiLevelInteraction(feats, len(self.rbf.centers)))
def forward(self, g, node_types, edge_dists):
"""Performs message passing and updates node representations.
Parameters
----------
g : DGLGraph
DGLGraph for a batch of graphs.
node_types : int64 tensor of shape (V)
Node types to embed, V for the number of nodes.
edge_dists : float32 tensor of shape (E, 1)
Distances between end nodes of edges, E for the number of edges.
Returns
-------
float32 tensor of shape (V, feats * (n_layers + 1))
Output node representations.
"""
node_feats = self.node_embed(node_types)
edge_feats = self.edge_embed(g, node_types)
expanded_dists = self.rbf(edge_dists)
all_layer_node_feats = [node_feats]
for gnn in self.gnn_layers:
node_feats, edge_feats = gnn(g, node_feats, edge_feats, expanded_dists)
all_layer_node_feats.append(node_feats)
return torch.cat(all_layer_node_feats, dim=1)
"""MPNN"""
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn.pytorch import NNConv
__all__ = ['MPNNGNN']
class MPNNGNN(nn.Module):
"""MPNN.
MPNN is introduced in `Neural Message Passing for Quantum Chemistry
<https://arxiv.org/abs/1704.01212>`__.
This class performs message passing in MPNN and returns the updated node representations.
Parameters
----------
node_in_feats : int
Size for the input node features.
node_out_feats : int
Size for the output node representations. Default to 64.
edge_in_feats : int
Size for the input edge features. Default to 128.
edge_hidden_feats : int
Size for the hidden edge representations.
num_step_message_passing : int
Number of message passing steps. Default to 6.
"""
def __init__(self, node_in_feats, edge_in_feats, node_out_feats=64,
edge_hidden_feats=128, num_step_message_passing=6):
super(MPNNGNN, self).__init__()
self.project_node_feats = nn.Sequential(
nn.Linear(node_in_feats, node_out_feats),
nn.ReLU()
)
self.num_step_message_passing = num_step_message_passing
edge_network = nn.Sequential(
nn.Linear(edge_in_feats, edge_hidden_feats),
nn.ReLU(),
nn.Linear(edge_hidden_feats, node_out_feats * node_out_feats)
)
self.gnn_layer = NNConv(
in_feats=node_out_feats,
out_feats=node_out_feats,
edge_func=edge_network,
aggregator_type='sum'
)
self.gru = nn.GRU(node_out_feats, node_out_feats)
def forward(self, g, node_feats, edge_feats):
"""Performs message passing and updates node representations.
Parameters
----------
g : DGLGraph
DGLGraph for a batch of graphs.
node_feats : float32 tensor of shape (V, node_in_feats)
Input node features.
edge_feats : float32 tensor of shape (E, edge_in_feats)
Input edge features.
Returns
-------
node_feats : float32 tensor of shape (V, node_out_feats)
Output node representations.
"""
node_feats = self.project_node_feats(node_feats) # (V, node_out_feats)
hidden_feats = node_feats.unsqueeze(0) # (1, V, node_out_feats)
for i in range(self.num_step_message_passing):
node_feats = F.relu(self.gnn_layer(g, node_feats, edge_feats))
node_feats, hidden_feats = self.gru(node_feats.unsqueeze(0), hidden_feats)
node_feats = node_feats.squeeze(0)
return node_feats
# -*- coding:utf-8 -*-
# pylint: disable=C0103, C0111, W0621
"""SchNet"""
import numpy as np
import torch
import torch.nn as nn
from dgl.nn.pytorch import CFConv
__all__ = ['SchNetGNN']
class RBFExpansion(nn.Module):
"""Expand distances between nodes by radial basis functions.
.. math::
\exp(- \gamma * ||d - \mu||^2)
where :math:`d` is the distance between two nodes and :math:`\mu` helps centralizes
the distances. We use multiple centers evenly distributed in the range of
:math:`[\text{low}, \text{high}]` with the difference between two adjacent centers
being :math:`gap`.
The number of centers is decided by :math:`(\text{high} - \text{low}) / \text{gap}`.
Choosing fewer centers corresponds to reducing the resolution of the filter.
Parameters
----------
low : float
Smallest center. Default to 0.
high : float
Largest center. Default to 30.
gap : float
Difference between two adjacent centers. :math:`\gamma` will be computed as the
reciprocal of gap. Default to 0.1.
"""
def __init__(self, low=0., high=30., gap=0.1):
super(RBFExpansion, self).__init__()
num_centers = int(np.ceil((high - low) / gap))
centers = np.linspace(low, high, num_centers)
self.centers = nn.Parameter(torch.tensor(centers).float(), requires_grad=False)
self.gamma = 1 / gap
def forward(self, edge_dists):
"""Expand distances.
Parameters
----------
edge_dists : float32 tensor of shape (E, 1)
Distances between end nodes of edges, E for the number of edges.
Returns
-------
float32 tensor of shape (E, len(self.centers))
Expanded distances.
"""
radial = edge_dists - self.centers
coef = - self.gamma
return torch.exp(coef * (radial ** 2))
class Interaction(nn.Module):
"""Building block for SchNet.
SchNet is introduced in `SchNet: A continuous-filter convolutional neural network for
modeling quantum interactions <https://arxiv.org/abs/1706.08566>`__.
This layer combines node and edge features in message passing and updates node
representations.
Parameters
----------
node_feats : int
Size for the input and output node features.
edge_in_feats : int
Size for the input edge features.
hidden_feats : int
Size for hidden representations.
"""
def __init__(self, node_feats, edge_in_feats, hidden_feats):
super(Interaction, self).__init__()
self.conv = CFConv(node_feats, edge_in_feats, hidden_feats, node_feats)
self.project_out = nn.Linear(node_feats, node_feats)
def forward(self, g, node_feats, edge_feats):
"""Performs message passing and updates node representations.
Parameters
----------
g : DGLGraph
DGLGraph for a batch of graphs.
node_feats : float32 tensor of shape (V, node_feats)
Input node features, V for the number of nodes.
edge_feats : float32 tensor of shape (E, edge_in_feats)
Input edge features, E for the number of edges.
Returns
-------
float32 tensor of shape (V, node_feats)
Updated node representations.
"""
node_feats = self.conv(g, node_feats, edge_feats)
return self.project_out(node_feats)
class SchNetGNN(nn.Module):
"""SchNet.
SchNet is introduced in `SchNet: A continuous-filter convolutional neural network for
modeling quantum interactions <https://arxiv.org/abs/1706.08566>`__.
This class performs message passing in SchNet and returns the updated node representations.
Parameters
----------
node_feats : int
Size for node representations to learn. Default to 64.
hidden_feats : list of int
``hidden_feats[i]`` gives the size of hidden representations for the i-th interaction
layer. ``len(hidden_feats)`` equals the number of interaction layers.
Default to ``[64, 64, 64]``.
num_node_types : int
Number of node types to embed. Default to 100.
cutoff : float
Largest center in RBF expansion. Default to 30.
gap : float
Difference between two adjacent centers in RBF expansion. Default to 0.1.
"""
def __init__(self, node_feats=64, hidden_feats=None, num_node_types=100, cutoff=30., gap=0.1):
super(SchNetGNN, self).__init__()
if hidden_feats is None:
hidden_feats = [64, 64, 64]
self.embed = nn.Embedding(num_node_types, node_feats)
self.rbf = RBFExpansion(high=cutoff, gap=gap)
n_layers = len(hidden_feats)
self.gnn_layers = nn.ModuleList()
for i in range(n_layers):
self.gnn_layers.append(
Interaction(node_feats, len(self.rbf.centers), hidden_feats[i]))
def forward(self, g, node_types, edge_dists):
"""Performs message passing and updates node representations.
Parameters
----------
g : DGLGraph
DGLGraph for a batch of graphs.
node_types : int64 tensor of shape (V)
Node types to embed, V for the number of nodes.
edge_dists : float32 tensor of shape (E, 1)
Distances between end nodes of edges, E for the number of edges.
Returns
-------
node_feats : float32 tensor of shape (V, node_feats)
Updated node representations.
"""
node_feats = self.embed(node_types)
expanded_dists = self.rbf(edge_dists)
for gnn in self.gnn_layers:
node_feats = gnn(g, node_feats, expanded_dists)
return node_feats
"""Collection of model architectures"""
from .jtnn import *
from .dgmg import *
from .attentivefp_predictor import *
from .gat_predictor import *
from .gcn_predictor import *
from .mlp_predictor import *
from .schnet_predictor import *
from .mgcn_predictor import *
from .mpnn_predictor import *
from .acnn import *
"""Atomic Convolutional Networks for Predicting Protein-Ligand Binding Affinity"""
# pylint: disable=C0103, C0123
import itertools
import numpy as np
import torch
import torch.nn as nn
from dgl import BatchedDGLHeteroGraph
from dgl.nn.pytorch import AtomicConv
__all__ = ['ACNN']
def truncated_normal_(tensor, mean=0., std=1.):
"""Fills the given tensor in-place with elements sampled from the truncated normal
distribution parameterized by mean and std.
The generated values follow a normal distribution with specified mean and
standard deviation, except that values whose magnitude is more than 2 std
from the mean are dropped.
We credit to Ruotian Luo for this implementation:
https://discuss.pytorch.org/t/implementing-truncated-normal-initializer/4778/15.
Parameters
----------
tensor : Float32 tensor of arbitrary shape
Tensor to be filled.
mean : float
Mean of the truncated normal distribution.
std : float
Standard deviation of the truncated normal distribution.
"""
shape = tensor.shape
tmp = tensor.new_empty(shape + (4,)).normal_()
valid = (tmp < 2) & (tmp > -2)
ind = valid.max(-1, keepdim=True)[1]
tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1))
tensor.data.mul_(std).add_(mean)
class ACNNPredictor(nn.Module):
"""Predictor for ACNN.
Parameters
----------
in_size : int
Number of radial filters used.
hidden_sizes : list of int
Specifying the hidden sizes for all layers in the predictor.
weight_init_stddevs : list of float
Specifying the standard deviations to use for truncated normal
distributions in initialzing weights for the predictor.
dropouts : list of float
Specifying the dropouts to use for all layers in the predictor.
features_to_use : None or float tensor of shape (T)
In the original paper, these are atomic numbers to consider, representing the types
of atoms. T for the number of types of atomic numbers. Default to None.
num_tasks : int
Output size.
"""
def __init__(self, in_size, hidden_sizes, weight_init_stddevs,
dropouts, features_to_use, num_tasks):
super(ACNNPredictor, self).__init__()
if type(features_to_use) != type(None):
in_size *= len(features_to_use)
modules = []
for i, h in enumerate(hidden_sizes):
linear_layer = nn.Linear(in_size, h)
truncated_normal_(linear_layer.weight, std=weight_init_stddevs[i])
modules.append(linear_layer)
modules.append(nn.ReLU())
modules.append(nn.Dropout(dropouts[i]))
in_size = h
linear_layer = nn.Linear(in_size, num_tasks)
truncated_normal_(linear_layer.weight, std=weight_init_stddevs[-1])
modules.append(linear_layer)
self.project = nn.Sequential(*modules)
def forward(self, batch_size, frag1_node_indices_in_complex, frag2_node_indices_in_complex,
ligand_conv_out, protein_conv_out, complex_conv_out):
"""Perform the prediction.
Parameters
----------
batch_size : int
Number of datapoints in a batch.
frag1_node_indices_in_complex : Int64 tensor of shape (V1)
Indices for atoms in the first fragment (protein) in the batched complex.
frag2_node_indices_in_complex : list of int of length V2
Indices for atoms in the second fragment (ligand) in the batched complex.
ligand_conv_out : Float32 tensor of shape (V2, K * T)
Updated ligand node representations. V2 for the number of atoms in the
ligand, K for the number of radial filters, and T for the number of types
of atomic numbers.
protein_conv_out : Float32 tensor of shape (V1, K * T)
Updated protein node representations. V1 for the number of
atoms in the protein, K for the number of radial filters,
and T for the number of types of atomic numbers.
complex_conv_out : Float32 tensor of shape (V1 + V2, K * T)
Updated complex node representations. V1 and V2 separately
for the number of atoms in the ligand and protein, K for
the number of radial filters, and T for the number of
types of atomic numbers.
Returns
-------
Float32 tensor of shape (B, O)
Predicted protein-ligand binding affinity. B for the number
of protein-ligand pairs in the batch and O for the number of tasks.
"""
ligand_feats = self.project(ligand_conv_out) # (V1, O)
protein_feats = self.project(protein_conv_out) # (V2, O)
complex_feats = self.project(complex_conv_out) # (V1+V2, O)
ligand_energy = ligand_feats.reshape(batch_size, -1).sum(-1, keepdim=True) # (B, O)
protein_energy = protein_feats.reshape(batch_size, -1).sum(-1, keepdim=True) # (B, O)
complex_ligand_energy = complex_feats[frag1_node_indices_in_complex].reshape(
batch_size, -1).sum(-1, keepdim=True)
complex_protein_energy = complex_feats[frag2_node_indices_in_complex].reshape(
batch_size, -1).sum(-1, keepdim=True)
complex_energy = complex_ligand_energy + complex_protein_energy
return complex_energy - (ligand_energy + protein_energy)
class ACNN(nn.Module):
"""Atomic Convolutional Networks.
The model was proposed in `Atomic Convolutional Networks for
Predicting Protein-Ligand Binding Affinity <https://arxiv.org/abs/1703.10603>`__.
The prediction proceeds as follows:
1. Perform message passing to update atom representations for the
ligand, protein and protein-ligand complex.
2. Predict the energy of atoms from their representations with an MLP.
3. Take the sum of predicted energy of atoms within each molecule for
predicted energy of the ligand, protein and protein-ligand complex.
4. Make the final prediction by subtracting the predicted ligand and protein
energy from the predicted complex energy.
Parameters
----------
hidden_sizes : list of int
``hidden_sizes[i]`` gives the size of hidden representations in the i-th
hidden layer of the MLP. By Default, ``[32, 32, 16]`` will be used.
weight_init_stddevs : list of float
``weight_init_stddevs[i]`` gives the std to initialize parameters in the
i-th layer of the MLP. Note that ``len(weight_init_stddevs) == len(hidden_sizes) + 1``
due to the output layer. By default, we use ``1 / sqrt(hidden_sizes[i])`` for hidden
layers and 0.01 for the output layer.
dropouts : list of float
``dropouts[i]`` gives the dropout in the i-th hidden layer of the MLP. By default,
no dropout is used.
features_to_use : None or float tensor of shape (T)
In the original paper, these are atomic numbers to consider, representing the types
of atoms. T for the number of types of atomic numbers. If None, we use same parameters
for all atoms regardless of their type. Default to None.
radial : list
The list consists of 3 sublists of floats, separately for the
options of interaction cutoff, the options of rbf kernel mean and the
options of rbf kernel scaling. By default,
``[[12.0], [0.0, 2.0, 4.0, 6.0, 8.0], [4.0]]`` will be used.
num_tasks : int
Number of output tasks. Default to 1.
"""
def __init__(self, hidden_sizes=None, weight_init_stddevs=None, dropouts=None,
features_to_use=None, radial=None, num_tasks=1):
super(ACNN, self).__init__()
if hidden_sizes is None:
hidden_sizes = [32, 32, 16]
if weight_init_stddevs is None:
weight_init_stddevs = [1. / float(np.sqrt(hidden_sizes[i]))
for i in range(len(hidden_sizes))]
weight_init_stddevs.append(0.01)
if dropouts is None:
dropouts = [0. for _ in range(len(hidden_sizes))]
if radial is None:
radial = [[12.0], [0.0, 2.0, 4.0, 6.0, 8.0], [4.0]]
# Take the product of sets of options and get a list of 3-tuples.
radial_params = [x for x in itertools.product(*radial)]
radial_params = torch.stack(list(map(torch.tensor, zip(*radial_params))), dim=1)
interaction_cutoffs = radial_params[:, 0]
rbf_kernel_means = radial_params[:, 1]
rbf_kernel_scaling = radial_params[:, 2]
self.ligand_conv = AtomicConv(interaction_cutoffs, rbf_kernel_means,
rbf_kernel_scaling, features_to_use)
self.protein_conv = AtomicConv(interaction_cutoffs, rbf_kernel_means,
rbf_kernel_scaling, features_to_use)
self.complex_conv = AtomicConv(interaction_cutoffs, rbf_kernel_means,
rbf_kernel_scaling, features_to_use)
self.predictor = ACNNPredictor(radial_params.shape[0], hidden_sizes,
weight_init_stddevs, dropouts, features_to_use, num_tasks)
def forward(self, graph):
"""Apply the model for prediction.
Parameters
----------
graph : DGLHeteroGraph
DGLHeteroGraph consisting of the ligand graph, the protein graph
and the complex graph, along with preprocessed features. For a batch of
protein-ligand pairs, we assume zero padding is performed so that the
number of ligand and protein atoms is the same in all pairs.
Returns
-------
Float32 tensor of shape (B, O)
Predicted protein-ligand binding affinity. B for the number
of protein-ligand pairs in the batch and O for the number of tasks.
"""
ligand_graph = graph[('ligand_atom', 'ligand', 'ligand_atom')]
ligand_graph_node_feats = ligand_graph.ndata['atomic_number']
assert ligand_graph_node_feats.shape[-1] == 1
ligand_graph_distances = ligand_graph.edata['distance']
ligand_conv_out = self.ligand_conv(ligand_graph,
ligand_graph_node_feats,
ligand_graph_distances)
protein_graph = graph[('protein_atom', 'protein', 'protein_atom')]
protein_graph_node_feats = protein_graph.ndata['atomic_number']
assert protein_graph_node_feats.shape[-1] == 1
protein_graph_distances = protein_graph.edata['distance']
protein_conv_out = self.protein_conv(protein_graph,
protein_graph_node_feats,
protein_graph_distances)
complex_graph = graph[:, 'complex', :]
complex_graph_node_feats = complex_graph.ndata['atomic_number']
assert complex_graph_node_feats.shape[-1] == 1
complex_graph_distances = complex_graph.edata['distance']
complex_conv_out = self.complex_conv(complex_graph,
complex_graph_node_feats,
complex_graph_distances)
frag1_node_indices_in_complex = torch.where(complex_graph.ndata['_TYPE'] == 0)[0]
frag2_node_indices_in_complex = list(set(range(complex_graph.number_of_nodes())) -
set(frag1_node_indices_in_complex.tolist()))
# Hack the case when we are working with a single graph.
if not isinstance(graph, BatchedDGLHeteroGraph):
graph.batch_size = 1
return self.predictor(
graph.batch_size,
frag1_node_indices_in_complex,
frag2_node_indices_in_complex,
ligand_conv_out, protein_conv_out, complex_conv_out)
"""AttentiveFP"""
import torch.nn as nn
from ..gnn import AttentiveFPGNN
from ..readout import AttentiveFPReadout
__all__ = ['AttentiveFPPredictor']
class AttentiveFPPredictor(nn.Module):
"""AttentiveFP for regression and classification on graphs.
AttentiveFP is introduced in `Pushing the Boundaries of Molecular Representation for Drug
Discovery with the Graph Attention Mechanism.
<https://www.ncbi.nlm.nih.gov/pubmed/31408336>`__
Parameters
----------
node_feat_size : int
Size for the input node features.
edge_feat_size : int
Size for the input edge features.
num_layers : int
Number of GNN layers. Default to 2.
num_timesteps : int
Times of updating the graph representations with GRU. Default to 2.
graph_feat_size : int
Size for the learned graph representations. Default to 200.
n_tasks : int
Number of tasks, which is also the output size. Default to 1.
dropout : float
Probability for performing the dropout. Default to 0.
"""
def __init__(self,
node_feat_size,
edge_feat_size,
num_layers=2,
num_timesteps=2,
graph_feat_size=200,
n_tasks=1,
dropout=0.):
super(AttentiveFPPredictor, self).__init__()
self.gnn = AttentiveFPGNN(node_feat_size=node_feat_size,
edge_feat_size=edge_feat_size,
num_layers=num_layers,
graph_feat_size=graph_feat_size,
dropout=dropout)
self.readout = AttentiveFPReadout(feat_size=graph_feat_size,
num_timesteps=num_timesteps,
dropout=dropout)
self.predict = nn.Sequential(
nn.Dropout(dropout),
nn.Linear(graph_feat_size, n_tasks)
)
def forward(self, g, node_feats, edge_feats, get_node_weight=False):
"""Graph-level regression/soft classification.
Parameters
----------
g : DGLGraph
DGLGraph for a batch of graphs.
node_feats : float32 tensor of shape (V, node_feat_size)
Input node features. V for the number of nodes.
edge_feats : float32 tensor of shape (E, edge_feat_size)
Input edge features. E for the number of edges.
get_node_weight : bool
Whether to get the weights of atoms during readout. Default to False.
Returns
-------
float32 tensor of shape (G, n_tasks)
Prediction for the graphs in the batch. G for the number of graphs.
node_weights : list of float32 tensor of shape (V, 1), optional
This is returned when ``get_node_weight`` is ``True``.
The list has a length ``num_timesteps`` and ``node_weights[i]``
gives the node weights in the i-th update.
"""
node_feats = self.gnn(g, node_feats, edge_feats)
if get_node_weight:
g_feats, node_weights = self.readout(g, node_feats, get_node_weight)
return self.predict(g_feats), node_weights
else:
g_feats = self.readout(g, node_feats, get_node_weight)
return self.predict(g_feats)
# pylint: disable=C0103, W0622, R1710, W0104
"""
Learning Deep Generative Models of Graphs
https://arxiv.org/pdf/1803.03324.pdf
"""
import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from dgl import DGLGraph
from functools import partial
from rdkit import Chem
from torch.distributions import Categorical
__all__ = ['DGMG']
class MoleculeEnv(object):
"""MDP environment for generating molecules.
Parameters
----------
atom_types : list
E.g. ['C', 'N']
bond_types : list
E.g. [Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE,
Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC]
"""
def __init__(self, atom_types, bond_types):
super(MoleculeEnv, self).__init__()
self.atom_types = atom_types
self.bond_types = bond_types
self.atom_type_to_id = dict()
self.bond_type_to_id = dict()
for id, a_type in enumerate(atom_types):
self.atom_type_to_id[a_type] = id
for id, b_type in enumerate(bond_types):
self.bond_type_to_id[b_type] = id
def get_decision_sequence(self, mol, atom_order):
"""Extract a decision sequence with which DGMG can generate the
molecule with a specified atom order.
Parameters
----------
mol : Chem.rdchem.Mol
atom_order : list
Specifies a mapping between the original atom
indices and the new atom indices. In particular,
atom_order[i] is re-labeled as i.
Returns
-------
decisions : list
decisions[i] is a 2-tuple (i, j)
- If i = 0, j specifies either the type of the atom to add
self.atom_types[j] or termination with j = len(self.atom_types)
- If i = 1, j specifies either the type of the bond to add
self.bond_types[j] or termination with j = len(self.bond_types)
- If i = 2, j specifies the destination atom id for the bond to add.
With the formulation of DGMG, j must be created before the decision.
"""
decisions = []
old2new = dict()
for new_id, old_id in enumerate(atom_order):
atom = mol.GetAtomWithIdx(old_id)
a_type = atom.GetSymbol()
decisions.append((0, self.atom_type_to_id[a_type]))
for bond in atom.GetBonds():
u = bond.GetBeginAtomIdx()
v = bond.GetEndAtomIdx()
if v == old_id:
u, v = v, u
if v in old2new:
decisions.append((1, self.bond_type_to_id[bond.GetBondType()]))
decisions.append((2, old2new[v]))
decisions.append((1, len(self.bond_types)))
old2new[old_id] = new_id
decisions.append((0, len(self.atom_types)))
return decisions
def reset(self, rdkit_mol=False):
"""Setup for generating a new molecule
Parameters
----------
rdkit_mol : bool
Whether to keep a Chem.rdchem.Mol object so
that we know what molecule is being generated
"""
self.dgl_graph = DGLGraph()
# If there are some features for nodes and edges,
# zero tensors will be set for those of new nodes and edges.
self.dgl_graph.set_n_initializer(dgl.frame.zero_initializer)
self.dgl_graph.set_e_initializer(dgl.frame.zero_initializer)
self.mol = None
if rdkit_mol:
# RWMol is a molecule class that is intended to be edited.
self.mol = Chem.RWMol(Chem.MolFromSmiles(''))
def num_atoms(self):
"""Get the number of atoms for the current molecule.
Returns
-------
int
"""
return self.dgl_graph.number_of_nodes()
def add_atom(self, type):
"""Add an atom of the specified type.
Parameters
----------
type : int
Should be in the range of [0, len(self.atom_types) - 1]
"""
self.dgl_graph.add_nodes(1)
if self.mol is not None:
self.mol.AddAtom(Chem.Atom(self.atom_types[type]))
def add_bond(self, u, v, type, bi_direction=True):
"""Add a bond of the specified type between atom u and v.
Parameters
----------
u : int
Index for the first atom
v : int
Index for the second atom
type : int
Index for the bond type
bi_direction : bool
Whether to add edges for both directions in the DGLGraph.
If not, we will only add the edge (u, v).
"""
if bi_direction:
self.dgl_graph.add_edges([u, v], [v, u])
else:
self.dgl_graph.add_edge(u, v)
if self.mol is not None:
self.mol.AddBond(u, v, self.bond_types[type])
def get_current_smiles(self):
"""Get the generated molecule in SMILES
Returns
-------
s : str
SMILES
"""
assert self.mol is not None, 'Expect a Chem.rdchem.Mol object initialized.'
s = Chem.MolToSmiles(self.mol)
return s
class GraphEmbed(nn.Module):
"""Compute a molecule representations out of atom representations.
Parameters
----------
node_hidden_size : int
Size of atom representation
"""
def __init__(self, node_hidden_size):
super(GraphEmbed, self).__init__()
# Setting from the paper
self.graph_hidden_size = 2 * node_hidden_size
# Embed graphs
self.node_gating = nn.Sequential(
nn.Linear(node_hidden_size, 1),
nn.Sigmoid()
)
self.node_to_graph = nn.Linear(node_hidden_size,
self.graph_hidden_size)
def forward(self, g):
"""
Parameters
----------
g : DGLGraph
Current molecule graph
Returns
-------
tensor of dtype float32 and shape (1, self.graph_hidden_size)
Computed representation for the current molecule graph
"""
if g.number_of_nodes() == 0:
# Use a zero tensor for an empty molecule.
return torch.zeros(1, self.graph_hidden_size)
else:
# Node features are stored as hv in ndata.
hvs = g.ndata['hv']
return (self.node_gating(hvs) *
self.node_to_graph(hvs)).sum(0, keepdim=True)
class GraphProp(nn.Module):
"""Perform message passing over a molecule graph and update its atom representations.
Parameters
----------
num_prop_rounds : int
Number of message passing rounds for each time
node_hidden_size : int
Size of atom representation
edge_hidden_size : int
Size of bond representation
"""
def __init__(self, num_prop_rounds, node_hidden_size, edge_hidden_size):
super(GraphProp, self).__init__()
self.num_prop_rounds = num_prop_rounds
# Setting from the paper
self.node_activation_hidden_size = 2 * node_hidden_size
message_funcs = []
self.reduce_funcs = []
node_update_funcs = []
for t in range(num_prop_rounds):
# input being [hv, hu, xuv]
message_funcs.append(nn.Linear(2 * node_hidden_size + edge_hidden_size,
self.node_activation_hidden_size))
self.reduce_funcs.append(partial(self.dgmg_reduce, round=t))
node_update_funcs.append(
nn.GRUCell(self.node_activation_hidden_size,
node_hidden_size))
self.message_funcs = nn.ModuleList(message_funcs)
self.node_update_funcs = nn.ModuleList(node_update_funcs)
def dgmg_msg(self, edges):
"""For an edge u->v, send a message concat([h_u, x_uv])
Parameters
----------
edges : batch of edges
Returns
-------
dict
Dictionary containing messages for the edge batch,
with the messages being tensors of shape (B, F1),
B for the number of edges and F1 for the message size.
"""
return {'m': torch.cat([edges.src['hv'],
edges.data['he']],
dim=1)}
def dgmg_reduce(self, nodes, round):
"""Aggregate messages.
Parameters
----------
nodes : batch of nodes
round : int
Update round
Returns
-------
dict
Dictionary containing aggregated messages for each node
in the batch, with the messages being tensors of shape
(B, F2), B for the number of nodes and F2 for the aggregated
message size
"""
hv_old = nodes.data['hv']
m = nodes.mailbox['m']
# Make copies of original atom representations to match the
# number of messages.
message = torch.cat([
hv_old.unsqueeze(1).expand(-1, m.size(1), -1), m], dim=2)
node_activation = (self.message_funcs[round](message)).sum(1)
return {'a': node_activation}
def forward(self, g):
"""
Parameters
----------
g : DGLGraph
"""
if g.number_of_edges() == 0:
return
else:
for t in range(self.num_prop_rounds):
g.update_all(message_func=self.dgmg_msg,
reduce_func=self.reduce_funcs[t])
g.ndata['hv'] = self.node_update_funcs[t](
g.ndata['a'], g.ndata['hv'])
class AddNode(nn.Module):
"""Stop or add an atom of a particular type.
Parameters
----------
env : MoleculeEnv
Environment for generating molecules
graph_embed_func : callable taking g as input
Function for computing molecule representation
node_hidden_size : int
Size of atom representation
dropout : float
Probability for dropout
"""
def __init__(self, env, graph_embed_func, node_hidden_size, dropout):
super(AddNode, self).__init__()
self.env = env
n_node_types = len(env.atom_types)
self.graph_op = {'embed': graph_embed_func}
self.stop = n_node_types
self.add_node = nn.Sequential(
nn.Linear(graph_embed_func.graph_hidden_size, graph_embed_func.graph_hidden_size),
nn.Dropout(p=dropout),
nn.Linear(graph_embed_func.graph_hidden_size, n_node_types + 1)
)
# If to add a node, initialize its hv
self.node_type_embed = nn.Embedding(n_node_types, node_hidden_size)
self.initialize_hv = nn.Linear(node_hidden_size + \
graph_embed_func.graph_hidden_size,
node_hidden_size)
self.init_node_activation = torch.zeros(1, 2 * node_hidden_size)
self.dropout = nn.Dropout(p=dropout)
def _initialize_node_repr(self, g, node_type, graph_embed):
"""Initialize atom representation
Parameters
----------
g : DGLGraph
node_type : int
Index for the type of the new atom
graph_embed : tensor of dtype float32
Molecule representation
"""
num_nodes = g.number_of_nodes()
hv_init = torch.cat([
self.node_type_embed(torch.LongTensor([node_type])),
graph_embed], dim=1)
hv_init = self.dropout(hv_init)
hv_init = self.initialize_hv(hv_init)
g.nodes[num_nodes - 1].data['hv'] = hv_init
g.nodes[num_nodes - 1].data['a'] = self.init_node_activation
def prepare_log_prob(self, compute_log_prob):
"""Setup for returning log likelihood
Parameters
----------
compute_log_prob : bool
Whether to compute log likelihood
"""
if compute_log_prob:
self.log_prob = []
self.compute_log_prob = compute_log_prob
def forward(self, action=None):
"""
Parameters
----------
action : None or int
If None, a new action will be sampled. If not None,
teacher forcing will be used to enforce the decision of the
corresponding action.
Returns
-------
stop : bool
Whether we stop adding new atoms
"""
g = self.env.dgl_graph
graph_embed = self.graph_op['embed'](g)
logits = self.add_node(graph_embed).view(1, -1)
probs = F.softmax(logits, dim=1)
if action is None:
action = Categorical(probs).sample().item()
stop = bool(action == self.stop)
if not stop:
self.env.add_atom(action)
self._initialize_node_repr(g, action, graph_embed)
if self.compute_log_prob:
sample_log_prob = F.log_softmax(logits, dim=1)[:, action: action + 1]
self.log_prob.append(sample_log_prob)
return stop
class AddEdge(nn.Module):
"""Stop or add a bond of a particular type.
Parameters
----------
env : MoleculeEnv
Environment for generating molecules
graph_embed_func : callable taking g as input
Function for computing molecule representation
node_hidden_size : int
Size of atom representation
dropout : float
Probability for dropout
"""
def __init__(self, env, graph_embed_func, node_hidden_size, dropout):
super(AddEdge, self).__init__()
self.env = env
n_bond_types = len(env.bond_types)
self.stop = n_bond_types
self.graph_op = {'embed': graph_embed_func}
self.add_edge = nn.Sequential(
nn.Linear(graph_embed_func.graph_hidden_size + node_hidden_size,
graph_embed_func.graph_hidden_size + node_hidden_size),
nn.Dropout(p=dropout),
nn.Linear(graph_embed_func.graph_hidden_size + node_hidden_size, n_bond_types + 1)
)
def prepare_log_prob(self, compute_log_prob):
"""Setup for returning log likelihood
Parameters
----------
compute_log_prob : bool
Whether to compute log likelihood
"""
if compute_log_prob:
self.log_prob = []
self.compute_log_prob = compute_log_prob
def forward(self, action=None):
"""
Parameters
----------
action : None or int
If None, a new action will be sampled. If not None,
teacher forcing will be used to enforce the decision of the
corresponding action.
Returns
-------
stop : bool
Whether we stop adding new bonds
action : int
The type for the new bond
"""
g = self.env.dgl_graph
graph_embed = self.graph_op['embed'](g)
src_embed = g.nodes[g.number_of_nodes() - 1].data['hv']
logits = self.add_edge(
torch.cat([graph_embed, src_embed], dim=1))
probs = F.softmax(logits, dim=1)
if action is None:
action = Categorical(probs).sample().item()
stop = bool(action == self.stop)
if self.compute_log_prob:
sample_log_prob = F.log_softmax(logits, dim=1)[:, action: action + 1]
self.log_prob.append(sample_log_prob)
return stop, action
class ChooseDestAndUpdate(nn.Module):
"""Choose the atom to connect for the new bond.
Parameters
----------
env : MoleculeEnv
Environment for generating molecules
graph_prop_func : callable taking g as input
Function for performing message passing
and updating atom representations
node_hidden_size : int
Size of atom representation
dropout : float
Probability for dropout
"""
def __init__(self, env, graph_prop_func, node_hidden_size, dropout):
super(ChooseDestAndUpdate, self).__init__()
self.env = env
n_bond_types = len(self.env.bond_types)
# To be used for one-hot encoding of bond type
self.bond_embedding = torch.eye(n_bond_types)
self.graph_op = {'prop': graph_prop_func}
self.choose_dest = nn.Sequential(
nn.Linear(2 * node_hidden_size + n_bond_types, 2 * node_hidden_size + n_bond_types),
nn.Dropout(p=dropout),
nn.Linear(2 * node_hidden_size + n_bond_types, 1)
)
def _initialize_edge_repr(self, g, src_list, dest_list, edge_embed):
"""Initialize bond representation
Parameters
----------
g : DGLGraph
src_list : list of int
source atoms for new bonds
dest_list : list of int
destination atoms for new bonds
edge_embed : 2D tensor of dtype float32
Embeddings for the new bonds
"""
g.edges[src_list, dest_list].data['he'] = edge_embed.expand(len(src_list), -1)
def prepare_log_prob(self, compute_log_prob):
"""Setup for returning log likelihood
Parameters
----------
compute_log_prob : bool
Whether to compute log likelihood
"""
if compute_log_prob:
self.log_prob = []
self.compute_log_prob = compute_log_prob
def forward(self, bond_type, dest):
"""
Parameters
----------
bond_type : int
The type for the new bond
dest : int or None
If None, a new action will be sampled. If not None,
teacher forcing will be used to enforce the decision of the
corresponding action.
"""
g = self.env.dgl_graph
src = g.number_of_nodes() - 1
possible_dests = range(src)
src_embed_expand = g.nodes[src].data['hv'].expand(src, -1)
possible_dests_embed = g.nodes[possible_dests].data['hv']
edge_embed = self.bond_embedding[bond_type: bond_type + 1]
dests_scores = self.choose_dest(
torch.cat([possible_dests_embed,
src_embed_expand,
edge_embed.expand(src, -1)], dim=1)).view(1, -1)
dests_probs = F.softmax(dests_scores, dim=1)
if dest is None:
dest = Categorical(dests_probs).sample().item()
if not g.has_edge_between(src, dest):
# For undirected graphs, we add edges for both directions
# so that we can perform graph propagation.
src_list = [src, dest]
dest_list = [dest, src]
self.env.add_bond(src, dest, bond_type)
self._initialize_edge_repr(g, src_list, dest_list, edge_embed)
# Perform message passing when new bonds are added.
self.graph_op['prop'](g)
if self.compute_log_prob:
if dests_probs.nelement() > 1:
self.log_prob.append(
F.log_softmax(dests_scores, dim=1)[:, dest: dest + 1])
def weights_init(m):
'''Function to initialize weights for models
Code from https://gist.github.com/jeasinema/ed9236ce743c8efaf30fa2ff732749f5
Usage:
model = Model()
model.apply(weight_init)
'''
if isinstance(m, nn.Linear):
init.xavier_normal_(m.weight.data)
init.normal_(m.bias.data)
elif isinstance(m, nn.GRUCell):
for param in m.parameters():
if len(param.shape) >= 2:
init.orthogonal_(param.data)
else:
init.normal_(param.data)
def dgmg_message_weight_init(m):
"""Weight initialization for graph propagation module
These are suggested by the author. This should only be used for
the message passing functions, i.e. fe's in the paper.
"""
def _weight_init(m):
if isinstance(m, nn.Linear):
init.normal_(m.weight.data, std=1./10)
init.normal_(m.bias.data, std=1./10)
else:
raise ValueError('Expected the input to be of type nn.Linear!')
if isinstance(m, nn.ModuleList):
for layer in m:
layer.apply(_weight_init)
else:
m.apply(_weight_init)
class DGMG(nn.Module):
"""DGMG model
`Learning Deep Generative Models of Graphs <https://arxiv.org/abs/1803.03324>`__
Users only need to initialize an instance of this class.
Parameters
----------
atom_types : list
E.g. ['C', 'N'].
bond_types : list
E.g. [Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE,
Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC].
node_hidden_size : int
Size of atom representation. Default to 128.
num_prop_rounds : int
Number of message passing rounds for each time. Default to 2.
dropout : float
Probability for dropout. Default to 0.2.
"""
def __init__(self, atom_types, bond_types, node_hidden_size=128,
num_prop_rounds=2, dropout=0.2):
super(DGMG, self).__init__()
self.env = MoleculeEnv(atom_types, bond_types)
# Graph embedding module
self.graph_embed = GraphEmbed(node_hidden_size)
# Graph propagation module
# For one-hot encoding, edge_hidden_size is just the number of bond types
self.graph_prop = GraphProp(num_prop_rounds, node_hidden_size, len(self.env.bond_types))
# Actions
self.add_node_agent = AddNode(
self.env, self.graph_embed, node_hidden_size, dropout)
self.add_edge_agent = AddEdge(
self.env, self.graph_embed, node_hidden_size, dropout)
self.choose_dest_agent = ChooseDestAndUpdate(
self.env, self.graph_prop, node_hidden_size, dropout)
# Weight initialization
self.init_weights()
def init_weights(self):
"""Initialize model weights"""
self.graph_embed.apply(weights_init)
self.graph_prop.apply(weights_init)
self.add_node_agent.apply(weights_init)
self.add_edge_agent.apply(weights_init)
self.choose_dest_agent.apply(weights_init)
self.graph_prop.message_funcs.apply(dgmg_message_weight_init)
def count_step(self):
"""Increment the step by 1."""
self.step_count += 1
def prepare_log_prob(self, compute_log_prob):
"""Setup for returning log likelihood
Parameters
----------
compute_log_prob : bool
Whether to compute log likelihood
"""
self.compute_log_prob = compute_log_prob
self.add_node_agent.prepare_log_prob(compute_log_prob)
self.add_edge_agent.prepare_log_prob(compute_log_prob)
self.choose_dest_agent.prepare_log_prob(compute_log_prob)
def add_node_and_update(self, a=None):
"""Decide if to add a new atom.
If a new atom should be added, update the graph.
Parameters
----------
a : None or int
If None, a new action will be sampled. If not None,
teacher forcing will be used to enforce the decision of the
corresponding action.
"""
self.count_step()
return self.add_node_agent(a)
def add_edge_or_not(self, a=None):
"""Decide if to add a new bond.
Parameters
----------
a : None or int
If None, a new action will be sampled. If not None,
teacher forcing will be used to enforce the decision of the
corresponding action.
"""
self.count_step()
return self.add_edge_agent(a)
def choose_dest_and_update(self, bond_type, a=None):
"""Choose destination and connect it to the latest atom.
Add edges for both directions and update the graph.
Parameters
----------
bond_type : int
The type of the new bond to add
a : None or int
If None, a new action will be sampled. If not None,
teacher forcing will be used to enforce the decision of the
corresponding action.
"""
self.count_step()
self.choose_dest_agent(bond_type, a)
def get_log_prob(self):
"""Compute the log likelihood for the decision sequence,
typically corresponding to the generation of a molecule.
Returns
-------
torch.tensor consisting of a float only
"""
return torch.cat(self.add_node_agent.log_prob).sum()\
+ torch.cat(self.add_edge_agent.log_prob).sum()\
+ torch.cat(self.choose_dest_agent.log_prob).sum()
def teacher_forcing(self, actions):
"""Generate a molecule according to a sequence of actions.
Parameters
----------
actions : list of 2-tuples of int
actions[t] gives (i, j), the action to execute by DGMG at timestep t.
- If i = 0, j specifies either the type of the atom to add or termination
- If i = 1, j specifies either the type of the bond to add or termination
- If i = 2, j specifies the destination atom id for the bond to add.
With the formulation of DGMG, j must be created before the decision.
"""
stop_node = self.add_node_and_update(a=actions[self.step_count][1])
while not stop_node:
# A new atom was just added.
stop_edge, bond_type = self.add_edge_or_not(a=actions[self.step_count][1])
while not stop_edge:
# A new bond is to be added.
self.choose_dest_and_update(bond_type, a=actions[self.step_count][1])
stop_edge, bond_type = self.add_edge_or_not(a=actions[self.step_count][1])
stop_node = self.add_node_and_update(a=actions[self.step_count][1])
def rollout(self, max_num_steps):
"""Sample a molecule from the distribution learned by DGMG."""
stop_node = self.add_node_and_update()
while (not stop_node) and (self.step_count <= max_num_steps):
stop_edge, bond_type = self.add_edge_or_not()
if self.env.num_atoms() == 1:
stop_edge = True
while (not stop_edge) and (self.step_count <= max_num_steps):
self.choose_dest_and_update(bond_type)
stop_edge, bond_type = self.add_edge_or_not()
stop_node = self.add_node_and_update()
def forward(self, actions=None, rdkit_mol=False, compute_log_prob=False, max_num_steps=400):
"""
Parameters
----------
actions : list of 2-tuples or None.
If actions are not None, generate a molecule according to actions.
Otherwise, a molecule will be generated based on sampled actions.
rdkit_mol : bool
Whether to maintain a Chem.rdchem.Mol object. This brings extra
computational cost, but is necessary if we are interested in
learning the generated molecule.
compute_log_prob : bool
Whether to compute log likelihood
max_num_steps : int
Maximum number of steps allowed. This only comes into effect
during inference and prevents the model from not stopping.
Returns
-------
torch.tensor consisting of a float only, optional
The log likelihood for the actions taken
str, optional
The generated molecule in the form of SMILES
"""
# Initialize an empty molecule
self.step_count = 0
self.env.reset(rdkit_mol=rdkit_mol)
self.prepare_log_prob(compute_log_prob)
if actions is not None:
# A sequence of decisions is given, use teacher forcing
self.teacher_forcing(actions)
else:
# Sample a molecule from the distribution learned by DGMG
self.rollout(max_num_steps)
if compute_log_prob and rdkit_mol:
return self.get_log_prob(), self.env.get_current_smiles()
if compute_log_prob:
return self.get_log_prob()
if rdkit_mol:
return self.env.get_current_smiles()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment