Unverified Commit 828a5e5b authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[DGL-LifeSci] Migration and Refactor (#1226)

* First commit

* Update

* Update splitters

* Update

* Update

* Update

* Update

* Update

* Update

* Migrate ACNN

* Fix

* Fix

* Update

* Update

* Update

* Update

* Update

* Update

* Finish classification

* Update

* Fix

* Update

* Update

* Update

* Fix

* Fix

* Fix

* Update

* Update

* Update

* trigger CI

* Fix CI

* Update

* Update

* Update

* Add default values

* Rename

* Update deprecation message
parent e4948c5c
"""Apply weighted sum and max pooling to the node representations and concatenate the results."""
import dgl
import torch
import torch.nn as nn
from dgl import BatchedDGLGraph
from dgl.nn.pytorch import WeightAndSum
__all__ = ['WeightedSumAndMax']
class WeightedSumAndMax(nn.Module):
r"""Apply weighted sum and max pooling to the node
representations and concatenate the results.
Parameters
----------
in_feats : int
Input node feature size
"""
def __init__(self, in_feats):
super(WeightedSumAndMax, self).__init__()
self.weight_and_sum = WeightAndSum(in_feats)
def forward(self, bg, feats):
"""Readout
Parameters
----------
bg : DGLGraph
DGLGraph for a batch of graphs.
feats : FloatTensor of shape (N, M1)
* N is the total number of nodes in the batch of graphs
* M1 is the input node feature size, which must match
in_feats in initialization
Returns
-------
h_g : FloatTensor of shape (B, 2 * M1)
* B is the number of graphs in the batch
* M1 is the input node feature size, which must match
in_feats in initialization
"""
h_g_sum = self.weight_and_sum(bg, feats)
with bg.local_scope():
bg.ndata['h'] = feats
h_g_max = dgl.max_nodes(bg, 'h')
if not isinstance(bg, BatchedDGLGraph):
h_g_sum = h_g_sum.unsqueeze(0)
h_g_max = h_g_max.unsqueeze(0)
h_g = torch.cat([h_g_sum, h_g_max], dim=1)
return h_g
"""Utils for data processing."""
from .complex_to_graph import *
from .early_stop import *
from .eval import *
from .featurizers import *
from .mol_to_graph import *
from .rdkit_utils import *
from .splitters import *
"""Convert complexes into DGLHeteroGraphs"""
import dgl.backend as F
import numpy as np
from dgl import graph, bipartite, hetero_from_relations
from ..utils.mol_to_graph import k_nearest_neighbors
__all__ = ['ACNN_graph_construction_and_featurization']
def filter_out_hydrogens(mol):
"""Get indices for non-hydrogen atoms.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance.
Returns
-------
indices_left : list of int
Indices of non-hydrogen atoms.
"""
indices_left = []
for i, atom in enumerate(mol.GetAtoms()):
atomic_num = atom.GetAtomicNum()
# Hydrogen atoms have an atomic number of 1.
if atomic_num != 1:
indices_left.append(i)
return indices_left
def get_atomic_numbers(mol, indices):
"""Get the atomic numbers for the specified atoms.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance.
indices : list of int
Specifying atoms.
Returns
-------
list of int
Atomic numbers computed.
"""
atomic_numbers = []
for i in indices:
atom = mol.GetAtomWithIdx(i)
atomic_numbers.append(atom.GetAtomicNum())
return atomic_numbers
def ACNN_graph_construction_and_featurization(ligand_mol,
protein_mol,
ligand_coordinates,
protein_coordinates,
max_num_ligand_atoms=None,
max_num_protein_atoms=None,
neighbor_cutoff=12.,
max_num_neighbors=12,
strip_hydrogens=False):
"""Graph construction and featurization for `Atomic Convolutional Networks for
Predicting Protein-Ligand Binding Affinity <https://arxiv.org/abs/1703.10603>`__.
Parameters
----------
ligand_mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance.
protein_mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance.
ligand_coordinates : Float Tensor of shape (V1, 3)
Atom coordinates in a ligand.
protein_coordinates : Float Tensor of shape (V2, 3)
Atom coordinates in a protein.
max_num_ligand_atoms : int or None
Maximum number of atoms in ligands for zero padding, which should be no smaller than
ligand_mol.GetNumAtoms() if not None. If None, no zero padding will be performed.
Default to None.
max_num_protein_atoms : int or None
Maximum number of atoms in proteins for zero padding, which should be no smaller than
protein_mol.GetNumAtoms() if not None. If None, no zero padding will be performed.
Default to None.
neighbor_cutoff : float
Distance cutoff to define 'neighboring'. Default to 12.
max_num_neighbors : int
Maximum number of neighbors allowed for each atom. Default to 12.
strip_hydrogens : bool
Whether to exclude hydrogen atoms. Default to False.
"""
assert ligand_coordinates is not None, 'Expect ligand_coordinates to be provided.'
assert protein_coordinates is not None, 'Expect protein_coordinates to be provided.'
if max_num_ligand_atoms is not None:
assert max_num_ligand_atoms >= ligand_mol.GetNumAtoms(), \
'Expect max_num_ligand_atoms to be no smaller than ligand_mol.GetNumAtoms(), ' \
'got {:d} and {:d}'.format(max_num_ligand_atoms, ligand_mol.GetNumAtoms())
if max_num_protein_atoms is not None:
assert max_num_protein_atoms >= protein_mol.GetNumAtoms(), \
'Expect max_num_protein_atoms to be no smaller than protein_mol.GetNumAtoms(), ' \
'got {:d} and {:d}'.format(max_num_protein_atoms, protein_mol.GetNumAtoms())
if strip_hydrogens:
# Remove hydrogen atoms and their corresponding coordinates
ligand_atom_indices_left = filter_out_hydrogens(ligand_mol)
protein_atom_indices_left = filter_out_hydrogens(protein_mol)
ligand_coordinates = ligand_coordinates.take(ligand_atom_indices_left, axis=0)
protein_coordinates = protein_coordinates.take(protein_atom_indices_left, axis=0)
else:
ligand_atom_indices_left = list(range(ligand_mol.GetNumAtoms()))
protein_atom_indices_left = list(range(protein_mol.GetNumAtoms()))
# Compute number of nodes for each type
if max_num_ligand_atoms is None:
num_ligand_atoms = len(ligand_atom_indices_left)
else:
num_ligand_atoms = max_num_ligand_atoms
if max_num_protein_atoms is None:
num_protein_atoms = len(protein_atom_indices_left)
else:
num_protein_atoms = max_num_protein_atoms
# Construct graph for atoms in the ligand
ligand_srcs, ligand_dsts, ligand_dists = k_nearest_neighbors(
ligand_coordinates, neighbor_cutoff, max_num_neighbors)
ligand_graph = graph((ligand_srcs, ligand_dsts),
'ligand_atom', 'ligand', num_ligand_atoms)
ligand_graph.edata['distance'] = F.reshape(F.zerocopy_from_numpy(
np.array(ligand_dists).astype(np.float32)), (-1, 1))
# Construct graph for atoms in the protein
protein_srcs, protein_dsts, protein_dists = k_nearest_neighbors(
protein_coordinates, neighbor_cutoff, max_num_neighbors)
protein_graph = graph((protein_srcs, protein_dsts),
'protein_atom', 'protein', num_protein_atoms)
protein_graph.edata['distance'] = F.reshape(F.zerocopy_from_numpy(
np.array(protein_dists).astype(np.float32)), (-1, 1))
# Construct 4 graphs for complex representation, including the connection within
# protein atoms, the connection within ligand atoms and the connection between
# protein and ligand atoms.
complex_srcs, complex_dsts, complex_dists = k_nearest_neighbors(
np.concatenate([ligand_coordinates, protein_coordinates]),
neighbor_cutoff, max_num_neighbors)
complex_srcs = np.array(complex_srcs)
complex_dsts = np.array(complex_dsts)
complex_dists = np.array(complex_dists)
offset = num_ligand_atoms
# ('ligand_atom', 'complex', 'ligand_atom')
inter_ligand_indices = np.intersect1d(
(complex_srcs < offset).nonzero()[0],
(complex_dsts < offset).nonzero()[0],
assume_unique=True)
inter_ligand_graph = graph(
(complex_srcs[inter_ligand_indices].tolist(),
complex_dsts[inter_ligand_indices].tolist()),
'ligand_atom', 'complex', num_ligand_atoms)
inter_ligand_graph.edata['distance'] = F.reshape(F.zerocopy_from_numpy(
complex_dists[inter_ligand_indices].astype(np.float32)), (-1, 1))
# ('protein_atom', 'complex', 'protein_atom')
inter_protein_indices = np.intersect1d(
(complex_srcs >= offset).nonzero()[0],
(complex_dsts >= offset).nonzero()[0],
assume_unique=True)
inter_protein_graph = graph(
((complex_srcs[inter_protein_indices] - offset).tolist(),
(complex_dsts[inter_protein_indices] - offset).tolist()),
'protein_atom', 'complex', num_protein_atoms)
inter_protein_graph.edata['distance'] = F.reshape(F.zerocopy_from_numpy(
complex_dists[inter_protein_indices].astype(np.float32)), (-1, 1))
# ('ligand_atom', 'complex', 'protein_atom')
ligand_protein_indices = np.intersect1d(
(complex_srcs < offset).nonzero()[0],
(complex_dsts >= offset).nonzero()[0],
assume_unique=True)
ligand_protein_graph = bipartite(
(complex_srcs[ligand_protein_indices].tolist(),
(complex_dsts[ligand_protein_indices] - offset).tolist()),
'ligand_atom', 'complex', 'protein_atom',
(num_ligand_atoms, num_protein_atoms))
ligand_protein_graph.edata['distance'] = F.reshape(F.zerocopy_from_numpy(
complex_dists[ligand_protein_indices].astype(np.float32)), (-1, 1))
# ('protein_atom', 'complex', 'ligand_atom')
protein_ligand_indices = np.intersect1d(
(complex_srcs >= offset).nonzero()[0],
(complex_dsts < offset).nonzero()[0],
assume_unique=True)
protein_ligand_graph = bipartite(
((complex_srcs[protein_ligand_indices] - offset).tolist(),
complex_dsts[protein_ligand_indices].tolist()),
'protein_atom', 'complex', 'ligand_atom',
(num_protein_atoms, num_ligand_atoms))
protein_ligand_graph.edata['distance'] = F.reshape(F.zerocopy_from_numpy(
complex_dists[protein_ligand_indices].astype(np.float32)), (-1, 1))
# Merge the graphs
g = hetero_from_relations(
[protein_graph,
ligand_graph,
inter_ligand_graph,
inter_protein_graph,
ligand_protein_graph,
protein_ligand_graph]
)
# Get atomic numbers for all atoms left and set node features
ligand_atomic_numbers = np.array(get_atomic_numbers(ligand_mol, ligand_atom_indices_left))
# zero padding
ligand_atomic_numbers = np.concatenate([
ligand_atomic_numbers, np.zeros(num_ligand_atoms - len(ligand_atom_indices_left))])
protein_atomic_numbers = np.array(get_atomic_numbers(protein_mol, protein_atom_indices_left))
# zero padding
protein_atomic_numbers = np.concatenate([
protein_atomic_numbers, np.zeros(num_protein_atoms - len(protein_atom_indices_left))])
g.nodes['ligand_atom'].data['atomic_number'] = F.reshape(F.zerocopy_from_numpy(
ligand_atomic_numbers.astype(np.float32)), (-1, 1))
g.nodes['protein_atom'].data['atomic_number'] = F.reshape(F.zerocopy_from_numpy(
protein_atomic_numbers.astype(np.float32)), (-1, 1))
# Prepare mask indicating the existence of nodes
ligand_masks = np.zeros((num_ligand_atoms, 1))
ligand_masks[:len(ligand_atom_indices_left), :] = 1
g.nodes['ligand_atom'].data['mask'] = F.zerocopy_from_numpy(
ligand_masks.astype(np.float32))
protein_masks = np.zeros((num_protein_atoms, 1))
protein_masks[:len(protein_atom_indices_left), :] = 1
g.nodes['protein_atom'].data['mask'] = F.zerocopy_from_numpy(
protein_masks.astype(np.float32))
return g
"""Early stopping"""
import datetime
import torch
__all__ = ['EarlyStopping']
class EarlyStopping(object):
"""Early stop tracker
Save model checkpoint when observing a performance improvement on
the validation set and early stop if improvement has not been
observed for a particular number of epochs.
Parameters
----------
mode : str
* 'higher': Higher metric suggests a better model
* 'lower': Lower metric suggests a better model
patience : int
The early stopping will happen if we do not observe performance
improvement for ``patience`` consecutive epochs.
filename : str or None
Filename for storing the model checkpoint
"""
def __init__(self, mode='higher', patience=10, filename=None):
if filename is None:
dt = datetime.datetime.now()
filename = 'early_stop_{}_{:02d}-{:02d}-{:02d}.pth'.format(
dt.date(), dt.hour, dt.minute, dt.second)
assert mode in ['higher', 'lower']
self.mode = mode
if self.mode == 'higher':
self._check = self._check_higher
else:
self._check = self._check_lower
self.patience = patience
self.counter = 0
self.filename = filename
self.best_score = None
self.early_stop = False
def _check_higher(self, score, prev_best_score):
"""Check if the new score is higher than the previous best score.
Parameters
----------
score : float
New score.
prev_best_score : float
Previous best score.
Returns
-------
bool
Whether the new score is higher than the previous best score.
"""
return (score > prev_best_score)
def _check_lower(self, score, prev_best_score):
"""Check if the new score is lower than the previous best score.
Parameters
----------
score : float
New score.
prev_best_score : float
Previous best score.
Returns
-------
bool
Whether the new score is lower than the previous best score.
"""
return (score < prev_best_score)
def step(self, score, model):
"""Update based on a new score.
The new score is typically model performance on the validation set
for a new epoch.
Parameters
----------
score : float
New score.
model : nn.Module
Model instance.
Returns
-------
bool
Whether an early stop should be performed.
"""
if self.best_score is None:
self.best_score = score
self.save_checkpoint(model)
elif self._check(score, self.best_score):
self.best_score = score
self.save_checkpoint(model)
self.counter = 0
else:
self.counter += 1
print(
f'EarlyStopping counter: {self.counter} out of {self.patience}')
if self.counter >= self.patience:
self.early_stop = True
return self.early_stop
def save_checkpoint(self, model):
'''Saves model when the metric on the validation set gets improved.
Parameters
----------
model : nn.Module
Model instance.
'''
torch.save({'model_state_dict': model.state_dict()}, self.filename)
def load_checkpoint(self, model):
'''Load the latest checkpoint
Parameters
----------
model : nn.Module
Model instance.
'''
model.load_state_dict(torch.load(self.filename)['model_state_dict'])
"""Evaluation of model performance."""
import numpy as np
import torch
import torch.nn.functional as F
from scipy.stats import pearsonr
from sklearn.metrics import roc_auc_score
__all__ = ['Meter']
class Meter(object):
"""Track and summarize model performance on a dataset for (multi-label) prediction.
When dealing with multitask learning, quite often we normalize the labels so they are
roughly at a same scale. During the evaluation, we need to undo the normalization on
the predicted labels. If mean and std are not None, we will undo the normalization.
Currently we support evaluation with 4 metrics:
* pearson r2
* mae
* rmse
* roc auc score
Parameters
----------
mean : torch.float32 tensor of shape (T) or None.
Mean of existing training labels across tasks if not None. T for the number of tasks.
Default to None.
std : torch.float32 tensor of shape (T)
Std of existing training labels across tasks if not None.
"""
def __init__(self, mean=None, std=None):
self.mask = []
self.y_pred = []
self.y_true = []
if (mean is not None) and (std is not None):
self.mean = mean.cpu()
self.std = std.cpu()
else:
self.mean = None
self.std = None
def update(self, y_pred, y_true, mask=None):
"""Update for the result of an iteration
Parameters
----------
y_pred : float32 tensor
Predicted labels with shape (B, T),
B for number of graphs in the batch and T for the number of tasks
y_true : float32 tensor
Ground truth labels with shape (B, T)
mask : None or float32 tensor
Binary mask indicating the existence of ground truth labels with
shape (B, T). If None, we assume that all labels exist and create
a one-tensor for placeholder.
"""
self.y_pred.append(y_pred.detach().cpu())
self.y_true.append(y_true.detach().cpu())
if mask is None:
self.mask.append(torch.ones(self.y_pred[-1].shape))
else:
self.mask.append(mask.detach().cpu())
def _finalize(self):
"""Prepare for evaluation.
If normalization was performed on the ground truth labels during training,
we need to undo the normalization on the predicted labels.
Returns
-------
mask : float32 tensor
Binary mask indicating the existence of ground
truth labels with shape (B, T), B for batch size
and T for the number of tasks
y_pred : float32 tensor
Predicted labels with shape (B, T)
y_true : float32 tensor
Ground truth labels with shape (B, T)
"""
mask = torch.cat(self.mask, dim=0)
y_pred = torch.cat(self.y_pred, dim=0)
y_true = torch.cat(self.y_true, dim=0)
if (self.mean is not None) and (self.std is not None):
# To compensate for the imbalance between labels during training,
# we normalize the ground truth labels with training mean and std.
# We need to undo that for evaluation.
y_pred = y_pred * self.std + self.mean
return mask, y_pred, y_true
def _reduce_scores(self, scores, reduction='none'):
"""Finalize the scores to return.
Parameters
----------
scores : list of float
Scores for all tasks.
reduction : 'none' or 'mean' or 'sum'
Controls the form of scores for all tasks
Returns
-------
float or list of float
* If ``reduction == 'none'``, return the list of scores for all tasks.
* If ``reduction == 'mean'``, return the mean of scores for all tasks.
* If ``reduction == 'sum'``, return the sum of scores for all tasks.
"""
if reduction == 'none':
return scores
elif reduction == 'mean':
return np.mean(scores)
elif reduction == 'sum':
return np.sum(scores)
else:
raise ValueError(
"Expect reduction to be 'none', 'mean' or 'sum', got {}".format(reduction))
def multilabel_score(self, score_func, reduction='none'):
"""Evaluate for multi-label prediction.
Parameters
----------
score_func : callable
A score function that takes task-specific ground truth and predicted labels as
input and return a float as the score. The labels are in the form of 1D tensor.
reduction : 'none' or 'mean' or 'sum'
Controls the form of scores for all tasks
Returns
-------
float or list of float
* If ``reduction == 'none'``, return the list of scores for all tasks.
* If ``reduction == 'mean'``, return the mean of scores for all tasks.
* If ``reduction == 'sum'``, return the sum of scores for all tasks.
"""
mask, y_pred, y_true = self._finalize()
n_tasks = y_true.shape[1]
scores = []
for task in range(n_tasks):
task_w = mask[:, task]
task_y_true = y_true[:, task][task_w != 0]
task_y_pred = y_pred[:, task][task_w != 0]
scores.append(score_func(task_y_true, task_y_pred))
return self._reduce_scores(scores, reduction)
def pearson_r2(self, reduction='none'):
"""Compute squared Pearson correlation coefficient.
Parameters
----------
reduction : 'none' or 'mean' or 'sum'
Controls the form of scores for all tasks
Returns
-------
float or list of float
* If ``reduction == 'none'``, return the list of scores for all tasks.
* If ``reduction == 'mean'``, return the mean of scores for all tasks.
* If ``reduction == 'sum'``, return the sum of scores for all tasks.
"""
def score(y_true, y_pred):
return pearsonr(y_true.numpy(), y_pred.numpy())[0] ** 2
return self.multilabel_score(score, reduction)
def mae(self, reduction='none'):
"""Compute mean absolute error.
Parameters
----------
reduction : 'none' or 'mean' or 'sum'
Controls the form of scores for all tasks
Returns
-------
float or list of float
* If ``reduction == 'none'``, return the list of scores for all tasks.
* If ``reduction == 'mean'``, return the mean of scores for all tasks.
* If ``reduction == 'sum'``, return the sum of scores for all tasks.
"""
def score(y_true, y_pred):
return F.l1_loss(y_true, y_pred).data.item()
return self.multilabel_score(score, reduction)
def rmse(self, reduction='none'):
"""Compute root mean square error.
Parameters
----------
reduction : 'none' or 'mean' or 'sum'
Controls the form of scores for all tasks
Returns
-------
float or list of float
* If ``reduction == 'none'``, return the list of scores for all tasks.
* If ``reduction == 'mean'``, return the mean of scores for all tasks.
* If ``reduction == 'sum'``, return the sum of scores for all tasks.
"""
def score(y_true, y_pred):
return np.sqrt(F.mse_loss(y_pred, y_true).cpu().item())
return self.multilabel_score(score, reduction)
def roc_auc_score(self, reduction='none'):
"""Compute roc-auc score for binary classification.
Parameters
----------
reduction : 'none' or 'mean' or 'sum'
Controls the form of scores for all tasks
Returns
-------
float or list of float
* If ``reduction == 'none'``, return the list of scores for all tasks.
* If ``reduction == 'mean'``, return the mean of scores for all tasks.
* If ``reduction == 'sum'``, return the sum of scores for all tasks.
"""
# Todo: This function only supports binary classification and we may need
# to support categorical classes.
assert (self.mean is None) and (self.std is None), \
'Label normalization should not be performed for binary classification.'
def score(y_true, y_pred):
return roc_auc_score(y_true.long().numpy(), torch.sigmoid(y_pred).numpy())
return self.multilabel_score(score, reduction)
def compute_metric(self, metric_name, reduction='none'):
"""Compute metric based on metric name.
Parameters
----------
metric_name : str
* 'r2': compute squared Pearson correlation coefficient
* 'mae': compute mean absolute error
* 'rmse': compute root mean square error
* 'roc_auc_score': compute roc-auc score
reduction : 'none' or 'mean' or 'sum'
Controls the form of scores for all tasks
Returns
-------
float or list of float
* If ``reduction == 'none'``, return the list of scores for all tasks.
* If ``reduction == 'mean'``, return the mean of scores for all tasks.
* If ``reduction == 'sum'``, return the sum of scores for all tasks.
"""
if metric_name == 'r2':
return self.pearson_r2(reduction)
if metric_name == 'mae':
return self.mae(reduction)
if metric_name == 'rmse':
return self.rmse(reduction)
if metric_name == 'roc_auc_score':
return self.roc_auc_score(reduction)
"""Node and edge featurization for molecular graphs."""
import dgl.backend as F
import itertools
import numpy as np
from collections import defaultdict
from rdkit import Chem
__all__ = ['one_hot_encoding',
'atom_type_one_hot',
'atomic_number_one_hot',
'atomic_number',
'atom_degree_one_hot',
'atom_degree',
'atom_total_degree_one_hot',
'atom_total_degree',
'atom_implicit_valence_one_hot',
'atom_implicit_valence',
'atom_hybridization_one_hot',
'atom_total_num_H_one_hot',
'atom_total_num_H',
'atom_formal_charge_one_hot',
'atom_formal_charge',
'atom_num_radical_electrons_one_hot',
'atom_num_radical_electrons',
'atom_is_aromatic_one_hot',
'atom_is_aromatic',
'atom_chiral_tag_one_hot',
'atom_mass',
'ConcatFeaturizer',
'BaseAtomFeaturizer',
'CanonicalAtomFeaturizer',
'bond_type_one_hot',
'bond_is_conjugated_one_hot',
'bond_is_conjugated',
'bond_is_in_ring_one_hot',
'bond_is_in_ring',
'bond_stereo_one_hot',
'BaseBondFeaturizer',
'CanonicalBondFeaturizer']
def one_hot_encoding(x, allowable_set, encode_unknown=False):
"""One-hot encoding.
Parameters
----------
x
Value to encode.
allowable_set : list
The elements of the allowable_set should be of the
same type as x.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element.
Returns
-------
list
List of boolean values where at most one value is True.
The list is of length ``len(allowable_set)`` if ``encode_unknown=False``
and ``len(allowable_set) + 1`` otherwise.
"""
if encode_unknown and (allowable_set[-1] is not None):
allowable_set.append(None)
if encode_unknown and (x not in allowable_set):
x = None
return list(map(lambda s: x == s, allowable_set))
#################################################################
# Atom featurization
#################################################################
def atom_type_one_hot(atom, allowable_set=None, encode_unknown=False):
"""One hot encoding for the type of an atom.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
allowable_set : list of str
Atom types to consider. Default: ``C``, ``N``, ``O``, ``S``, ``F``, ``Si``, ``P``,
``Cl``, ``Br``, ``Mg``, ``Na``, ``Ca``, ``Fe``, ``As``, ``Al``, ``I``, ``B``, ``V``,
``K``, ``Tl``, ``Yb``, ``Sb``, ``Sn``, ``Ag``, ``Pd``, ``Co``, ``Se``, ``Ti``, ``Zn``,
``H``, ``Li``, ``Ge``, ``Cu``, ``Au``, ``Ni``, ``Cd``, ``In``, ``Mn``, ``Zr``, ``Cr``,
``Pt``, ``Hg``, ``Pb``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
Returns
-------
list
List of boolean values where at most one value is True.
"""
if allowable_set is None:
allowable_set = ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na', 'Ca',
'Fe', 'As', 'Al', 'I', 'B', 'V', 'K', 'Tl', 'Yb', 'Sb', 'Sn',
'Ag', 'Pd', 'Co', 'Se', 'Ti', 'Zn', 'H', 'Li', 'Ge', 'Cu', 'Au',
'Ni', 'Cd', 'In', 'Mn', 'Zr', 'Cr', 'Pt', 'Hg', 'Pb']
return one_hot_encoding(atom.GetSymbol(), allowable_set, encode_unknown)
def atomic_number_one_hot(atom, allowable_set=None, encode_unknown=False):
"""One hot encoding for the atomic number of an atom.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
allowable_set : list of int
Atomic numbers to consider. Default: ``1`` - ``100``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
Returns
-------
list
List of boolean values where at most one value is True.
"""
if allowable_set is None:
allowable_set = list(range(1, 101))
return one_hot_encoding(atom.GetAtomicNum(), allowable_set, encode_unknown)
def atomic_number(atom):
"""Get the atomic number for an atom.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
Returns
-------
list
List containing one int only.
"""
return [atom.GetAtomicNum()]
def atom_degree_one_hot(atom, allowable_set=None, encode_unknown=False):
"""One hot encoding for the degree of an atom.
Note that the result will be different depending on whether the Hs are
explicitly modeled in the graph.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
allowable_set : list of int
Atom degrees to consider. Default: ``0`` - ``10``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
Returns
-------
list
List of boolean values where at most one value is True.
See Also
--------
atom_total_degree_one_hot
"""
if allowable_set is None:
allowable_set = list(range(11))
return one_hot_encoding(atom.GetDegree(), allowable_set, encode_unknown)
def atom_degree(atom):
"""Get the degree of an atom.
Note that the result will be different depending on whether the Hs are
explicitly modeled in the graph.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
Returns
-------
list
List containing one int only.
See Also
--------
atom_total_degree
"""
return [atom.GetDegree()]
def atom_total_degree_one_hot(atom, allowable_set=None, encode_unknown=False):
"""One hot encoding for the degree of an atom including Hs.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
allowable_set : list
Total degrees to consider. Default: ``0`` - ``5``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
See Also
--------
atom_degree_one_hot
"""
if allowable_set is None:
allowable_set = list(range(6))
return one_hot_encoding(atom.GetTotalDegree(), allowable_set, encode_unknown)
def atom_total_degree(atom):
"""The degree of an atom including Hs.
See Also
--------
atom_degree
Returns
-------
list
List containing one int only.
"""
return [atom.GetTotalDegree()]
def atom_implicit_valence_one_hot(atom, allowable_set=None, encode_unknown=False):
"""One hot encoding for the implicit valences of an atom.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
allowable_set : list of int
Atom implicit valences to consider. Default: ``0`` - ``6``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
Returns
-------
list
List of boolean values where at most one value is True.
"""
if allowable_set is None:
allowable_set = list(range(7))
return one_hot_encoding(atom.GetImplicitValence(), allowable_set, encode_unknown)
def atom_implicit_valence(atom):
"""Get the implicit valence of an atom.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
Reurns
------
list
List containing one int only.
"""
return [atom.GetImplicitValence()]
def atom_hybridization_one_hot(atom, allowable_set=None, encode_unknown=False):
"""One hot encoding for the hybridization of an atom.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
allowable_set : list of rdkit.Chem.rdchem.HybridizationType
Atom hybridizations to consider. Default: ``Chem.rdchem.HybridizationType.SP``,
``Chem.rdchem.HybridizationType.SP2``, ``Chem.rdchem.HybridizationType.SP3``,
``Chem.rdchem.HybridizationType.SP3D``, ``Chem.rdchem.HybridizationType.SP3D2``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
Returns
-------
list
List of boolean values where at most one value is True.
"""
if allowable_set is None:
allowable_set = [Chem.rdchem.HybridizationType.SP,
Chem.rdchem.HybridizationType.SP2,
Chem.rdchem.HybridizationType.SP3,
Chem.rdchem.HybridizationType.SP3D,
Chem.rdchem.HybridizationType.SP3D2]
return one_hot_encoding(atom.GetHybridization(), allowable_set, encode_unknown)
def atom_total_num_H_one_hot(atom, allowable_set=None, encode_unknown=False):
"""One hot encoding for the total number of Hs of an atom.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
allowable_set : list of int
Total number of Hs to consider. Default: ``0`` - ``4``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
Returns
-------
list
List of boolean values where at most one value is True.
"""
if allowable_set is None:
allowable_set = list(range(5))
return one_hot_encoding(atom.GetTotalNumHs(), allowable_set, encode_unknown)
def atom_total_num_H(atom):
"""Get the total number of Hs of an atom.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
Returns
-------
list
List containing one int only.
"""
return [atom.GetTotalNumHs()]
def atom_formal_charge_one_hot(atom, allowable_set=None, encode_unknown=False):
"""One hot encoding for the formal charge of an atom.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
allowable_set : list of int
Formal charges to consider. Default: ``-2`` - ``2``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
Returns
-------
list
List of boolean values where at most one value is True.
"""
if allowable_set is None:
allowable_set = list(range(-2, 3))
return one_hot_encoding(atom.GetFormalCharge(), allowable_set, encode_unknown)
def atom_formal_charge(atom):
"""Get formal charge for an atom.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
Returns
-------
list
List containing one int only.
"""
return [atom.GetFormalCharge()]
def atom_num_radical_electrons_one_hot(atom, allowable_set=None, encode_unknown=False):
"""One hot encoding for the number of radical electrons of an atom.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
allowable_set : list of int
Number of radical electrons to consider. Default: ``0`` - ``4``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
Returns
-------
list
List of boolean values where at most one value is True.
"""
if allowable_set is None:
allowable_set = list(range(5))
return one_hot_encoding(atom.GetNumRadicalElectrons(), allowable_set, encode_unknown)
def atom_num_radical_electrons(atom):
"""Get the number of radical electrons for an atom.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
Returns
-------
list
List containing one int only.
"""
return [atom.GetNumRadicalElectrons()]
def atom_is_aromatic_one_hot(atom, allowable_set=None, encode_unknown=False):
"""One hot encoding for whether the atom is aromatic.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
allowable_set : list of bool
Conditions to consider. Default: ``False`` and ``True``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
Returns
-------
list
List of boolean values where at most one value is True.
"""
if allowable_set is None:
allowable_set = [False, True]
return one_hot_encoding(atom.GetIsAromatic(), allowable_set, encode_unknown)
def atom_is_aromatic(atom):
"""Get whether the atom is aromatic.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
Returns
-------
list
List containing one bool only.
"""
return [atom.GetIsAromatic()]
def atom_chiral_tag_one_hot(atom, allowable_set=None, encode_unknown=False):
"""One hot encoding for the chiral tag of an atom.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
allowable_set : list of rdkit.Chem.rdchem.ChiralType
Chiral tags to consider. Default: ``rdkit.Chem.rdchem.ChiralType.CHI_UNSPECIFIED``,
``rdkit.Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW``,
``rdkit.Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW``,
``rdkit.Chem.rdchem.ChiralType.CHI_OTHER``.
"""
if allowable_set is None:
allowable_set = [Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW,
Chem.rdchem.ChiralType.CHI_OTHER]
return one_hot_encoding(atom.GetChiralTag(), allowable_set, encode_unknown)
def atom_mass(atom, coef=0.01):
"""Get the mass of an atom and scale it.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
coef : float
The mass will be multiplied by ``coef``.
Returns
-------
list
List containing one float only.
"""
return [atom.GetMass() * coef]
class ConcatFeaturizer(object):
"""Concatenate the evaluation results of multiple functions as a single feature.
Parameters
----------
func_list : list
List of functions for computing molecular descriptors from objects of a same
particular data type, e.g. ``rdkit.Chem.rdchem.Atom``. Each function is of signature
``func(data_type) -> list of float or bool or int``. The resulting order of
the features will follow that of the functions in the list.
"""
def __init__(self, func_list):
self.func_list = func_list
def __call__(self, x):
"""Featurize the input data.
Parameters
----------
x :
Data to featurize.
Returns
-------
list
List of feature values, which can be of type bool, float or int.
"""
return list(itertools.chain.from_iterable(
[func(x) for func in self.func_list]))
class BaseAtomFeaturizer(object):
"""An abstract class for atom featurizers.
Loop over all atoms in a molecule and featurize them with the ``featurizer_funcs``.
**We assume the resulting DGLGraph will not contain any virtual nodes.**
Parameters
----------
featurizer_funcs : dict
Mapping feature name to the featurization function.
Each function is of signature ``func(rdkit.Chem.rdchem.Atom) -> list or 1D numpy array``.
feat_sizes : dict
Mapping feature name to the size of the corresponding feature. If None, they will be
computed when needed. Default: None.
Examples
--------
>>> from dgl.data.life_sci import BaseAtomFeaturizer, atom_mass, atom_degree_one_hot
>>> from rdkit import Chem
>>> mol = Chem.MolFromSmiles('CCO')
>>> atom_featurizer = BaseAtomFeaturizer({'mass': atom_mass, 'degree': atom_degree_one_hot})
>>> atom_featurizer(mol)
{'mass': tensor([[0.1201],
[0.1201],
[0.1600]]),
'degree': tensor([[0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])}
"""
def __init__(self, featurizer_funcs, feat_sizes=None):
self.featurizer_funcs = featurizer_funcs
if feat_sizes is None:
feat_sizes = dict()
self._feat_sizes = feat_sizes
def feat_size(self, feat_name):
"""Get the feature size for ``feat_name``.
Returns
-------
int
Feature size for the feature with name ``feat_name``.
"""
if feat_name not in self.featurizer_funcs:
return ValueError('Expect feat_name to be in {}, got {}'.format(
list(self.featurizer_funcs.keys()), feat_name))
if feat_name not in self._feat_sizes:
atom = Chem.MolFromSmiles('C').GetAtomWithIdx(0)
self._feat_sizes[feat_name] = len(self.featurizer_funcs[feat_name](atom))
return self._feat_sizes[feat_name]
def __call__(self, mol):
"""Featurize all atoms in a molecule.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance.
Returns
-------
dict
For each function in self.featurizer_funcs with the key ``k``, store the computed
feature under the key ``k``. Each feature is a tensor of dtype float32 and shape
(N, M), where N is the number of atoms in the molecule.
"""
num_atoms = mol.GetNumAtoms()
atom_features = defaultdict(list)
# Compute features for each atom
for i in range(num_atoms):
atom = mol.GetAtomWithIdx(i)
for feat_name, feat_func in self.featurizer_funcs.items():
atom_features[feat_name].append(feat_func(atom))
# Stack the features and convert them to float arrays
processed_features = dict()
for feat_name, feat_list in atom_features.items():
feat = np.stack(feat_list)
processed_features[feat_name] = F.zerocopy_from_numpy(feat.astype(np.float32))
return processed_features
class CanonicalAtomFeaturizer(BaseAtomFeaturizer):
"""A default featurizer for atoms.
The atom features include:
* **One hot encoding of the atom type**. The supported atom types include
``C``, ``N``, ``O``, ``S``, ``F``, ``Si``, ``P``, ``Cl``, ``Br``, ``Mg``,
``Na``, ``Ca``, ``Fe``, ``As``, ``Al``, ``I``, ``B``, ``V``, ``K``, ``Tl``,
``Yb``, ``Sb``, ``Sn``, ``Ag``, ``Pd``, ``Co``, ``Se``, ``Ti``, ``Zn``,
``H``, ``Li``, ``Ge``, ``Cu``, ``Au``, ``Ni``, ``Cd``, ``In``, ``Mn``, ``Zr``,
``Cr``, ``Pt``, ``Hg``, ``Pb``.
* **One hot encoding of the atom degree**. The supported possibilities
include ``0 - 10``.
* **One hot encoding of the number of implicit Hs on the atom**. The supported
possibilities include ``0 - 6``.
* **Formal charge of the atom**.
* **Number of radical electrons of the atom**.
* **One hot encoding of the atom hybridization**. The supported possibilities include
``SP``, ``SP2``, ``SP3``, ``SP3D``, ``SP3D2``.
* **Whether the atom is aromatic**.
* **One hot encoding of the number of total Hs on the atom**. The supported possibilities
include ``0 - 4``.
**We assume the resulting DGLGraph will not contain any virtual nodes.**
Parameters
----------
atom_data_field : str
Name for storing atom features in DGLGraphs, default to be 'h'.
"""
def __init__(self, atom_data_field='h'):
super(CanonicalAtomFeaturizer, self).__init__(
featurizer_funcs={atom_data_field: ConcatFeaturizer(
[atom_type_one_hot,
atom_degree_one_hot,
atom_implicit_valence_one_hot,
atom_formal_charge,
atom_num_radical_electrons,
atom_hybridization_one_hot,
atom_is_aromatic,
atom_total_num_H_one_hot]
)})
def bond_type_one_hot(bond, allowable_set=None, encode_unknown=False):
"""One hot encoding for the type of a bond.
Parameters
----------
bond : rdkit.Chem.rdchem.Bond
RDKit bond instance.
allowable_set : list of Chem.rdchem.BondType
Bond types to consider. Default: ``Chem.rdchem.BondType.SINGLE``,
``Chem.rdchem.BondType.DOUBLE``, ``Chem.rdchem.BondType.TRIPLE``,
``Chem.rdchem.BondType.AROMATIC``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
Returns
-------
list
List of boolean values where at most one value is True.
"""
if allowable_set is None:
allowable_set = [Chem.rdchem.BondType.SINGLE,
Chem.rdchem.BondType.DOUBLE,
Chem.rdchem.BondType.TRIPLE,
Chem.rdchem.BondType.AROMATIC]
return one_hot_encoding(bond.GetBondType(), allowable_set, encode_unknown)
def bond_is_conjugated_one_hot(bond, allowable_set=None, encode_unknown=False):
"""One hot encoding for whether the bond is conjugated.
Parameters
----------
bond : rdkit.Chem.rdchem.Bond
RDKit bond instance.
allowable_set : list of bool
Conditions to consider. Default: ``False`` and ``True``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
Returns
-------
list
List of boolean values where at most one value is True.
"""
if allowable_set is None:
allowable_set = [False, True]
return one_hot_encoding(bond.GetIsConjugated(), allowable_set, encode_unknown)
def bond_is_conjugated(bond):
"""Get whether the bond is conjugated.
Parameters
----------
bond : rdkit.Chem.rdchem.Bond
RDKit bond instance.
Returns
-------
list
List containing one bool only.
"""
return [bond.GetIsConjugated()]
def bond_is_in_ring_one_hot(bond, allowable_set=None, encode_unknown=False):
"""One hot encoding for whether the bond is in a ring of any size.
Parameters
----------
bond : rdkit.Chem.rdchem.Bond
RDKit bond instance.
allowable_set : list of bool
Conditions to consider. Default: ``False`` and ``True``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
Returns
-------
list
List of boolean values where at most one value is True.
"""
if allowable_set is None:
allowable_set = [False, True]
return one_hot_encoding(bond.IsInRing(), allowable_set, encode_unknown)
def bond_is_in_ring(bond):
"""Get whether the bond is in a ring of any size.
Parameters
----------
bond : rdkit.Chem.rdchem.Bond
RDKit bond instance.
Returns
-------
list
List containing one bool only.
"""
return [bond.IsInRing()]
def bond_stereo_one_hot(bond, allowable_set=None, encode_unknown=False):
"""One hot encoding for the stereo configuration of a bond.
Parameters
----------
bond : rdkit.Chem.rdchem.Bond
RDKit bond instance.
allowable_set : list of rdkit.Chem.rdchem.BondStereo
Stereo configurations to consider. Default: ``rdkit.Chem.rdchem.BondStereo.STEREONONE``,
``rdkit.Chem.rdchem.BondStereo.STEREOANY``, ``rdkit.Chem.rdchem.BondStereo.STEREOZ``,
``rdkit.Chem.rdchem.BondStereo.STEREOE``, ``rdkit.Chem.rdchem.BondStereo.STEREOCIS``,
``rdkit.Chem.rdchem.BondStereo.STEREOTRANS``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
Returns
-------
list
List of boolean values where at most one value is True.
"""
if allowable_set is None:
allowable_set = [Chem.rdchem.BondStereo.STEREONONE,
Chem.rdchem.BondStereo.STEREOANY,
Chem.rdchem.BondStereo.STEREOZ,
Chem.rdchem.BondStereo.STEREOE,
Chem.rdchem.BondStereo.STEREOCIS,
Chem.rdchem.BondStereo.STEREOTRANS]
return one_hot_encoding(bond.GetStereo(), allowable_set, encode_unknown)
class BaseBondFeaturizer(object):
"""An abstract class for bond featurizers.
Loop over all bonds in a molecule and featurize them with the ``featurizer_funcs``.
We assume the constructed ``DGLGraph`` is a bi-directed graph where the **i** th bond in the
molecule, i.e. ``mol.GetBondWithIdx(i)``, corresponds to the **(2i)**-th and **(2i+1)**-th edges
in the DGLGraph.
**We assume the resulting DGLGraph will be created with :func:`smiles_to_bigraph` without
self loops.**
Parameters
----------
featurizer_funcs : dict
Mapping feature name to the featurization function.
Each function is of signature ``func(rdkit.Chem.rdchem.Bond) -> list or 1D numpy array``.
feat_sizes : dict
Mapping feature name to the size of the corresponding feature. If None, they will be
computed when needed. Default: None.
Examples
--------
>>> from dgl.data.life_sci import BaseBondFeaturizer, bond_type_one_hot, bond_is_in_ring
>>> from rdkit import Chem
>>> mol = Chem.MolFromSmiles('CCO')
>>> bond_featurizer = BaseBondFeaturizer({'bond_type': bond_type_one_hot, 'in_ring': bond_is_in_ring})
>>> bond_featurizer(mol)
{'bond_type': tensor([[1., 0., 0., 0.],
[1., 0., 0., 0.],
[1., 0., 0., 0.],
[1., 0., 0., 0.]]),
'in_ring': tensor([[0.], [0.], [0.], [0.]])}
"""
def __init__(self, featurizer_funcs, feat_sizes=None):
self.featurizer_funcs = featurizer_funcs
if feat_sizes is None:
feat_sizes = dict()
self._feat_sizes = feat_sizes
def feat_size(self, feat_name):
"""Get the feature size for ``feat_name``.
Returns
-------
int
Feature size for the feature with name ``feat_name``.
"""
if feat_name not in self.featurizer_funcs:
return ValueError('Expect feat_name to be in {}, got {}'.format(
list(self.featurizer_funcs.keys()), feat_name))
if feat_name not in self._feat_sizes:
bond = Chem.MolFromSmiles('CO').GetBondWithIdx(0)
self._feat_sizes[feat_name] = len(self.featurizer_funcs[feat_name](bond))
return self._feat_sizes[feat_name]
def __call__(self, mol):
"""Featurize all bonds in a molecule.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance.
Returns
-------
dict
For each function in self.featurizer_funcs with the key ``k``, store the computed
feature under the key ``k``. Each feature is a tensor of dtype float32 and shape
(N, M), where N is the number of atoms in the molecule.
"""
num_bonds = mol.GetNumBonds()
bond_features = defaultdict(list)
# Compute features for each bond
for i in range(num_bonds):
bond = mol.GetBondWithIdx(i)
for feat_name, feat_func in self.featurizer_funcs.items():
feat = feat_func(bond)
bond_features[feat_name].extend([feat, feat.copy()])
# Stack the features and convert them to float arrays
processed_features = dict()
for feat_name, feat_list in bond_features.items():
feat = np.stack(feat_list)
processed_features[feat_name] = F.zerocopy_from_numpy(feat.astype(np.float32))
return processed_features
class CanonicalBondFeaturizer(BaseBondFeaturizer):
"""A default featurizer for bonds.
The bond features include:
* **One hot encoding of the bond type**. The supported bond types include
``SINGLE``, ``DOUBLE``, ``TRIPLE``, ``AROMATIC``.
* **Whether the bond is conjugated.**.
* **Whether the bond is in a ring of any size.**
* **One hot encoding of the stereo configuration of a bond**. The supported bond stereo
configurations include ``STEREONONE``, ``STEREOANY``, ``STEREOZ``, ``STEREOE``,
``STEREOCIS``, ``STEREOTRANS``.
**We assume the resulting DGLGraph will be created with :func:`smiles_to_bigraph` without
self loops.**
"""
def __init__(self, bond_data_field='e'):
super(CanonicalBondFeaturizer, self).__init__(
featurizer_funcs={bond_data_field: ConcatFeaturizer(
[bond_type_one_hot,
bond_is_conjugated,
bond_is_in_ring,
bond_stereo_one_hot]
)})
"""Convert molecules into DGLGraphs."""
import numpy as np
from dgl import DGLGraph
from functools import partial
from rdkit import Chem
from rdkit.Chem import rdmolfiles, rdmolops
try:
import mdtraj
except ImportError:
pass
__all__ = ['mol_to_graph',
'smiles_to_bigraph',
'mol_to_bigraph',
'smiles_to_complete_graph',
'mol_to_complete_graph',
'k_nearest_neighbors']
def mol_to_graph(mol, graph_constructor, node_featurizer, edge_featurizer):
"""Convert an RDKit molecule object into a DGLGraph and featurize for it.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule holder
graph_constructor : callable
Takes an RDKit molecule as input and returns a DGLGraph
node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for nodes like atoms in a molecule, which can be used to
update ndata for a DGLGraph.
edge_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for edges like bonds in a molecule, which can be used to
update edata for a DGLGraph.
Returns
-------
g : DGLGraph
Converted DGLGraph for the molecule
"""
new_order = rdmolfiles.CanonicalRankAtoms(mol)
mol = rdmolops.RenumberAtoms(mol, new_order)
g = graph_constructor(mol)
if node_featurizer is not None:
g.ndata.update(node_featurizer(mol))
if edge_featurizer is not None:
g.edata.update(edge_featurizer(mol))
return g
def construct_bigraph_from_mol(mol, add_self_loop=False):
"""Construct a bi-directed DGLGraph with topology only for the molecule.
The **i** th atom in the molecule, i.e. ``mol.GetAtomWithIdx(i)``, corresponds to the
**i** th node in the returned DGLGraph.
The **i** th bond in the molecule, i.e. ``mol.GetBondWithIdx(i)``, corresponds to the
**(2i)**-th and **(2i+1)**-th edges in the returned DGLGraph. The **(2i)**-th and
**(2i+1)**-th edges will be separately from **u** to **v** and **v** to **u**, where
**u** is ``bond.GetBeginAtomIdx()`` and **v** is ``bond.GetEndAtomIdx()``.
If self loops are added, the last **n** edges will separately be self loops for
atoms ``0, 1, ..., n-1``.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule holder
add_self_loop : bool
Whether to add self loops in DGLGraphs. Default to False.
Returns
-------
g : DGLGraph
Empty bigraph topology of the molecule
"""
g = DGLGraph()
# Add nodes
num_atoms = mol.GetNumAtoms()
g.add_nodes(num_atoms)
# Add edges
src_list = []
dst_list = []
num_bonds = mol.GetNumBonds()
for i in range(num_bonds):
bond = mol.GetBondWithIdx(i)
u = bond.GetBeginAtomIdx()
v = bond.GetEndAtomIdx()
src_list.extend([u, v])
dst_list.extend([v, u])
g.add_edges(src_list, dst_list)
if add_self_loop:
nodes = g.nodes()
g.add_edges(nodes, nodes)
return g
def mol_to_bigraph(mol, add_self_loop=False,
node_featurizer=None,
edge_featurizer=None):
"""Convert an RDKit molecule object into a bi-directed DGLGraph and featurize for it.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule holder
add_self_loop : bool
Whether to add self loops in DGLGraphs. Default to False.
node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for nodes like atoms in a molecule, which can be used to update
ndata for a DGLGraph. Default to None.
edge_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for edges like bonds in a molecule, which can be used to update
edata for a DGLGraph. Default to None.
Returns
-------
g : DGLGraph
Bi-directed DGLGraph for the molecule
"""
return mol_to_graph(mol, partial(construct_bigraph_from_mol, add_self_loop=add_self_loop),
node_featurizer, edge_featurizer)
def smiles_to_bigraph(smiles, add_self_loop=False,
node_featurizer=None,
edge_featurizer=None):
"""Convert a SMILES into a bi-directed DGLGraph and featurize for it.
Parameters
----------
smiles : str
String of SMILES
add_self_loop : bool
Whether to add self loops in DGLGraphs. Default to False.
node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for nodes like atoms in a molecule, which can be used to update
ndata for a DGLGraph. Default to None.
edge_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for edges like bonds in a molecule, which can be used to update
edata for a DGLGraph. Default to None.
Returns
-------
g : DGLGraph
Bi-directed DGLGraph for the molecule
"""
mol = Chem.MolFromSmiles(smiles)
return mol_to_bigraph(mol, add_self_loop, node_featurizer, edge_featurizer)
def construct_complete_graph_from_mol(mol, add_self_loop=False):
"""Construct a complete graph with topology only for the molecule
The **i** th atom in the molecule, i.e. ``mol.GetAtomWithIdx(i)``, corresponds to the
**i** th node in the returned DGLGraph.
The edges are in the order of (0, 0), (1, 0), (2, 0), ... (0, 1), (1, 1), (2, 1), ...
If self loops are not created, we will not have (0, 0), (1, 1), ...
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule holder
add_self_loop : bool
Whether to add self loops in DGLGraphs. Default to False.
Returns
-------
g : DGLGraph
Empty complete graph topology of the molecule
"""
g = DGLGraph()
num_atoms = mol.GetNumAtoms()
g.add_nodes(num_atoms)
if add_self_loop:
g.add_edges(
[i for i in range(num_atoms) for j in range(num_atoms)],
[j for i in range(num_atoms) for j in range(num_atoms)])
else:
g.add_edges(
[i for i in range(num_atoms) for j in range(num_atoms - 1)], [
j for i in range(num_atoms)
for j in range(num_atoms) if i != j
])
return g
def mol_to_complete_graph(mol, add_self_loop=False,
node_featurizer=None,
edge_featurizer=None):
"""Convert an RDKit molecule into a complete DGLGraph and featurize for it.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule holder
add_self_loop : bool
Whether to add self loops in DGLGraphs. Default to False.
node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for nodes like atoms in a molecule, which can be used to update
ndata for a DGLGraph. Default to None.
edge_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for edges like bonds in a molecule, which can be used to update
edata for a DGLGraph. Default to None.
Returns
-------
g : DGLGraph
Complete DGLGraph for the molecule
"""
return mol_to_graph(mol, partial(construct_complete_graph_from_mol, add_self_loop=add_self_loop),
node_featurizer, edge_featurizer)
def smiles_to_complete_graph(smiles, add_self_loop=False,
node_featurizer=None,
edge_featurizer=None):
"""Convert a SMILES into a complete DGLGraph and featurize for it.
Parameters
----------
smiles : str
String of SMILES
add_self_loop : bool
Whether to add self loops in DGLGraphs. Default to False.
node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for nodes like atoms in a molecule, which can be used to update
ndata for a DGLGraph. Default to None.
edge_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for edges like bonds in a molecule, which can be used to update
edata for a DGLGraph. Default to None.
Returns
-------
g : DGLGraph
Complete DGLGraph for the molecule
"""
mol = Chem.MolFromSmiles(smiles)
return mol_to_complete_graph(mol, add_self_loop, node_featurizer, edge_featurizer)
def k_nearest_neighbors(coordinates, neighbor_cutoff, max_num_neighbors):
"""Find k nearest neighbors for each atom based on the 3D coordinates and
return the resulted edges.
For each atom, find its k nearest neighbors and return edges
from these neighbors to it.
Parameters
----------
coordinates : numpy.ndarray of shape (N, 3)
The 3D coordinates of atoms in the molecule. N for the number of atoms.
neighbor_cutoff : float
Distance cutoff to define 'neighboring'.
max_num_neighbors : int or None.
If not None, then this specifies the maximum number of closest neighbors
allowed for each atom.
Returns
-------
srcs : list of int
Source nodes.
dsts : list of int
Destination nodes.
distances : list of float
Distances between the end nodes.
"""
num_atoms = coordinates.shape[0]
traj = mdtraj.Trajectory(coordinates.reshape((1, num_atoms, 3)), None)
neighbors = mdtraj.geometry.compute_neighborlist(traj, neighbor_cutoff)
srcs, dsts, distances = [], [], []
for i in range(num_atoms):
delta = coordinates[i] - coordinates.take(neighbors[i], axis=0)
dist = np.linalg.norm(delta, axis=1)
if max_num_neighbors is not None and len(neighbors[i]) > max_num_neighbors:
sorted_neighbors = list(zip(dist, neighbors[i]))
# Sort neighbors based on distance from smallest to largest
sorted_neighbors.sort(key=lambda tup: tup[0])
dsts.extend([i for _ in range(max_num_neighbors)])
srcs.extend([int(sorted_neighbors[j][1]) for j in range(max_num_neighbors)])
distances.extend([float(sorted_neighbors[j][0]) for j in range(max_num_neighbors)])
else:
dsts.extend([i for _ in range(len(neighbors[i]))])
srcs.extend(neighbors[i].tolist())
distances.extend(dist.tolist())
return srcs, dsts, distances
# Todo(Mufei): smiles_to_knn_graph, mol_to_knn_graph
"""Utils for RDKit, mostly adapted from DeepChem
(https://github.com/deepchem/deepchem/blob/master/deepchem)."""
import warnings
from functools import partial
from multiprocessing import Pool
from rdkit import Chem
from rdkit.Chem import AllChem
try:
from StringIO import StringIO
except ImportError:
from io import StringIO
__all__ = ['get_mol_3D_coordinates',
'load_molecule',
'multiprocess_load_molecules']
def get_mol_3D_coordinates(mol):
"""Get 3D coordinates of the molecule.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance.
Returns
-------
numpy.ndarray of shape (N, 3) or None
The 3D coordinates of atoms in the molecule. N for the number of atoms in
the molecule. For failures in getting the conformations, None will be returned.
"""
try:
conf = mol.GetConformer()
conf_num_atoms = conf.GetNumAtoms()
mol_num_atoms = mol.GetNumAtoms()
assert mol_num_atoms == conf_num_atoms, \
'Expect the number of atoms in the molecule and its conformation ' \
'to be the same, got {:d} and {:d}'.format(mol_num_atoms, conf_num_atoms)
return conf.GetPositions()
except:
warnings.warn('Unable to get conformation of the molecule.')
return None
def load_molecule(molecule_file, sanitize=False, calc_charges=False,
remove_hs=False, use_conformation=True):
"""Load a molecule from a file.
Parameters
----------
molecule_file : str
Path to file for storing a molecule, which can be of format '.mol2', '.sdf',
'.pdbqt', or '.pdb'.
sanitize : bool
Whether sanitization is performed in initializing RDKit molecule instances. See
https://www.rdkit.org/docs/RDKit_Book.html for details of the sanitization.
Default to False.
calc_charges : bool
Whether to add Gasteiger charges via RDKit. Setting this to be True will enforce
``sanitize`` to be True. Default to False.
remove_hs : bool
Whether to remove hydrogens via RDKit. Note that removing hydrogens can be quite
slow for large molecules. Default to False.
use_conformation : bool
Whether we need to extract molecular conformation from proteins and ligands.
Default to True.
Returns
-------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance for the loaded molecule.
coordinates : np.ndarray of shape (N, 3) or None
The 3D coordinates of atoms in the molecule. N for the number of atoms in
the molecule. None will be returned if ``use_conformation`` is False or
we failed to get conformation information.
"""
if molecule_file.endswith('.mol2'):
mol = Chem.MolFromMol2File(molecule_file, sanitize=False, removeHs=False)
elif molecule_file.endswith('.sdf'):
supplier = Chem.SDMolSupplier(molecule_file, sanitize=False, removeHs=False)
mol = supplier[0]
elif molecule_file.endswith('.pdbqt'):
with open(molecule_file) as f:
pdbqt_data = f.readlines()
pdb_block = ''
for line in pdbqt_data:
pdb_block += '{}\n'.format(line[:66])
mol = Chem.MolFromPDBBlock(pdb_block, sanitize=False, removeHs=False)
elif molecule_file.endswith('.pdb'):
mol = Chem.MolFromPDBFile(molecule_file, sanitize=False, removeHs=False)
else:
return ValueError('Expect the format of the molecule_file to be '
'one of .mol2, .sdf, .pdbqt and .pdb, got {}'.format(molecule_file))
try:
if sanitize or calc_charges:
Chem.SanitizeMol(mol)
if calc_charges:
# Compute Gasteiger charges on the molecule.
try:
AllChem.ComputeGasteigerCharges(mol)
except:
warnings.warn('Unable to compute charges for the molecule.')
if remove_hs:
mol = Chem.RemoveHs(mol)
except:
return None, None
if use_conformation:
coordinates = get_mol_3D_coordinates(mol)
else:
coordinates = None
return mol, coordinates
def multiprocess_load_molecules(files, sanitize=False, calc_charges=False,
remove_hs=False, use_conformation=True, num_processes=2):
"""Load molecules from files with multiprocessing.
Parameters
----------
files : list of str
Each element is a path to a file storing a molecule, which can be of format '.mol2',
'.sdf', '.pdbqt', or '.pdb'.
sanitize : bool
Whether sanitization is performed in initializing RDKit molecule instances. See
https://www.rdkit.org/docs/RDKit_Book.html for details of the sanitization.
Default to False.
calc_charges : bool
Whether to add Gasteiger charges via RDKit. Setting this to be True will enforce
``sanitize`` to be True. Default to False.
remove_hs : bool
Whether to remove hydrogens via RDKit. Note that removing hydrogens can be quite
slow for large molecules. Default to False.
use_conformation : bool
Whether we need to extract molecular conformation from proteins and ligands.
Default to True.
num_processes : int or None
Number of worker processes to use. If None,
then we will use the number of CPUs in the systetm. Default to 2.
Returns
-------
list of 2-tuples
The first element of each 2-tuple is an RDKit molecule instance. The second element
of each 2-tuple is the 3D atom coordinates of the corresponding molecule if
use_conformation is True and the coordinates has been successfully loaded. Otherwise,
it will be None.
"""
if num_processes == 1:
mols_loaded = []
for i, f in enumerate(files):
mols_loaded.append(load_molecule(
f, sanitize=sanitize, calc_charges=calc_charges,
remove_hs=remove_hs, use_conformation=use_conformation))
else:
with Pool(processes=num_processes) as pool:
mols_loaded = pool.map_async(partial(
load_molecule, sanitize=sanitize, calc_charges=calc_charges,
remove_hs=remove_hs, use_conformation=use_conformation), files)
mols_loaded = mols_loaded.get()
return mols_loaded
"""Various methods for splitting chemical datasets.
We mostly adapt them from deepchem
(https://github.com/deepchem/deepchem/blob/master/deepchem/splits/splitters.py).
"""
import dgl.backend as F
import numpy as np
from dgl.data.utils import split_dataset, Subset
from collections import defaultdict
from functools import partial
from itertools import accumulate, chain
from rdkit import Chem
from rdkit.Chem import rdMolDescriptors
from rdkit.Chem.rdmolops import FastFindRings
from rdkit.Chem.Scaffolds import MurckoScaffold
__all__ = ['ConsecutiveSplitter',
'RandomSplitter',
'MolecularWeightSplitter',
'ScaffoldSplitter',
'SingleTaskStratifiedSplitter']
def base_k_fold_split(split_method, dataset, k, log):
"""Split dataset for k-fold cross validation.
Parameters
----------
split_method : callable
Arbitrary method for splitting the dataset
into training, validation and test subsets.
dataset
We assume ``len(dataset)`` gives the size for the dataset and ``dataset[i]``
gives the ith datapoint.
k : int
Number of folds to use and should be no smaller than 2.
log : bool
Whether to print a message at the start of preparing each fold.
Returns
-------
all_folds : list of 2-tuples
Each element of the list represents a fold and is a 2-tuple (train_set, val_set).
"""
assert k >= 2, 'Expect the number of folds to be no smaller than 2, got {:d}'.format(k)
all_folds = []
frac_per_part = 1. / k
for i in range(k):
if log:
print('Processing fold {:d}/{:d}'.format(i + 1, k))
# We are reusing the code for train-validation-test split.
train_set1, val_set, train_set2 = split_method(dataset,
frac_train=i * frac_per_part,
frac_val=frac_per_part,
frac_test=1. - (i + 1) * frac_per_part)
# For cross validation, each fold consists of only a train subset and
# a validation subset.
train_set = Subset(dataset, np.concatenate(
[train_set1.indices, train_set2.indices]).astype(np.int64))
all_folds.append((train_set, val_set))
return all_folds
def train_val_test_sanity_check(frac_train, frac_val, frac_test):
"""Sanity check for train-val-test split
Ensure that the fractions of the dataset to use for training,
validation and test add up to 1.
Parameters
----------
frac_train : float
Fraction of the dataset to use for training.
frac_val : float
Fraction of the dataset to use for validation.
frac_test : float
Fraction of the dataset to use for test.
"""
total_fraction = frac_train + frac_val + frac_test
assert np.allclose(total_fraction, 1.), \
'Expect the sum of fractions for training, validation and ' \
'test to be 1, got {:.4f}'.format(total_fraction)
def indices_split(dataset, frac_train, frac_val, frac_test, indices):
"""Reorder datapoints based on the specified indices and then take consecutive
chunks as subsets.
Parameters
----------
dataset
We assume ``len(dataset)`` gives the size for the dataset and ``dataset[i]``
gives the ith datapoint.
frac_train : float
Fraction of data to use for training.
frac_val : float
Fraction of data to use for validation.
frac_test : float
Fraction of data to use for test.
indices : list or ndarray
Indices specifying the order of datapoints.
Returns
-------
list of length 3
Subsets for training, validation and test, which are all :class:`Subset` instances.
"""
frac_list = np.array([frac_train, frac_val, frac_test])
assert np.allclose(np.sum(frac_list), 1.), \
'Expect frac_list sum to 1, got {:.4f}'.format(np.sum(frac_list))
num_data = len(dataset)
lengths = (num_data * frac_list).astype(int)
lengths[-1] = num_data - np.sum(lengths[:-1])
return [Subset(dataset, list(indices[offset - length:offset]))
for offset, length in zip(accumulate(lengths), lengths)]
def count_and_log(message, i, total, log_every_n):
"""Print a message to reflect the progress of processing once a while.
Parameters
----------
message : str
Message to print.
i : int
Current index.
total : int
Total count.
log_every_n : None or int
Molecule related computation can take a long time for a large dataset and we want
to learn the progress of processing. This can be done by printing a message whenever
a batch of ``log_every_n`` molecules have been processed. If None, no messages will
be printed.
"""
if (log_every_n is not None) and ((i + 1) % log_every_n == 0):
print('{} {:d}/{:d}'.format(message, i + 1, total))
def prepare_mols(dataset, mols, sanitize, log_every_n=1000):
"""Prepare RDKit molecule instances.
Parameters
----------
dataset
We assume ``len(dataset)`` gives the size for the dataset, ``dataset[i]``
gives the ith datapoint and ``dataset.smiles[i]`` gives the SMILES for the
ith datapoint.
mols : None or list of rdkit.Chem.rdchem.Mol
None or pre-computed RDKit molecule instances. If not None, we expect a
one-on-one correspondence between ``dataset.smiles`` and ``mols``, i.e.
``mols[i]`` corresponds to ``dataset.smiles[i]``.
sanitize : bool
This argument only comes into effect when ``mols`` is None and decides whether
sanitization is performed in initializing RDKit molecule instances. See
https://www.rdkit.org/docs/RDKit_Book.html for details of the sanitization.
log_every_n : None or int
Molecule related computation can take a long time for a large dataset and we want
to learn the progress of processing. This can be done by printing a message whenever
a batch of ``log_every_n`` molecules have been processed. If None, no messages will
be printed. Default to 1000.
Returns
-------
mols : list of rdkit.Chem.rdchem.Mol
RDkit molecule instances where there is a one-on-one correspondence between
``dataset.smiles`` and ``mols``, i.e. ``mols[i]`` corresponds to ``dataset.smiles[i]``.
"""
if mols is not None:
# Sanity check
assert len(mols) == len(dataset), \
'Expect mols to be of the same size as that of the dataset, ' \
'got {:d} and {:d}'.format(len(mols), len(dataset))
else:
if log_every_n is not None:
print('Start initializing RDKit molecule instances...')
mols = []
for i, s in enumerate(dataset.smiles):
count_and_log('Creating RDKit molecule instance',
i, len(dataset.smiles), log_every_n)
mols.append(Chem.MolFromSmiles(s, sanitize=sanitize))
return mols
class ConsecutiveSplitter(object):
"""Split datasets with the input order.
The dataset is split without permutation, so the splitting is deterministic.
"""
@staticmethod
def train_val_test_split(dataset, frac_train=0.8, frac_val=0.1, frac_test=0.1):
"""Split the dataset into three consecutive chunks for training, validation and test.
Parameters
----------
dataset
We assume ``len(dataset)`` gives the size for the dataset and ``dataset[i]``
gives the ith datapoint.
frac_train : float
Fraction of data to use for training. By default, we set this to be 0.8, i.e.
80% of the dataset is used for training.
frac_val : float
Fraction of data to use for validation. By default, we set this to be 0.1, i.e.
10% of the dataset is used for validation.
frac_test : float
Fraction of data to use for test. By default, we set this to be 0.1, i.e.
10% of the dataset is used for test.
Returns
-------
list of length 3
Subsets for training, validation and test, which are all :class:`Subset` instances.
"""
return split_dataset(dataset, frac_list=[frac_train, frac_val, frac_test], shuffle=False)
@staticmethod
def k_fold_split(dataset, k=5, log=True):
"""Split the dataset for k-fold cross validation by taking consecutive chunks.
Parameters
----------
dataset
We assume ``len(dataset)`` gives the size for the dataset and ``dataset[i]``
gives the ith datapoint.
k : int
Number of folds to use and should be no smaller than 2. Default to be 5.
log : bool
Whether to print a message at the start of preparing each fold.
Returns
-------
list of 2-tuples
Each element of the list represents a fold and is a 2-tuple (train_set, val_set).
"""
return base_k_fold_split(ConsecutiveSplitter.train_val_test_split, dataset, k, log)
class RandomSplitter(object):
"""Randomly reorder datasets and then split them.
The dataset is split with permutation and the splitting is hence random.
"""
@staticmethod
def train_val_test_split(dataset, frac_train=0.8, frac_val=0.1,
frac_test=0.1, random_state=None):
"""Randomly permute the dataset and then split it into
three consecutive chunks for training, validation and test.
Parameters
----------
dataset
We assume ``len(dataset)`` gives the size for the dataset and ``dataset[i]``
gives the ith datapoint.
frac_train : float
Fraction of data to use for training. By default, we set this to be 0.8, i.e.
80% of the dataset is used for training.
frac_val : float
Fraction of data to use for validation. By default, we set this to be 0.1, i.e.
10% of the dataset is used for validation.
frac_test : float
Fraction of data to use for test. By default, we set this to be 0.1, i.e.
10% of the dataset is used for test.
random_state : None, int or array_like, optional
Random seed used to initialize the pseudo-random number generator.
Can be any integer between 0 and 2**32 - 1 inclusive, an array
(or other sequence) of such integers, or None (the default).
If seed is None, then RandomState will try to read data from /dev/urandom
(or the Windows analogue) if available or seed from the clock otherwise.
Returns
-------
list of length 3
Subsets for training, validation and test.
"""
return split_dataset(dataset, frac_list=[frac_train, frac_val, frac_test],
shuffle=True, random_state=random_state)
@staticmethod
def k_fold_split(dataset, k=5, random_state=None, log=True):
"""Randomly permute the dataset and then split it
for k-fold cross validation by taking consecutive chunks.
Parameters
----------
dataset
We assume ``len(dataset)`` gives the size for the dataset and ``dataset[i]``
gives the ith datapoint.
k : int
Number of folds to use and should be no smaller than 2. Default to be 5.
random_state : None, int or array_like, optional
Random seed used to initialize the pseudo-random number generator.
Can be any integer between 0 and 2**32 - 1 inclusive, an array
(or other sequence) of such integers, or None (the default).
If seed is None, then RandomState will try to read data from /dev/urandom
(or the Windows analogue) if available or seed from the clock otherwise.
log : bool
Whether to print a message at the start of preparing each fold. Default to True.
Returns
-------
list of 2-tuples
Each element of the list represents a fold and is a 2-tuple (train_set, val_set).
"""
# Permute the dataset only once so that each datapoint
# will appear once in exactly one fold.
indices = np.random.RandomState(seed=random_state).permutation(len(dataset))
return base_k_fold_split(partial(indices_split, indices=indices), dataset, k, log)
class MolecularWeightSplitter(object):
"""Sort molecules based on their weights and then split them."""
@staticmethod
def molecular_weight_indices(molecules, log_every_n):
"""Reorder molecules based on molecular weights.
Parameters
----------
molecules : list of rdkit.Chem.rdchem.Mol
Pre-computed RDKit molecule instances. We expect a one-on-one
correspondence between ``dataset.smiles`` and ``mols``, i.e.
``mols[i]`` corresponds to ``dataset.smiles[i]``.
log_every_n : None or int
Molecule related computation can take a long time for a large dataset and we want
to learn the progress of processing. This can be done by printing a message whenever
a batch of ``log_every_n`` molecules have been processed. If None, no messages will
be printed.
Returns
-------
indices : list or ndarray
Indices specifying the order of datapoints, which are basically
argsort of the molecular weights.
"""
if log_every_n is not None:
print('Start computing molecular weights.')
mws = []
for i, mol in enumerate(molecules):
count_and_log('Computing molecular weight for compound',
i, len(molecules), log_every_n)
mws.append(Chem.rdMolDescriptors.CalcExactMolWt(mol))
return np.argsort(mws)
@staticmethod
def train_val_test_split(dataset, mols=None, sanitize=True, frac_train=0.8,
frac_val=0.1, frac_test=0.1, log_every_n=1000):
"""Sort molecules based on their weights and then split them into
three consecutive chunks for training, validation and test.
Parameters
----------
dataset
We assume ``len(dataset)`` gives the size for the dataset, ``dataset[i]``
gives the ith datapoint and ``dataset.smiles[i]`` gives the SMILES for the
ith datapoint.
mols : None or list of rdkit.Chem.rdchem.Mol
None or pre-computed RDKit molecule instances. If not None, we expect a
one-on-one correspondence between ``dataset.smiles`` and ``mols``, i.e.
``mols[i]`` corresponds to ``dataset.smiles[i]``. Default to None.
sanitize : bool
This argument only comes into effect when ``mols`` is None and decides whether
sanitization is performed in initializing RDKit molecule instances. See
https://www.rdkit.org/docs/RDKit_Book.html for details of the sanitization.
Default to be True.
frac_train : float
Fraction of data to use for training. By default, we set this to be 0.8, i.e.
80% of the dataset is used for training.
frac_val : float
Fraction of data to use for validation. By default, we set this to be 0.1, i.e.
10% of the dataset is used for validation.
frac_test : float
Fraction of data to use for test. By default, we set this to be 0.1, i.e.
10% of the dataset is used for test.
log_every_n : None or int
Molecule related computation can take a long time for a large dataset and we want
to learn the progress of processing. This can be done by printing a message whenever
a batch of ``log_every_n`` molecules have been processed. If None, no messages will
be printed. Default to 1000.
Returns
-------
list of length 3
Subsets for training, validation and test, which are all :class:`Subset` instances.
"""
# Perform sanity check first as molecule instance initialization and descriptor
# computation can take a long time.
train_val_test_sanity_check(frac_train, frac_val, frac_test)
molecules = prepare_mols(dataset, mols, sanitize, log_every_n)
sorted_indices = MolecularWeightSplitter.molecular_weight_indices(molecules, log_every_n)
return indices_split(dataset, frac_train, frac_val, frac_test, sorted_indices)
@staticmethod
def k_fold_split(dataset, mols=None, sanitize=True, k=5, log_every_n=1000):
"""Sort molecules based on their weights and then split them
for k-fold cross validation by taking consecutive chunks.
Parameters
----------
dataset
We assume ``len(dataset)`` gives the size for the dataset, ``dataset[i]``
gives the ith datapoint and ``dataset.smiles[i]`` gives the SMILES for the
ith datapoint.
mols : None or list of rdkit.Chem.rdchem.Mol
None or pre-computed RDKit molecule instances. If not None, we expect a
one-on-one correspondence between ``dataset.smiles`` and ``mols``, i.e.
``mols[i]`` corresponds to ``dataset.smiles[i]``. Default to None.
sanitize : bool
This argument only comes into effect when ``mols`` is None and decides whether
sanitization is performed in initializing RDKit molecule instances. See
https://www.rdkit.org/docs/RDKit_Book.html for details of the sanitization.
Default to be True.
k : int
Number of folds to use and should be no smaller than 2. Default to be 5.
log_every_n : None or int
Molecule related computation can take a long time for a large dataset and we want
to learn the progress of processing. This can be done by printing a message whenever
a batch of ``log_every_n`` molecules have been processed. If None, no messages will
be printed. Default to 1000.
Returns
-------
list of 2-tuples
Each element of the list represents a fold and is a 2-tuple (train_set, val_set).
"""
molecules = prepare_mols(dataset, mols, sanitize, log_every_n)
sorted_indices = MolecularWeightSplitter.molecular_weight_indices(molecules, log_every_n)
return base_k_fold_split(partial(indices_split, indices=sorted_indices), dataset, k,
log=(log_every_n is not None))
class ScaffoldSplitter(object):
"""Group molecules based on their Bemis-Murcko scaffolds and then split the groups.
Group molecules so that all molecules in a group have a same scaffold (see reference).
The dataset is then split at the level of groups.
References
----------
Bemis, G. W.; Murcko, M. A. “The Properties of Known Drugs.
1. Molecular Frameworks.” J. Med. Chem. 39:2887-93 (1996).
"""
@staticmethod
def get_ordered_scaffold_sets(molecules, include_chirality, log_every_n):
"""Group molecules based on their Bemis-Murcko scaffolds and
order these groups based on their sizes.
The order is decided by comparing the size of groups, where groups with a larger size
are placed before the ones with a smaller size.
Parameters
----------
molecules : list of rdkit.Chem.rdchem.Mol
Pre-computed RDKit molecule instances. We expect a one-on-one
correspondence between ``dataset.smiles`` and ``mols``, i.e.
``mols[i]`` corresponds to ``dataset.smiles[i]``.
include_chirality : bool
Whether to consider chirality in computing scaffolds.
log_every_n : None or int
Molecule related computation can take a long time for a large dataset and we want
to learn the progress of processing. This can be done by printing a message whenever
a batch of ``log_every_n`` molecules have been processed. If None, no messages will
be printed.
Returns
-------
scaffold_sets : list
Each element of the list is a list of int,
representing the indices of compounds with a same scaffold.
"""
if log_every_n is not None:
print('Start computing Bemis-Murcko scaffolds.')
scaffolds = defaultdict(list)
for i, mol in enumerate(molecules):
count_and_log('Computing Bemis-Murcko for compound',
i, len(molecules), log_every_n)
# For mols that have not been sanitized, we need to compute their ring information
try:
FastFindRings(mol)
mol_scaffold = MurckoScaffold.MurckoScaffoldSmiles(
mol=mol, includeChirality=include_chirality)
# Group molecules that have the same scaffold
scaffolds[mol_scaffold].append(i)
except:
print('Failed to compute the scaffold for molecule {:d} '
'and it will be excluded.'.format(i + 1))
# Order groups of molecules by first comparing the size of groups
# and then the index of the first compound in the group.
scaffold_sets = [
scaffold_set for (scaffold, scaffold_set) in sorted(
scaffolds.items(), key=lambda x: (len(x[1]), x[1][0]), reverse=True)
]
return scaffold_sets
@staticmethod
def train_val_test_split(dataset, mols=None, sanitize=True, include_chirality=False,
frac_train=0.8, frac_val=0.1, frac_test=0.1, log_every_n=1000):
"""Split the dataset into training, validation and test set based on molecular scaffolds.
This spliting method ensures that molecules with a same scaffold will be collectively
in only one of the training, validation or test set. As a result, the fraction
of dataset to use for training and validation tend to be smaller than ``frac_train``
and ``frac_val``, while the fraction of dataset to use for test tends to be larger
than ``frac_test``.
Parameters
----------
dataset
We assume ``len(dataset)`` gives the size for the dataset, ``dataset[i]``
gives the ith datapoint and ``dataset.smiles[i]`` gives the SMILES for the
ith datapoint.
mols : None or list of rdkit.Chem.rdchem.Mol
None or pre-computed RDKit molecule instances. If not None, we expect a
one-on-one correspondence between ``dataset.smiles`` and ``mols``, i.e.
``mols[i]`` corresponds to ``dataset.smiles[i]``. Default to None.
sanitize : bool
This argument only comes into effect when ``mols`` is None and decides whether
sanitization is performed in initializing RDKit molecule instances. See
https://www.rdkit.org/docs/RDKit_Book.html for details of the sanitization.
Default to True.
include_chirality : bool
Whether to consider chirality in computing scaffolds. Default to False.
frac_train : float
Fraction of data to use for training. By default, we set this to be 0.8, i.e.
80% of the dataset is used for training.
frac_val : float
Fraction of data to use for validation. By default, we set this to be 0.1, i.e.
10% of the dataset is used for validation.
frac_test : float
Fraction of data to use for test. By default, we set this to be 0.1, i.e.
10% of the dataset is used for test.
log_every_n : None or int
Molecule related computation can take a long time for a large dataset and we want
to learn the progress of processing. This can be done by printing a message whenever
a batch of ``log_every_n`` molecules have been processed. If None, no messages will
be printed. Default to 1000.
Returns
-------
list of length 3
Subsets for training, validation and test, which are all :class:`Subset` instances.
"""
# Perform sanity check first as molecule related computation can take a long time.
train_val_test_sanity_check(frac_train, frac_val, frac_test)
molecules = prepare_mols(dataset, mols, sanitize)
scaffold_sets = ScaffoldSplitter.get_ordered_scaffold_sets(
molecules, include_chirality, log_every_n)
train_indices, val_indices, test_indices = [], [], []
train_cutoff = int(frac_train * len(molecules))
val_cutoff = int((frac_train + frac_val) * len(molecules))
for group_indices in scaffold_sets:
if len(train_indices) + len(group_indices) > train_cutoff:
if len(train_indices) + len(val_indices) + len(group_indices) > val_cutoff:
test_indices.extend(group_indices)
else:
val_indices.extend(group_indices)
else:
train_indices.extend(group_indices)
return [Subset(dataset, train_indices),
Subset(dataset, val_indices),
Subset(dataset, test_indices)]
@staticmethod
def k_fold_split(dataset, mols=None, sanitize=True,
include_chirality=False, k=5, log_every_n=1000):
"""Group molecules based on their scaffolds and sort groups based on their sizes.
The groups are then split for k-fold cross validation.
Same as usual k-fold splitting methods, each molecule will appear only once
in the validation set among all folds. In addition, this method ensures that
molecules with a same scaffold will be collectively in either the training
set or the validation set for each fold.
Note that the folds can be highly imbalanced depending on the
scaffold distribution in the dataset.
Parameters
----------
dataset
We assume ``len(dataset)`` gives the size for the dataset, ``dataset[i]``
gives the ith datapoint and ``dataset.smiles[i]`` gives the SMILES for the
ith datapoint.
mols : None or list of rdkit.Chem.rdchem.Mol
None or pre-computed RDKit molecule instances. If not None, we expect a
one-on-one correspondence between ``dataset.smiles`` and ``mols``, i.e.
``mols[i]`` corresponds to ``dataset.smiles[i]``. Default to None.
sanitize : bool
This argument only comes into effect when ``mols`` is None and decides whether
sanitization is performed in initializing RDKit molecule instances. See
https://www.rdkit.org/docs/RDKit_Book.html for details of the sanitization.
Default to True.
include_chirality : bool
Whether to consider chirality in computing scaffolds. Default to False.
k : int
Number of folds to use and should be no smaller than 2. Default to be 5.
log_every_n : None or int
Molecule related computation can take a long time for a large dataset and we want
to learn the progress of processing. This can be done by printing a message whenever
a batch of ``log_every_n`` molecules have been processed. If None, no messages will
be printed. Default to 1000.
Returns
-------
list of 2-tuples
Each element of the list represents a fold and is a 2-tuple (train_set, val_set).
"""
assert k >= 2, 'Expect the number of folds to be no smaller than 2, got {:d}'.format(k)
molecules = prepare_mols(dataset, mols, sanitize)
scaffold_sets = ScaffoldSplitter.get_ordered_scaffold_sets(
molecules, include_chirality, log_every_n)
# k buckets that form a relatively balanced partition of the dataset
index_buckets = [[] for _ in range(k)]
for group_indices in scaffold_sets:
bucket_chosen = int(np.argmin([len(bucket) for bucket in index_buckets]))
index_buckets[bucket_chosen].extend(group_indices)
all_folds = []
for i in range(k):
if log_every_n is not None:
print('Processing fold {:d}/{:d}'.format(i + 1, k))
train_indices = list(chain.from_iterable(index_buckets[:i] + index_buckets[i + 1:]))
val_indices = index_buckets[i]
all_folds.append((Subset(dataset, train_indices), Subset(dataset, val_indices)))
return all_folds
class SingleTaskStratifiedSplitter(object):
"""Splits the dataset by stratification on a single task.
We sort the molecules based on their label values for a task and then repeatedly
take buckets of datapoints to augment the training, validation and test subsets.
"""
@staticmethod
def train_val_test_split(dataset, labels, task_id, frac_train=0.8, frac_val=0.1,
frac_test=0.1, bucket_size=10, random_state=None):
"""Split the dataset into training, validation and test subsets as stated above.
Parameters
----------
dataset
We assume ``len(dataset)`` gives the size for the dataset, ``dataset[i]``
gives the ith datapoint and ``dataset.smiles[i]`` gives the SMILES for the
ith datapoint.
labels : tensor of shape (N, T)
Dataset labels all tasks. N for the number of datapoints and T for the number
of tasks.
task_id : int
Index for the task.
frac_train : float
Fraction of data to use for training. By default, we set this to be 0.8, i.e.
80% of the dataset is used for training.
frac_val : float
Fraction of data to use for validation. By default, we set this to be 0.1, i.e.
10% of the dataset is used for validation.
frac_test : float
Fraction of data to use for test. By default, we set this to be 0.1, i.e.
10% of the dataset is used for test.
bucket_size : int
Size of bucket of datapoints. Default to 10.
random_state : None, int or array_like, optional
Random seed used to initialize the pseudo-random number generator.
Can be any integer between 0 and 2**32 - 1 inclusive, an array
(or other sequence) of such integers, or None (the default).
If seed is None, then RandomState will try to read data from /dev/urandom
(or the Windows analogue) if available or seed from the clock otherwise.
Returns
-------
list of length 3
Subsets for training, validation and test, which are all :class:`Subset` instances.
"""
train_val_test_sanity_check(frac_train, frac_val, frac_test)
if random_state is not None:
np.random.seed(random_state)
if not isinstance(labels, np.ndarray):
labels = F.asnumpy(labels)
task_labels = labels[:, task_id]
sorted_indices = np.argsort(task_labels)
train_bucket_cutoff = int(np.round(frac_train * bucket_size))
val_bucket_cutoff = int(np.round(frac_val * bucket_size)) + train_bucket_cutoff
train_indices, val_indices, test_indices = [], [], []
while sorted_indices.shape[0] >= bucket_size:
current_batch, sorted_indices = np.split(sorted_indices, [bucket_size])
shuffled = np.random.permutation(range(bucket_size))
train_indices.extend(
current_batch[shuffled[:train_bucket_cutoff]].tolist())
val_indices.extend(
current_batch[shuffled[train_bucket_cutoff:val_bucket_cutoff]].tolist())
test_indices.extend(
current_batch[shuffled[val_bucket_cutoff:]].tolist())
# Place rest samples in the training set.
train_indices.extend(sorted_indices.tolist())
return [Subset(dataset, train_indices),
Subset(dataset, val_indices),
Subset(dataset, test_indices)]
@staticmethod
def k_fold_split(dataset, labels, task_id, k=5, log=True):
"""Sort molecules based on their label values for a task and then split them
for k-fold cross validation by taking consecutive chunks.
Parameters
----------
dataset
We assume ``len(dataset)`` gives the size for the dataset, ``dataset[i]``
gives the ith datapoint and ``dataset.smiles[i]`` gives the SMILES for the
ith datapoint.
labels : tensor of shape (N, T)
Dataset labels all tasks. N for the number of datapoints and T for the number
of tasks.
task_id : int
Index for the task.
k : int
Number of folds to use and should be no smaller than 2. Default to be 5.
log : bool
Whether to print a message at the start of preparing each fold.
Returns
-------
list of 2-tuples
Each element of the list represents a fold and is a 2-tuple (train_set, val_set).
"""
if not isinstance(labels, np.ndarray):
labels = F.asnumpy(labels)
task_labels = labels[:, task_id]
sorted_indices = np.argsort(task_labels).tolist()
return base_k_fold_split(partial(indices_split, indices=sorted_indices), dataset, k, log)
# Binding Affinity Prediction
## Datasets
- **PDBBind**: The PDBBind dataset in MoleculeNet [1] processed from the PDBBind database. The PDBBind
database consists of experimentally measured binding affinities for bio-molecular complexes [2], [3].
It provides detailed 3D Cartesian coordinates of both ligands and their target proteins derived from
experimental(e.g., X-ray crystallography) measurements. The availability of coordinates of the
protein-ligand complexes permits structure-based featurization that is aware of the protein-ligand
binding geometry. The authors of [1] use the "refined" and "core" subsets of the database [4], more carefully
processed for data artifacts, as additional benchmarking targets.
## Models
- **Atomic Convolutional Networks (ACNN)** [5]: Constructs nearest neighbor graphs separately for the ligand, protein and complex
based on the 3D coordinates of the atoms and predicts the binding free energy.
## Usage
Use `main.py` with arguments
```
-m {ACNN}, Model to use
-d {PDBBind_core_pocket_random, PDBBind_core_pocket_scaffold, PDBBind_core_pocket_stratified,
PDBBind_core_pocket_temporal, PDBBind_refined_pocket_random, PDBBind_refined_pocket_scaffold,
PDBBind_refined_pocket_stratified, PDBBind_refined_pocket_temporal}, dataset and splitting method to use
```
## Performance
### PDBBind
#### ACNN
| Subset | Splitting Method | Test MAE | Test R2 |
| ------- | ---------------- | -------- | ------- |
| Core | Random | 1.7688 | 0.1511 |
| Core | Scaffold | 2.5420 | 0.1471 |
| Core | Stratified | 1.7419 | 0.1520 |
| Core | Temporal | 1.9543 | 0.1640 |
| Refined | Random | 1.1948 | 0.4373 |
| Refined | Scaffold | 1.4021 | 0.2086 |
| Refined | Stratified | 1.6376 | 0.3050 |
| Refined | Temporal | 1.2457 | 0.3438 |
## Speed
### ACNN
Comparing to the [DeepChem's implementation](https://github.com/joegomes/deepchem/tree/acdc), we achieve a speedup by
roughly 3.3 for training time per epoch (from 1.40s to 0.42s). If we do not care about
randomness introduced by some kernel optimization, we can achieve a speedup by roughly 4.4 (from 1.40s to 0.32s).
## References
[1] Wu et al. (2017) MoleculeNet: a benchmark for molecular machine learning. *Chemical Science* 9, 513-530.
[2] Wang et al. (2004) The PDBbind database: collection of binding affinities for protein-ligand complexes
with known three-dimensional structures. *J Med Chem* 3;47(12):2977-80.
[3] Wang et al. (2005) The PDBbind database: methodologies and updates. *J Med Chem* 16;48(12):4111-9.
[4] Liu et al. (2015) PDB-wide collection of binding data: current status of the PDBbind database. *Bioinformatics* 1;31(3):405-12.
[5] Gomes et al. (2017) Atomic Convolutional Networks for Predicting Protein-Ligand Binding Affinity. *arXiv preprint arXiv:1703.10603*.
import numpy as np
import torch
ACNN_PDBBind_core_pocket_random = {
'dataset': 'PDBBind',
'subset': 'core',
'load_binding_pocket': True,
'random_seed': 123,
'frac_train': 0.8,
'frac_val': 0.,
'frac_test': 0.2,
'batch_size': 24,
'shuffle': False,
'hidden_sizes': [32, 32, 16],
'weight_init_stddevs': [1. / float(np.sqrt(32)), 1. / float(np.sqrt(32)),
1. / float(np.sqrt(16)), 0.01],
'dropouts': [0., 0., 0.],
'atomic_numbers_considered': torch.tensor([
1., 6., 7., 8., 9., 11., 12., 15., 16., 17., 20., 25., 30., 35., 53.]),
'radial': [[12.0], [0.0, 4.0, 8.0], [4.0]],
'lr': 0.001,
'num_epochs': 120,
'metrics': ['r2', 'mae'],
'split': 'random'
}
ACNN_PDBBind_core_pocket_scaffold = {
'dataset': 'PDBBind',
'subset': 'core',
'load_binding_pocket': True,
'random_seed': 123,
'frac_train': 0.8,
'frac_val': 0.,
'frac_test': 0.2,
'batch_size': 24,
'shuffle': False,
'hidden_sizes': [32, 32, 16],
'weight_init_stddevs': [1. / float(np.sqrt(32)), 1. / float(np.sqrt(32)),
1. / float(np.sqrt(16)), 0.01],
'dropouts': [0., 0., 0.],
'atomic_numbers_considered': torch.tensor([
1., 6., 7., 8., 9., 11., 12., 15., 16., 17., 20., 25., 30., 35., 53.]),
'radial': [[12.0], [0.0, 4.0, 8.0], [4.0]],
'lr': 0.001,
'num_epochs': 170,
'metrics': ['r2', 'mae'],
'split': 'scaffold'
}
ACNN_PDBBind_core_pocket_stratified = {
'dataset': 'PDBBind',
'subset': 'core',
'load_binding_pocket': True,
'random_seed': 123,
'frac_train': 0.8,
'frac_val': 0.,
'frac_test': 0.2,
'batch_size': 24,
'shuffle': False,
'hidden_sizes': [32, 32, 16],
'weight_init_stddevs': [1. / float(np.sqrt(32)), 1. / float(np.sqrt(32)),
1. / float(np.sqrt(16)), 0.01],
'dropouts': [0., 0., 0.],
'atomic_numbers_considered': torch.tensor([
1., 6., 7., 8., 9., 11., 12., 15., 16., 17., 20., 25., 30., 35., 53.]),
'radial': [[12.0], [0.0, 4.0, 8.0], [4.0]],
'lr': 0.001,
'num_epochs': 110,
'metrics': ['r2', 'mae'],
'split': 'stratified'
}
ACNN_PDBBind_core_pocket_temporal = {
'dataset': 'PDBBind',
'subset': 'core',
'load_binding_pocket': True,
'random_seed': 123,
'frac_train': 0.8,
'frac_val': 0.,
'frac_test': 0.2,
'batch_size': 24,
'shuffle': False,
'hidden_sizes': [32, 32, 16],
'weight_init_stddevs': [1. / float(np.sqrt(32)), 1. / float(np.sqrt(32)),
1. / float(np.sqrt(16)), 0.01],
'dropouts': [0., 0., 0.],
'atomic_numbers_considered': torch.tensor([
1., 6., 7., 8., 9., 11., 12., 15., 16., 17., 20., 25., 30., 35., 53.]),
'radial': [[12.0], [0.0, 4.0, 8.0], [4.0]],
'lr': 0.001,
'num_epochs': 80,
'metrics': ['r2', 'mae'],
'split': 'temporal'
}
ACNN_PDBBind_refined_pocket_random = {
'dataset': 'PDBBind',
'subset': 'refined',
'load_binding_pocket': True,
'random_seed': 123,
'frac_train': 0.8,
'frac_val': 0.,
'frac_test': 0.2,
'batch_size': 24,
'shuffle': False,
'hidden_sizes': [128, 128, 64],
'weight_init_stddevs': [0.125, 0.125, 0.177, 0.01],
'dropouts': [0.4, 0.4, 0.],
'atomic_numbers_considered': torch.tensor([
1., 6., 7., 8., 9., 11., 12., 15., 16., 17., 19., 20., 25., 26., 27., 28.,
29., 30., 34., 35., 38., 48., 53., 55., 80.]),
'radial': [[12.0], [0.0, 2.0, 4.0, 6.0, 8.0], [4.0]],
'lr': 0.001,
'num_epochs': 200,
'metrics': ['r2', 'mae'],
'split': 'random'
}
ACNN_PDBBind_refined_pocket_scaffold = {
'dataset': 'PDBBind',
'subset': 'refined',
'load_binding_pocket': True,
'random_seed': 123,
'frac_train': 0.8,
'frac_val': 0.,
'frac_test': 0.2,
'batch_size': 24,
'shuffle': False,
'hidden_sizes': [128, 128, 64],
'weight_init_stddevs': [0.125, 0.125, 0.177, 0.01],
'dropouts': [0.4, 0.4, 0.],
'atomic_numbers_considered': torch.tensor([
1., 6., 7., 8., 9., 11., 12., 15., 16., 17., 19., 20., 25., 26., 27., 28.,
29., 30., 34., 35., 38., 48., 53., 55., 80.]),
'radial': [[12.0], [0.0, 2.0, 4.0, 6.0, 8.0], [4.0]],
'lr': 0.001,
'num_epochs': 350,
'metrics': ['r2', 'mae'],
'split': 'scaffold'
}
ACNN_PDBBind_refined_pocket_stratified = {
'dataset': 'PDBBind',
'subset': 'refined',
'load_binding_pocket': True,
'random_seed': 123,
'frac_train': 0.8,
'frac_val': 0.,
'frac_test': 0.2,
'batch_size': 24,
'shuffle': False,
'hidden_sizes': [128, 128, 64],
'weight_init_stddevs': [0.125, 0.125, 0.177, 0.01],
'dropouts': [0.4, 0.4, 0.],
'atomic_numbers_considered': torch.tensor([
1., 6., 7., 8., 9., 11., 12., 15., 16., 17., 19., 20., 25., 26., 27., 28.,
29., 30., 34., 35., 38., 48., 53., 55., 80.]),
'radial': [[12.0], [0.0, 2.0, 4.0, 6.0, 8.0], [4.0]],
'lr': 0.001,
'num_epochs': 400,
'metrics': ['r2', 'mae'],
'split': 'stratified'
}
ACNN_PDBBind_refined_pocket_temporal = {
'dataset': 'PDBBind',
'subset': 'refined',
'load_binding_pocket': True,
'random_seed': 123,
'frac_train': 0.8,
'frac_val': 0.,
'frac_test': 0.2,
'batch_size': 24,
'shuffle': False,
'hidden_sizes': [128, 128, 64],
'weight_init_stddevs': [0.125, 0.125, 0.177, 0.01],
'dropouts': [0.4, 0.4, 0.],
'atomic_numbers_considered': torch.tensor([
1., 6., 7., 8., 9., 11., 12., 15., 16., 17., 19., 20., 25., 26., 27., 28.,
29., 30., 34., 35., 38., 48., 53., 55., 80.]),
'radial': [[12.0], [0.0, 2.0, 4.0, 6.0, 8.0], [4.0]],
'lr': 0.001,
'num_epochs': 350,
'metrics': ['r2', 'mae'],
'split': 'temporal'
}
experiment_configures = {
'ACNN_PDBBind_core_pocket_random': ACNN_PDBBind_core_pocket_random,
'ACNN_PDBBind_core_pocket_scaffold': ACNN_PDBBind_core_pocket_scaffold,
'ACNN_PDBBind_core_pocket_stratified': ACNN_PDBBind_core_pocket_stratified,
'ACNN_PDBBind_core_pocket_temporal': ACNN_PDBBind_core_pocket_temporal,
'ACNN_PDBBind_refined_pocket_random': ACNN_PDBBind_refined_pocket_random,
'ACNN_PDBBind_refined_pocket_scaffold': ACNN_PDBBind_refined_pocket_scaffold,
'ACNN_PDBBind_refined_pocket_stratified': ACNN_PDBBind_refined_pocket_stratified,
'ACNN_PDBBind_refined_pocket_temporal': ACNN_PDBBind_refined_pocket_temporal
}
def get_exp_configure(exp_name):
return experiment_configures[exp_name]
import torch
import torch.nn as nn
from dgllife.utils.eval import Meter
from torch.utils.data import DataLoader
from utils import set_random_seed, load_dataset, collate, load_model
def update_msg_from_scores(msg, scores):
for metric, score in scores.items():
msg += ', {} {:.4f}'.format(metric, score)
return msg
def run_a_train_epoch(args, epoch, model, data_loader,
loss_criterion, optimizer):
model.train()
train_meter = Meter(args['train_mean'], args['train_std'])
epoch_loss = 0
for batch_id, batch_data in enumerate(data_loader):
indices, ligand_mols, protein_mols, bg, labels = batch_data
labels, bg = labels.to(args['device']), bg.to(args['device'])
prediction = model(bg)
loss = loss_criterion(prediction, (labels - args['train_mean']) / args['train_std'])
epoch_loss += loss.data.item() * len(indices)
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_meter.update(prediction, labels)
avg_loss = epoch_loss / len(data_loader.dataset)
total_scores = {metric: train_meter.compute_metric(metric, 'mean')
for metric in args['metrics']}
msg = 'epoch {:d}/{:d}, training | loss {:.4f}'.format(
epoch + 1, args['num_epochs'], avg_loss)
msg = update_msg_from_scores(msg, total_scores)
print(msg)
def run_an_eval_epoch(args, model, data_loader):
model.eval()
eval_meter = Meter(args['train_mean'], args['train_std'])
with torch.no_grad():
for batch_id, batch_data in enumerate(data_loader):
indices, ligand_mols, protein_mols, bg, labels = batch_data
labels, bg = labels.to(args['device']), bg.to(args['device'])
prediction = model(bg)
eval_meter.update(prediction, labels)
total_scores = {metric: eval_meter.compute_metric(metric, 'mean')
for metric in args['metrics']}
return total_scores
def main(args):
args['device'] = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
set_random_seed(args['random_seed'])
dataset, train_set, test_set = load_dataset(args)
args['train_mean'] = train_set.labels_mean.to(args['device'])
args['train_std'] = train_set.labels_std.to(args['device'])
train_loader = DataLoader(dataset=train_set,
batch_size=args['batch_size'],
shuffle=False,
collate_fn=collate)
test_loader = DataLoader(dataset=test_set,
batch_size=args['batch_size'],
shuffle=True,
collate_fn=collate)
model = load_model(args)
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=args['lr'])
model.to(args['device'])
for epoch in range(args['num_epochs']):
run_a_train_epoch(args, epoch, model, train_loader, loss_fn, optimizer)
test_scores = run_an_eval_epoch(args, model, test_loader)
test_msg = update_msg_from_scores('test results', test_scores)
print(test_msg)
if __name__ == '__main__':
import argparse
from configure import get_exp_configure
parser = argparse.ArgumentParser(description='Protein-Ligand Binding Affinity Prediction')
parser.add_argument('-m', '--model', type=str, choices=['ACNN'],
help='Model to use')
parser.add_argument('-d', '--dataset', type=str,
choices=['PDBBind_core_pocket_random', 'PDBBind_core_pocket_scaffold',
'PDBBind_core_pocket_stratified', 'PDBBind_core_pocket_temporal',
'PDBBind_refined_pocket_random', 'PDBBind_refined_pocket_scaffold',
'PDBBind_refined_pocket_stratified', 'PDBBind_refined_pocket_temporal'],
help='Dataset to use')
args = parser.parse_args().__dict__
args['exp'] = '_'.join([args['model'], args['dataset']])
args.update(get_exp_configure(args['exp']))
main(args)
import dgl
import numpy as np
import random
import torch
from dgl.data.utils import Subset
from dgllife.data import PDBBind
from dgllife.model import ACNN
from dgllife.utils import RandomSplitter, ScaffoldSplitter, SingleTaskStratifiedSplitter
from itertools import accumulate
def set_random_seed(seed=0):
"""Set random seed.
Parameters
----------
seed : int
Random seed to use. Default to 0.
"""
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
def load_dataset(args):
"""Load the dataset.
Parameters
----------
args : dict
Input arguments.
Returns
-------
dataset
Full dataset.
train_set
Train subset of the dataset.
val_set
Validation subset of the dataset.
"""
assert args['dataset'] in ['PDBBind'], 'Unexpected dataset {}'.format(args['dataset'])
if args['dataset'] == 'PDBBind':
dataset = PDBBind(subset=args['subset'],
load_binding_pocket=args['load_binding_pocket'],
zero_padding=True)
# No validation set is used and frac_val = 0.
if args['split'] == 'random':
train_set, _, test_set = RandomSplitter.train_val_test_split(
dataset,
frac_train=args['frac_train'],
frac_val=args['frac_val'],
frac_test=args['frac_test'],
random_state=args['random_seed'])
elif args['split'] == 'scaffold':
train_set, _, test_set = ScaffoldSplitter.train_val_test_split(
dataset,
mols=dataset.ligand_mols,
sanitize=False,
frac_train=args['frac_train'],
frac_val=args['frac_val'],
frac_test=args['frac_test'])
elif args['split'] == 'stratified':
train_set, _, test_set = SingleTaskStratifiedSplitter.train_val_test_split(
dataset,
labels=dataset.labels,
task_id=0,
frac_train=args['frac_train'],
frac_val=args['frac_val'],
frac_test=args['frac_test'],
random_state=args['random_seed'])
elif args['split'] == 'temporal':
years = dataset.df['release_year'].values.astype(np.float32)
indices = np.argsort(years).tolist()
frac_list = np.array([args['frac_train'], args['frac_val'], args['frac_test']])
num_data = len(dataset)
lengths = (num_data * frac_list).astype(int)
lengths[-1] = num_data - np.sum(lengths[:-1])
train_set, val_set, test_set = [
Subset(dataset, list(indices[offset - length:offset]))
for offset, length in zip(accumulate(lengths), lengths)]
else:
raise ValueError('Expect the splitting method '
'to be "random" or "scaffold", got {}'.format(args['split']))
train_labels = torch.stack([train_set.dataset.labels[i] for i in train_set.indices])
train_set.labels_mean = train_labels.mean(dim=0)
train_set.labels_std = train_labels.std(dim=0)
return dataset, train_set, test_set
def collate(data):
indices, ligand_mols, protein_mols, graphs, labels = map(list, zip(*data))
bg = dgl.batch_hetero(graphs)
for nty in bg.ntypes:
bg.set_n_initializer(dgl.init.zero_initializer, ntype=nty)
for ety in bg.canonical_etypes:
bg.set_e_initializer(dgl.init.zero_initializer, etype=ety)
labels = torch.stack(labels, dim=0)
return indices, ligand_mols, protein_mols, bg, labels
def load_model(args):
assert args['model'] in ['ACNN'], 'Unexpected model {}'.format(args['model'])
if args['model'] == 'ACNN':
model = ACNN(hidden_sizes=args['hidden_sizes'],
weight_init_stddevs=args['weight_init_stddevs'],
dropouts=args['dropouts'],
features_to_use=args['atomic_numbers_considered'],
radial=args['radial'])
return model
# Learning Deep Generative Models of Graphs (DGMG)
Yujia Li, Oriol Vinyals, Chris Dyer, Razvan Pascanu, and Peter Battaglia.
Learning Deep Generative Models of Graphs. *arXiv preprint arXiv:1803.03324*, 2018.
DGMG generates graphs by progressively adding nodes and edges as below:
![](https://user-images.githubusercontent.com/19576924/48605003-7f11e900-e9b6-11e8-8880-87362348e154.png)
For molecules, the nodes are atoms and the edges are bonds.
**Goal**: Given a set of real molecules, we want to learn the distribution of them and get new molecules
with similar properties. See the `Evaluation` section for more details.
## Dataset
### Preprocessing
With our implementation, this model has several limitations:
1. Information about protonation and chirality are ignored during generation
2. Molecules consisting of `[N+]`, `[O-]`, etc. cannot be generated.
For example, the model can only generate `O=C1NC(=S)NC(=O)C1=CNC1=CC=C(N(=O)O)C=C1O` from
`O=C1NC(=S)NC(=O)C1=CNC1=CC=C([N+](=O)[O-])C=C1O` even with the correct decisions.
To avoid issues about validity and novelty, we filter out these molecules from the dataset.
### ChEMBL
The authors use the [ChEMBL database](https://www.ebi.ac.uk/chembl/). Since they
did not release the code, we use a subset from [Olivecrona et al.](https://github.com/MarcusOlivecrona/REINVENT),
another work on generative modeling.
The authors restrict their dataset to molecules with at most 20 heavy atoms, and used a training/validation
split of 130, 830/26, 166 examples each. We use the same split but need to relax 20 to 23 as we are using
a different subset.
### ZINC
After the pre-processing, we are left with 232464 molecules for training and 5000 molecules for validation.
## Usage
### Training
Training auto-regressive generative models tends to be very slow. According to the authors, they use multiprocess to
speed up training and gpu does not give much speed advantage. We follow their approach and perform multiprocess cpu
training.
To start training, use `train.py` with required arguments
```
-d DATASET, dataset to use (default: None), built-in support exists for ChEMBL, ZINC
-o {random,canonical}, order to generate graphs (default: None)
```
and optional arguments
```
-s SEED, random seed (default: 0)
-np NUM_PROCESSES, number of processes to use (default: 32)
```
Even though multiprocess yields a significant speedup comparing to a single process, the training can still take a long
time (several days). An epoch of training and validation can take up to one hour and a half on our machine. If not
necessary, we recommend users use our pre-trained models.
Meanwhile, we make a checkpoint of our model whenever there is a performance improvement on the validation set so you
do not need to wait until the training terminates.
All training results can be found in `training_results`.
#### Dataset configuration
You can also use your own dataset with additional arguments
```
-tf TRAIN_FILE, Path to a file with one SMILES a line for training
data. This is only necessary if you want to use a new
dataset. (default: None)
-vf VAL_FILE, Path to a file with one SMILES a line for validation
data. This is only necessary if you want to use a new
dataset. (default: None)
```
#### Monitoring
We can monitor the training process with tensorboard as below:
![](https://s3.us-west-2.amazonaws.com/dgl-data/dgllife/dgmg/tensorboard.png)
To use tensorboard, you need to install [tensorboardX](https://github.com/lanpa/tensorboardX) and
[TensorFlow](https://www.tensorflow.org/). You can lunch tensorboard with `tensorboard --logdir=.`
If you are training on a remote server, you can still use it with:
1. Launch it on the remote server with `tensorboard --logdir=. --port=A`
2. In the terminal of your local machine, type `ssh -NfL localhost:B:localhost:A username@your_remote_host_name`
3. Go to the address `localhost:B` in your browser
### Evaluation
To start evaluation, use `eval.py` with required arguments
```
-d DATASET, dataset to use (default: None), built-in support exists for ChEMBL, ZINC
-o {random,canonical}, order to generate graphs, used for naming evaluation directory (default: None)
-p MODEL_PATH, path to saved model (default: None). This is not needed if you want to use pretrained models.
-pr, Whether to use a pre-trained model (default: False)
```
and optional arguments
```
-s SEED, random seed (default: 0)
-ns NUM_SAMPLES, Number of molecules to generate (default: 100000)
-mn MAX_NUM_STEPS, Max number of steps allowed in generated molecules to
ensure termination (default: 400)
-np NUM_PROCESSES, number of processes to use (default: 32)
-gt GENERATION_TIME, max time (seconds) allowed for generation with
multiprocess (default: 600)
```
All evaluation results can be found in `eval_results`.
After the evaluation, 100000 molecules will be generated and stored in `generated_smiles.txt` under `eval_results`
directory, with three statistics logged in `generation_stats.txt` under `eval_results`:
1. `Validity among all` gives the percentage of molecules that are valid
2. `Uniqueness among valid ones` gives the percentage of valid molecules that are unique
3. `Novelty among unique ones` gives the percentage of unique valid molecules that are novel (not seen in training data)
We also provide a jupyter notebook where you can visualize the generated molecules
![](https://s3.us-west-2.amazonaws.com/dgl-data/dgllife/dgmg/DGMG_ZINC_canonical_vis.png)
and compare their property distributions against the training molecule property distributions
![](https://s3.us-west-2.amazonaws.com/dgl-data/dgllife/dgmg/DGMG_ZINC_canonical_dist.png)
Download it with `wget https://s3.us-west-2.amazonaws.com/dgl-data/dgllife/dgmg/eval_jupyter.ipynb` from the s3
bucket in U.S. or `wget https://s3.cn-north-1.amazonaws.com.cn/dgl-data/dgllife/dgmg/eval_jupyter.ipynb` from the
s3 bucket in China.
### Pre-trained models
Below gives the statistics of pre-trained models. With random order, the training becomes significantly more difficult
as we now have `N^2` data points with `N` molecules.
| Pre-trained model | % valid | % unique among valid | % novel among unique |
| ------------------ | ------- | -------------------- | -------------------- |
| `ChEMBL_canonical` | 78.80 | 99.19 | 98.60 |
| `ChEMBL_random` | 29.09 | 99.87 | 100.00 |
| `ZINC_canonical` | 74.60 | 99.87 | 99.87 |
| `ZINC_random` | 12.37 | 99.38 | 100.00 |
import os
import pickle
import shutil
import torch
from dgllife.model import DGMG, load_pretrained
from utils import MoleculeDataset, set_random_seed, download_data,\
mkdir_p, summarize_molecules, get_unique_smiles, get_novel_smiles
def generate_and_save(log_dir, num_samples, max_num_steps, model):
with open(os.path.join(log_dir, 'generated_smiles.txt'), 'w') as f:
for i in range(num_samples):
with torch.no_grad():
s = model(rdkit_mol=True, max_num_steps=max_num_steps)
f.write(s + '\n')
def prepare_for_evaluation(rank, args):
worker_seed = args['seed'] + rank * 10000
set_random_seed(worker_seed)
torch.set_num_threads(1)
# Setup dataset and data loader
dataset = MoleculeDataset(args['dataset'], subset_id=rank, n_subsets=args['num_processes'])
# Initialize model
if not args['pretrained']:
model = DGMG(atom_types=dataset.atom_types,
bond_types=dataset.bond_types,
node_hidden_size=args['node_hidden_size'],
num_prop_rounds=args['num_propagation_rounds'], dropout=args['dropout'])
model.load_state_dict(torch.load(args['model_path'])['model_state_dict'])
else:
model = load_pretrained('_'.join(['DGMG', args['dataset'], args['order']]), log=False)
model.eval()
worker_num_samples = args['num_samples'] // args['num_processes']
if rank == args['num_processes'] - 1:
worker_num_samples += args['num_samples'] % args['num_processes']
worker_log_dir = os.path.join(args['log_dir'], str(rank))
mkdir_p(worker_log_dir, log=False)
generate_and_save(worker_log_dir, worker_num_samples, args['max_num_steps'], model)
def remove_worker_tmp_dir(args):
for rank in range(args['num_processes']):
worker_path = os.path.join(args['log_dir'], str(rank))
try:
shutil.rmtree(worker_path)
except OSError:
print('Directory {} does not exist!'.format(worker_path))
def aggregate_and_evaluate(args):
print('Merging generated SMILES into a single file...')
smiles = []
for rank in range(args['num_processes']):
with open(os.path.join(args['log_dir'], str(rank), 'generated_smiles.txt'), 'r') as f:
rank_smiles = f.read().splitlines()
smiles.extend(rank_smiles)
with open(os.path.join(args['log_dir'], 'generated_smiles.txt'), 'w') as f:
for s in smiles:
f.write(s + '\n')
print('Removing temporary dirs...')
remove_worker_tmp_dir(args)
# Summarize training molecules
print('Summarizing training molecules...')
train_file = '_'.join([args['dataset'], 'DGMG_train.txt'])
if not os.path.exists(train_file):
download_data(args['dataset'], train_file)
with open(train_file, 'r') as f:
train_smiles = f.read().splitlines()
train_summary = summarize_molecules(train_smiles, args['num_processes'])
with open(os.path.join(args['log_dir'], 'train_summary.pickle'), 'wb') as f:
pickle.dump(train_summary, f)
# Summarize generated molecules
print('Summarizing generated molecules...')
generation_summary = summarize_molecules(smiles, args['num_processes'])
with open(os.path.join(args['log_dir'], 'generation_summary.pickle'), 'wb') as f:
pickle.dump(generation_summary, f)
# Stats computation
print('Preparing generation statistics...')
valid_generated_smiles = generation_summary['smile']
unique_generated_smiles = get_unique_smiles(valid_generated_smiles)
unique_train_smiles = get_unique_smiles(train_summary['smile'])
novel_generated_smiles = get_novel_smiles(unique_generated_smiles, unique_train_smiles)
with open(os.path.join(args['log_dir'], 'generation_stats.txt'), 'w') as f:
f.write('Total number of generated molecules: {:d}\n'.format(len(smiles)))
f.write('Validity among all: {:.4f}\n'.format(
len(valid_generated_smiles) / len(smiles)))
f.write('Uniqueness among valid ones: {:.4f}\n'.format(
len(unique_generated_smiles) / len(valid_generated_smiles)))
f.write('Novelty among unique ones: {:.4f}\n'.format(
len(novel_generated_smiles) / len(unique_generated_smiles)))
if __name__ == '__main__':
import argparse
import datetime
import time
from rdkit import rdBase
from utils import setup
parser = argparse.ArgumentParser(description='Evaluating DGMG for molecule generation',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
# configure
parser.add_argument('-s', '--seed', type=int, default=0, help='random seed')
# dataset and setting
parser.add_argument('-d', '--dataset',
help='dataset to use')
parser.add_argument('-o', '--order', choices=['random', 'canonical'],
help='order to generate graphs, used for naming evaluation directory')
# log
parser.add_argument('-l', '--log-dir', default='./eval_results',
help='folder to save evaluation results')
parser.add_argument('-p', '--model-path', type=str, default=None,
help='path to saved model')
parser.add_argument('-pr', '--pretrained', action='store_true',
help='Whether to use a pre-trained model')
parser.add_argument('-ns', '--num-samples', type=int, default=100000,
help='Number of molecules to generate')
parser.add_argument('-mn', '--max-num-steps', type=int, default=400,
help='Max number of steps allowed in generated molecules to ensure termination')
# multi-process
parser.add_argument('-np', '--num-processes', type=int, default=32,
help='number of processes to use')
parser.add_argument('-gt', '--generation-time', type=int, default=600,
help='max time (seconds) allowed for generation with multiprocess')
args = parser.parse_args()
args = setup(args, train=False)
rdBase.DisableLog('rdApp.error')
t1 = time.time()
if args['num_processes'] == 1:
prepare_for_evaluation(0, args)
else:
import multiprocessing as mp
procs = []
for rank in range(args['num_processes']):
p = mp.Process(target=prepare_for_evaluation, args=(rank, args,))
procs.append(p)
p.start()
while time.time() - t1 <= args['generation_time']:
if any(p.is_alive() for p in procs):
time.sleep(5)
else:
break
else:
print('Timeout, killing all processes.')
for p in procs:
p.terminate()
p.join()
t2 = time.time()
print('It took {} for generation.'.format(
datetime.timedelta(seconds=t2 - t1)))
aggregate_and_evaluate(args)
#
# calculation of synthetic accessibility score as described in:
#
# Estimation of Synthetic Accessibility Score of Drug-like Molecules
# based on Molecular Complexity and Fragment Contributions
# Peter Ertl and Ansgar Schuffenhauer
# Journal of Cheminformatics 1:8 (2009)
# http://www.jcheminf.com/content/1/1/8
#
# several small modifications to the original paper are included
# particularly slightly different formula for marocyclic penalty
# and taking into account also molecule symmetry (fingerprint density)
#
# for a set of 10k diverse molecules the agreement between the original method
# as implemented in PipelinePilot and this implementation is r2 = 0.97
#
# peter ertl & greg landrum, september 2013
#
# A small modification is performed
#
# DGL team, August 2019
#
from __future__ import print_function
import math
import os
from rdkit import Chem
from rdkit.Chem import rdMolDescriptors
from rdkit.six.moves import cPickle
from rdkit.six import iteritems
from dgl.data.utils import download, _get_dgl_url
_fscores = None
def readFragmentScores(name='fpscores'):
import gzip
global _fscores
fname = '{}.pkl.gz'.format(name)
download(_get_dgl_url(os.path.join('dataset', fname)), path=fname)
_fscores = cPickle.load(gzip.open(fname))
outDict = {}
for i in _fscores:
for j in range(1, len(i)):
outDict[i[j]] = float(i[0])
_fscores = outDict
def numBridgeheadsAndSpiro(mol):
nSpiro = rdMolDescriptors.CalcNumSpiroAtoms(mol)
nBridgehead = rdMolDescriptors.CalcNumBridgeheadAtoms(mol)
return nBridgehead, nSpiro
def calculateScore(m):
if _fscores is None:
readFragmentScores()
# fragment score
# 2 is the *radius* of the circular fingerprint
fp = rdMolDescriptors.GetMorganFingerprint(m, 2)
fps = fp.GetNonzeroElements()
score1 = 0.
nf = 0
for bitId, v in iteritems(fps):
nf += v
sfp = bitId
score1 += _fscores.get(sfp, -4) * v
# We add L63 to avoid ZeroDivisionError.
if nf != 0:
score1 /= nf
# features score
nAtoms = m.GetNumAtoms()
nChiralCenters = len(Chem.FindMolChiralCenters(m, includeUnassigned=True))
ri = m.GetRingInfo()
nBridgeheads, nSpiro = numBridgeheadsAndSpiro(m)
nMacrocycles = 0
for x in ri.AtomRings():
if len(x) > 8:
nMacrocycles += 1
sizePenalty = nAtoms**1.005 - nAtoms
stereoPenalty = math.log10(nChiralCenters + 1)
spiroPenalty = math.log10(nSpiro + 1)
bridgePenalty = math.log10(nBridgeheads + 1)
macrocyclePenalty = 0.
# ---------------------------------------
# This differs from the paper, which defines:
# macrocyclePenalty = math.log10(nMacrocycles+1)
# This form generates better results when 2 or more macrocycles are present
if nMacrocycles > 0:
macrocyclePenalty = math.log10(2)
score2 = 0. - sizePenalty - stereoPenalty - \
spiroPenalty - bridgePenalty - macrocyclePenalty
# correction for the fingerprint density
# not in the original publication, added in version 1.1
# to make highly symmetrical molecules easier to synthetise
score3 = 0.
if nAtoms > len(fps):
score3 = math.log(float(nAtoms) / len(fps)) * .5
sascore = score1 + score2 + score3
# need to transform "raw" value into scale between 1 and 10
min = -4.0
max = 2.5
sascore = 11. - (sascore - min + 1) / (max - min) * 9.
# smooth the 10-end
if sascore > 8.:
sascore = 8. + math.log(sascore + 1. - 9.)
if sascore > 10.:
sascore = 10.0
elif sascore < 1.:
sascore = 1.0
return sascore
def processMols(mols):
print('smiles\tName\tsa_score')
for i, m in enumerate(mols):
if m is None:
continue
s = calculateScore(m)
smiles = Chem.MolToSmiles(m)
print(smiles + "\t" + m.GetProp('_Name') + "\t%3f" % s)
if __name__ == '__main__':
import sys, time
t1 = time.time()
readFragmentScores("fpscores")
t2 = time.time()
suppl = Chem.SmilesMolSupplier(sys.argv[1])
t3 = time.time()
processMols(suppl)
t4 = time.time()
print('Reading took %.2f seconds. Calculating took %.2f seconds' % ((t2 - t1), (t4 - t3)),
file=sys.stderr)
#
# Copyright (c) 2013, Novartis Institutes for BioMedical Research Inc.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following
# disclaimer in the documentation and/or other materials provided
# with the distribution.
# * Neither the name of Novartis Institutes for BioMedical Research Inc.
# nor the names of its contributors may be used to endorse or promote
# products derived from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
"""
Learning Deep Generative Models of Graphs
Paper: https://arxiv.org/pdf/1803.03324.pdf
"""
import datetime
import time
import torch
import torch.distributed as dist
from dgllife.model import DGMG
from torch.optim import Adam
from torch.utils.data import DataLoader
from utils import MoleculeDataset, Printer, set_random_seed, synchronize, launch_a_process
def evaluate(epoch, model, data_loader, printer):
model.eval()
batch_size = data_loader.batch_size
total_log_prob = 0
with torch.no_grad():
for i, data in enumerate(data_loader):
log_prob = model(actions=data, compute_log_prob=True).detach()
total_log_prob -= log_prob
if printer is not None:
prob = log_prob.detach().exp()
printer.update(epoch + 1, - log_prob / batch_size, prob / batch_size)
return total_log_prob / len(data_loader)
def main(rank, args):
"""
Parameters
----------
rank : int
Subprocess id
args : dict
Configuration
"""
if rank == 0:
t1 = time.time()
set_random_seed(args['seed'])
# Remove the line below will result in problems for multiprocess
torch.set_num_threads(1)
# Setup dataset and data loader
dataset = MoleculeDataset(args['dataset'], args['order'], ['train', 'val'],
subset_id=rank, n_subsets=args['num_processes'])
# Note that currently the batch size for the loaders should only be 1.
train_loader = DataLoader(dataset.train_set, batch_size=args['batch_size'],
shuffle=True, collate_fn=dataset.collate)
val_loader = DataLoader(dataset.val_set, batch_size=args['batch_size'],
shuffle=True, collate_fn=dataset.collate)
if rank == 0:
try:
from tensorboardX import SummaryWriter
writer = SummaryWriter(args['log_dir'])
except ImportError:
print('If you want to use tensorboard, install tensorboardX with pip.')
writer = None
train_printer = Printer(args['nepochs'], len(dataset.train_set), args['batch_size'], writer)
val_printer = Printer(args['nepochs'], len(dataset.val_set), args['batch_size'])
else:
val_printer = None
# Initialize model
model = DGMG(atom_types=dataset.atom_types,
bond_types=dataset.bond_types,
node_hidden_size=args['node_hidden_size'],
num_prop_rounds=args['num_propagation_rounds'],
dropout=args['dropout'])
if args['num_processes'] == 1:
from utils import Optimizer
optimizer = Optimizer(args['lr'], Adam(model.parameters(), lr=args['lr']))
else:
from utils import MultiProcessOptimizer
optimizer = MultiProcessOptimizer(args['num_processes'], args['lr'],
Adam(model.parameters(), lr=args['lr']))
if rank == 0:
t2 = time.time()
best_val_prob = 0
# Training
for epoch in range(args['nepochs']):
model.train()
if rank == 0:
print('Training')
for i, data in enumerate(train_loader):
log_prob = model(actions=data, compute_log_prob=True)
prob = log_prob.detach().exp()
loss_averaged = - log_prob
prob_averaged = prob
optimizer.backward_and_step(loss_averaged)
if rank == 0:
train_printer.update(epoch + 1, loss_averaged.item(), prob_averaged.item())
synchronize(args['num_processes'])
# Validation
val_log_prob = evaluate(epoch, model, val_loader, val_printer)
if args['num_processes'] > 1:
dist.all_reduce(val_log_prob, op=dist.ReduceOp.SUM)
val_log_prob /= args['num_processes']
# Strictly speaking, the computation of probability here is different from what is
# performed on the training set as we first take an average of log likelihood and then
# take the exponentiation. By Jensen's inequality, the resulting value is then a
# lower bound of the real probabilities.
val_prob = (- val_log_prob).exp().item()
val_log_prob = val_log_prob.item()
if val_prob >= best_val_prob:
if rank == 0:
torch.save({'model_state_dict': model.state_dict()}, args['checkpoint_dir'])
print('Old val prob {:.10f} | new val prob {:.10f} | model saved'.format(best_val_prob, val_prob))
best_val_prob = val_prob
elif epoch >= args['warmup_epochs']:
optimizer.decay_lr()
if rank == 0:
print('Validation')
if writer is not None:
writer.add_scalar('validation_log_prob', val_log_prob, epoch)
writer.add_scalar('validation_prob', val_prob, epoch)
writer.add_scalar('lr', optimizer.lr, epoch)
print('Validation log prob {:.4f} | prob {:.10f}'.format(val_log_prob, val_prob))
synchronize(args['num_processes'])
if rank == 0:
t3 = time.time()
print('It took {} to setup.'.format(datetime.timedelta(seconds=t2 - t1)))
print('It took {} to finish training.'.format(datetime.timedelta(seconds=t3 - t2)))
print('--------------------------------------------------------------------------')
print('On average, an epoch takes {}.'.format(datetime.timedelta(
seconds=(t3 - t2) / args['nepochs'])))
if __name__ == '__main__':
import argparse
from utils import setup
parser = argparse.ArgumentParser(description='Training DGMG for molecule generation',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
# configure
parser.add_argument('-s', '--seed', type=int, default=0, help='random seed')
parser.add_argument('-w', '--warmup-epochs', type=int, default=10,
help='Number of epochs where no lr decay is performed.')
# dataset and setting
parser.add_argument('-d', '--dataset',
help='dataset to use')
parser.add_argument('-o', '--order', choices=['random', 'canonical'],
help='order to generate graphs')
parser.add_argument('-tf', '--train-file', type=str, default=None,
help='Path to a file with one SMILES a line for training data. '
'This is only necessary if you want to use a new dataset.')
parser.add_argument('-vf', '--val-file', type=str, default=None,
help='Path to a file with one SMILES a line for validation data. '
'This is only necessary if you want to use a new dataset.')
# log
parser.add_argument('-l', '--log-dir', default='./training_results',
help='folder to save info like experiment configuration')
# multi-process
parser.add_argument('-np', '--num-processes', type=int, default=32,
help='number of processes to use')
parser.add_argument('-mi', '--master-ip', type=str, default='127.0.0.1')
parser.add_argument('-mp', '--master-port', type=str, default='12345')
args = parser.parse_args()
args = setup(args, train=True)
if args['num_processes'] == 1:
main(0, args)
else:
mp = torch.multiprocessing.get_context('spawn')
procs = []
for rank in range(args['num_processes']):
procs.append(mp.Process(target=launch_a_process, args=(rank, args, main), daemon=True))
procs[-1].start()
for p in procs:
p.join()
import datetime
import math
import numpy as np
import os
import pickle
import random
import torch
import torch.distributed as dist
import torch.nn as nn
from collections import defaultdict
from datetime import timedelta
from dgl.data.utils import download, _get_dgl_url
from dgllife.model.model_zoo.dgmg import MoleculeEnv
from multiprocessing import Pool
from pprint import pprint
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem.Crippen import MolLogP
from rdkit.Chem.QED import qed
from torch.utils.data import Dataset
from sascorer import calculateScore
########################################################################################################################
# configuration #
########################################################################################################################
def mkdir_p(path, log=True):
"""Create a directory for the specified path.
Parameters
----------
path : str
Path name
log : bool
Whether to print result for directory creation
"""
import errno
try:
os.makedirs(path)
if log:
print('Created directory {}'.format(path))
except OSError as exc:
if exc.errno == errno.EEXIST and os.path.isdir(path) and log:
print('Directory {} already exists.'.format(path))
else:
raise
def get_date_postfix():
"""Get a date based postfix for directory name.
Returns
-------
post_fix : str
"""
dt = datetime.datetime.now()
post_fix = '{}_{:02d}-{:02d}-{:02d}'.format(
dt.date(), dt.hour, dt.minute, dt.second)
return post_fix
def setup_log_dir(args):
"""Name and create directory for logging.
Parameters
----------
args : dict
Configuration
Returns
-------
log_dir : str
Path for logging directory
"""
date_postfix = get_date_postfix()
log_dir = os.path.join(
args['log_dir'],
'{}_{}_{}'.format(args['dataset'], args['order'], date_postfix))
mkdir_p(log_dir)
return log_dir
def save_arg_dict(args, filename='settings.txt'):
"""Save all experiment settings in a file.
Parameters
----------
args : dict
Configuration
filename : str
Name for the file to save settings
"""
def _format_value(v):
if isinstance(v, float):
return '{:.4f}'.format(v)
elif isinstance(v, int):
return '{:d}'.format(v)
else:
return '{}'.format(v)
save_path = os.path.join(args['log_dir'], filename)
with open(save_path, 'w') as f:
for key, value in args.items():
f.write('{}\t{}\n'.format(key, _format_value(value)))
print('Saved settings to {}'.format(save_path))
def configure(args):
"""Use default hyperparameters.
Parameters
----------
args : dict
Old configuration
Returns
-------
args : dict
Updated configuration
"""
configure = {
'node_hidden_size': 128,
'num_propagation_rounds': 2,
'lr': 1e-4,
'dropout': 0.2,
'nepochs': 400,
'batch_size': 1,
}
args.update(configure)
return args
def set_random_seed(seed):
"""Fix random seed for reproducible results.
Parameters
----------
seed : int
Random seed to use.
"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
def setup_dataset(args):
"""Dataset setup
For unsupported dataset, we need to perform data preprocessing.
Parameters
----------
args : dict
Configuration
"""
if args['dataset'] in ['ChEMBL', 'ZINC']:
print('Built-in support for dataset {} exists.'.format(args['dataset']))
else:
print('Configure for new dataset {}...'.format(args['dataset']))
configure_new_dataset(args['dataset'], args['train_file'], args['val_file'])
def setup(args, train=True):
"""Setup
Parameters
----------
args : argparse.Namespace
Configuration
train : bool
Whether the setup is for training or evaluation
"""
# Convert argparse.Namespace into a dict
args = args.__dict__.copy()
# Dataset
args = configure(args)
# Log
print('Prepare logging directory...')
log_dir = setup_log_dir(args)
args['log_dir'] = log_dir
save_arg_dict(args)
if train:
setup_dataset(args)
args['checkpoint_dir'] = os.path.join(log_dir, 'checkpoint.pth')
pprint(args)
return args
########################################################################################################################
# multi-process #
########################################################################################################################
def synchronize(num_processes):
"""Synchronize all processes.
Parameters
----------
num_processes : int
Number of subprocesses used
"""
if num_processes > 1:
dist.barrier()
def launch_a_process(rank, args, target, minutes=720):
"""Launch a subprocess for training.
Parameters
----------
rank : int
Subprocess id
args : dict
Configuration
target : callable
Target function for the subprocess
minutes : int
Timeout minutes for operations executed against the process group
"""
dist_init_method = 'tcp://{master_ip}:{master_port}'.format(
master_ip=args['master_ip'], master_port=args['master_port'])
dist.init_process_group(backend='gloo',
init_method=dist_init_method,
# If you have a larger dataset, you will need to increase it.
timeout=timedelta(minutes=minutes),
world_size=args['num_processes'],
rank=rank)
assert torch.distributed.get_rank() == rank
target(rank, args)
########################################################################################################################
# optimization #
########################################################################################################################
class Optimizer(nn.Module):
"""Wrapper for optimization
Parameters
----------
lr : float
Initial learning rate
optimizer
model optimizer
"""
def __init__(self, lr, optimizer):
super(Optimizer, self).__init__()
self.lr = lr
self.optimizer = optimizer
self._reset()
def _reset(self):
self.optimizer.zero_grad()
def backward_and_step(self, loss):
"""Backward and update model.
Parameters
----------
loss : torch.tensor consisting of a float only
"""
loss.backward()
self.optimizer.step()
self._reset()
def decay_lr(self, decay_rate=0.99):
"""Decay learning rate.
Parameters
----------
decay_rate : float
Multiply the current learning rate by the decay_rate
"""
self.lr *= decay_rate
for param_group in self.optimizer.param_groups:
param_group['lr'] = self.lr
class MultiProcessOptimizer(Optimizer):
"""Wrapper for optimization with multiprocess
Parameters
----------
n_processes : int
Number of processes used
lr : float
Initial learning rate
optimizer
model optimizer
"""
def __init__(self, n_processes, lr, optimizer):
super(MultiProcessOptimizer, self).__init__(lr=lr, optimizer=optimizer)
self.n_processes = n_processes
def _sync_gradient(self):
"""Average gradients across all subprocesses."""
for param_group in self.optimizer.param_groups:
for p in param_group['params']:
if p.requires_grad and p.grad is not None:
dist.all_reduce(p.grad.data, op=dist.ReduceOp.SUM)
p.grad.data /= self.n_processes
def backward_and_step(self, loss):
"""Backward and update model.
Parameters
----------
loss : torch.tensor consisting of a float only
"""
loss.backward()
self._sync_gradient()
self.optimizer.step()
self._reset()
########################################################################################################################
# data #
########################################################################################################################
def initialize_neuralization_reactions():
"""Reference neuralization reactions
Code adapted from RDKit Cookbook, by Hans de Winter.
"""
patts = (
# Imidazoles
('[n+;H]', 'n'),
# Amines
('[N+;!H0]', 'N'),
# Carboxylic acids and alcohols
('[$([O-]);!$([O-][#7])]', 'O'),
# Thiols
('[S-;X1]', 'S'),
# Sulfonamides
('[$([N-;X2]S(=O)=O)]', 'N'),
# Enamines
('[$([N-;X2][C,N]=C)]', 'N'),
# Tetrazoles
('[n-]', '[n]'),
# Sulfoxides
('[$([S-]=O)]', 'S'),
# Amides
('[$([N-]C=O)]', 'N'),
)
return [(Chem.MolFromSmarts(x), Chem.MolFromSmiles(y, False)) for x, y in patts]
def neutralize_charges(mol, reactions=None):
"""Deprotonation for molecules.
Code adapted from RDKit Cookbook, by Hans de Winter.
DGMG currently cannot generate protonated molecules.
For example, it can only generate
CC(C)(C)CC1CCC[NH+]1Cc1nnc(-c2ccccc2F)o1
from
CC(C)(C)CC1CCCN1Cc1nnc(-c2ccccc2F)o1
even with correct decisions.
Deprotonation is therefore an important step to avoid
false novel molecules.
Parameters
----------
mol : Chem.rdchem.Mol
reactions : list of 2-tuples
Rules for deprotonation
Returns
-------
mol : Chem.rdchem.Mol
Deprotonated molecule
"""
if reactions is None:
reactions = initialize_neuralization_reactions()
for i, (reactant, product) in enumerate(reactions):
while mol.HasSubstructMatch(reactant):
rms = AllChem.ReplaceSubstructs(mol, reactant, product)
mol = rms[0]
return mol
def standardize_mol(mol):
"""Standardize molecule to avoid false novel molecule.
Kekulize and deprotonate molecules to avoid false novel molecules.
In addition to deprotonation, we also kekulize molecules to avoid
explicit Hs in the SMILES. Otherwise we will get false novel molecules
as well. For example, DGMG can only generate
O=S(=O)(NC1=CC=CC(C(F)(F)F)=C1)C1=CNC=N1
from
O=S(=O)(Nc1cccc(C(F)(F)F)c1)c1c[nH]cn1.
One downside is that we remove all explicit aromatic rings and to
explicitly predict aromatic bond might make the learning easier for
the model.
"""
reactions = initialize_neuralization_reactions()
Chem.Kekulize(mol, clearAromaticFlags=True)
mol = neutralize_charges(mol, reactions)
return mol
def smiles_to_standard_mol(s):
"""Convert SMILES to a standard molecule.
Parameters
----------
s : str
SMILES
Returns
-------
Chem.rdchem.Mol
Standardized molecule
"""
mol = Chem.MolFromSmiles(s)
return standardize_mol(mol)
def mol_to_standard_smile(mol):
"""Standardize a molecule and convert it to a SMILES.
Parameters
----------
mol : Chem.rdchem.Mol
Returns
-------
str
SMILES
"""
return Chem.MolToSmiles(standardize_mol(mol))
def get_atom_and_bond_types(smiles, log=True):
"""Identify the atom types and bond types
appearing in this dataset.
Parameters
----------
smiles : list
List of smiles
log : bool
Whether to print the process of pre-processing.
Returns
-------
atom_types : list
E.g. ['C', 'N']
bond_types : list
E.g. [Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE,
Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC]
"""
atom_types = set()
bond_types = set()
n_smiles = len(smiles)
for i, s in enumerate(smiles):
if log:
print('Processing smiles {:d}/{:d}'.format(i + 1, n_smiles))
mol = smiles_to_standard_mol(s)
if mol is None:
continue
for atom in mol.GetAtoms():
a_symbol = atom.GetSymbol()
if a_symbol not in atom_types:
atom_types.add(a_symbol)
for bond in mol.GetBonds():
b_type = bond.GetBondType()
if b_type not in bond_types:
bond_types.add(b_type)
return list(atom_types), list(bond_types)
def eval_decisions(env, decisions):
"""This function mimics the way DGMG generates a molecule and is
helpful for debugging and verification in data preprocessing.
Parameters
----------
env : MoleculeEnv
MDP environment for generating molecules
decisions : list of 2-tuples of int
A decision sequence for generating a molecule
Returns
-------
str
SMILES for the molecule generated with decisions
"""
env.reset(rdkit_mol=True)
t = 0
def whether_to_add_atom(t):
assert decisions[t][0] == 0
atom_type = decisions[t][1]
t += 1
return t, atom_type
def whether_to_add_bond(t):
assert decisions[t][0] == 1
bond_type = decisions[t][1]
t += 1
return t, bond_type
def decide_atom2(t):
assert decisions[t][0] == 2
dst = decisions[t][1]
t += 1
return t, dst
t, atom_type = whether_to_add_atom(t)
while atom_type != len(env.atom_types):
env.add_atom(atom_type)
t, bond_type = whether_to_add_bond(t)
while bond_type != len(env.bond_types):
t, dst = decide_atom2(t)
env.add_bond((env.num_atoms() - 1), dst, bond_type)
t, bond_type = whether_to_add_bond(t)
t, atom_type = whether_to_add_atom(t)
assert t == len(decisions)
return env.get_current_smiles()
def get_DGMG_smile(env, mol):
"""Mimics the reproduced SMILES with DGMG for a molecule.
Given a molecule, we are interested in what SMILES we will
get if we want to generate it with DGMG. This is an important
step to check false novel molecules.
Parameters
----------
env : MoleculeEnv
MDP environment for generating molecules
mol : Chem.rdchem.Mol
A molecule
Returns
-------
canonical_smile : str
SMILES of the generated molecule with a canonical decision sequence
random_smile : str
SMILES of the generated molecule with a random decision sequence
"""
canonical_decisions = env.get_decision_sequence(mol, list(range(mol.GetNumAtoms())))
canonical_smile = eval_decisions(env, canonical_decisions)
order = list(range(mol.GetNumAtoms()))
random.shuffle(order)
random_decisions = env.get_decision_sequence(mol, order)
random_smile = eval_decisions(env, random_decisions)
return canonical_smile, random_smile
def preprocess_dataset(atom_types, bond_types, smiles, max_num_atoms=23):
"""Preprocess the dataset
1. Standardize the SMILES of the dataset
2. Only keep the SMILES that DGMG can reproduce
3. Drop repeated SMILES
Parameters
----------
atom_types : list
The types of atoms appearing in a dataset. E.g. ['C', 'N']
bond_types : list
The types of bonds appearing in a dataset.
E.g. [Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE,
Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC]
Returns
-------
valid_smiles : list of str
SMILES left after preprocessing
"""
valid_smiles = []
env = MoleculeEnv(atom_types, bond_types)
for id, s in enumerate(smiles):
print('Processing {:d}/{:d}'.format(id + 1, len(smiles)))
raw_s = s.strip()
mol = smiles_to_standard_mol(raw_s)
if mol is None:
continue
standard_s = Chem.MolToSmiles(mol)
if (max_num_atoms is not None) and (mol.GetNumAtoms() > max_num_atoms):
continue
canonical_s, random_s = get_DGMG_smile(env, mol)
canonical_mol = Chem.MolFromSmiles(canonical_s)
random_mol = Chem.MolFromSmiles(random_s)
if (standard_s != canonical_s) or (canonical_s != random_s) or (canonical_mol is None) or (random_mol is None):
continue
valid_smiles.append(standard_s)
valid_smiles = list(set(valid_smiles))
return valid_smiles
def download_data(dataset, fname):
"""Download dataset if built-in support exists
Parameters
----------
dataset : str
Dataset name
fname : str
Name of dataset file
"""
if dataset not in ['ChEMBL', 'ZINC']:
# For dataset without built-in support, they should be locally processed.
return
data_path = fname
download(_get_dgl_url(os.path.join('dataset', fname)), path=data_path)
def load_smiles_from_file(f_name):
"""Load dataset into a list of SMILES
Parameters
----------
f_name : str
Path to a file of molecules, where each line of the file
is a molecule in SMILES format.
Returns
-------
smiles : list of str
List of molecules as SMILES
"""
with open(f_name, 'r') as f:
smiles = f.read().splitlines()
return smiles
def write_smiles_to_file(f_name, smiles):
"""Write dataset to a file.
Parameters
----------
f_name : str
Path to create a file of molecules, where each line of the file
is a molecule in SMILES format.
smiles : list of str
List of SMILES
"""
with open(f_name, 'w') as f:
for s in smiles:
f.write(s + '\n')
def configure_new_dataset(dataset, train_file, val_file):
"""Configure for a new dataset.
Parameters
----------
dataset : str
Dataset name
train_file : str
Path to a file with one SMILES a line for training data
val_file : str
Path to a file with one SMILES a line for validation data
"""
assert train_file is not None, 'Expect a file of SMILES for training, got None.'
assert val_file is not None, 'Expect a file of SMILES for validation, got None.'
train_smiles = load_smiles_from_file(train_file)
val_smiles = load_smiles_from_file(val_file)
all_smiles = train_smiles + val_smiles
# Get all atom and bond types in the dataset
path_to_atom_and_bond_types = '_'.join([dataset, 'atom_and_bond_types.pkl'])
if not os.path.exists(path_to_atom_and_bond_types):
atom_types, bond_types = get_atom_and_bond_types(all_smiles)
with open(path_to_atom_and_bond_types, 'wb') as f:
pickle.dump({'atom_types': atom_types, 'bond_types': bond_types}, f)
else:
with open(path_to_atom_and_bond_types, 'rb') as f:
type_info = pickle.load(f)
atom_types = type_info['atom_types']
bond_types = type_info['bond_types']
# Standardize training data
path_to_processed_train_data = '_'.join([dataset, 'DGMG', 'train.txt'])
if not os.path.exists(path_to_processed_train_data):
processed_train_smiles = preprocess_dataset(atom_types, bond_types, train_smiles, None)
write_smiles_to_file(path_to_processed_train_data, processed_train_smiles)
path_to_processed_val_data = '_'.join([dataset, 'DGMG', 'val.txt'])
if not os.path.exists(path_to_processed_val_data):
processed_val_smiles = preprocess_dataset(atom_types, bond_types, val_smiles, None)
write_smiles_to_file(path_to_processed_val_data, processed_val_smiles)
class MoleculeDataset(object):
"""Initialize and split the dataset.
Parameters
----------
dataset : str
Dataset name
order : None or str
Order to extract a decision sequence for generating a molecule. Default to be None.
modes : None or list
List of subsets to use, which can contain 'train', 'val', corresponding to
training and validation. Default to be None.
subset_id : int
With multiprocess training, we partition the training set into multiple subsets and
each process will use one subset only. This subset_id corresponds to subprocess id.
n_subsets : int
With multiprocess training, this corresponds to the number of total subprocesses.
"""
def __init__(self, dataset, order=None, modes=None, subset_id=0, n_subsets=1):
super(MoleculeDataset, self).__init__()
if modes is None:
modes = []
else:
assert order is not None, 'An order should be specified for extracting ' \
'decision sequences.'
assert order in ['random', 'canonical', None], \
"Unexpected order option to get sequences of graph generation decisions"
assert len(set(modes) - {'train', 'val'}) == 0, \
"modes should be a list, representing a subset of ['train', 'val']"
self.dataset = dataset
self.order = order
self.modes = modes
self.subset_id = subset_id
self.n_subsets = n_subsets
self._setup()
def collate(self, samples):
"""PyTorch's approach to batch multiple samples.
For auto-regressive generative models, we process one sample at a time.
Parameters
----------
samples : list
A list of length 1 that consists of decision sequence to generate a molecule.
Returns
-------
list
List of 2-tuples, a decision sequence to generate a molecule
"""
assert len(samples) == 1
return samples[0]
def _create_a_subset(self, smiles):
"""Create a dataset from a subset of smiles.
Parameters
----------
smiles : list of str
List of molecules in SMILES format
"""
# We evenly divide the smiles into multiple susbets with multiprocess
subset_size = len(smiles) // self.n_subsets
return Subset(smiles[self.subset_id * subset_size: (self.subset_id + 1) * subset_size],
self.order, self.env)
def _setup(self):
"""
1. Instantiate an MDP environment for molecule generation
2. Download the dataset, which is a file of SMILES
3. Create subsets for training and validation
"""
if self.dataset == 'ChEMBL':
# For new datasets, get_atom_and_bond_types can be used to
# identify the atom and bond types in them.
self.atom_types = ['O', 'Cl', 'C', 'S', 'F', 'Br', 'N']
self.bond_types = [Chem.rdchem.BondType.SINGLE,
Chem.rdchem.BondType.DOUBLE,
Chem.rdchem.BondType.TRIPLE]
elif self.dataset == 'ZINC':
self.atom_types = ['Br', 'S', 'C', 'P', 'N', 'O', 'F', 'Cl', 'I']
self.bond_types = [Chem.rdchem.BondType.SINGLE,
Chem.rdchem.BondType.DOUBLE,
Chem.rdchem.BondType.TRIPLE]
else:
path_to_atom_and_bond_types = '_'.join([self.dataset, 'atom_and_bond_types.pkl'])
with open(path_to_atom_and_bond_types, 'rb') as f:
type_info = pickle.load(f)
self.atom_types = type_info['atom_types']
self.bond_types = type_info['bond_types']
self.env = MoleculeEnv(self.atom_types, self.bond_types)
dataset_prefix = self._dataset_prefix()
if 'train' in self.modes:
fname = '_'.join([dataset_prefix, 'train.txt'])
download_data(self.dataset, fname)
smiles = load_smiles_from_file(fname)
self.train_set = self._create_a_subset(smiles)
if 'val' in self.modes:
fname = '_'.join([dataset_prefix, 'val.txt'])
download_data(self.dataset, fname)
smiles = load_smiles_from_file(fname)
# We evenly divide the smiles into multiple susbets with multiprocess
self.val_set = self._create_a_subset(smiles)
def _dataset_prefix(self):
"""Get the prefix for the data files of supported datasets.
Returns
-------
str
Prefix for dataset file name
"""
return '_'.join([self.dataset, 'DGMG'])
class Subset(Dataset):
"""A set of molecules which can be used for training, validation, test.
Parameters
----------
smiles : list
List of SMILES for the dataset
order : str
Specifies how decision sequences for molecule generation
are obtained, can be either "random" or "canonical"
env : MoleculeEnv object
MDP environment for generating molecules
"""
def __init__(self, smiles, order, env):
super(Subset, self).__init__()
self.smiles = smiles
self.order = order
self.env = env
self._setup()
def _setup(self):
"""Convert SMILES into rdkit molecule objects.
Decision sequences are extracted if we use a fixed order.
"""
smiles_ = []
mols = []
for s in self.smiles:
m = smiles_to_standard_mol(s)
if m is None:
continue
smiles_.append(s)
mols.append(m)
self.smiles = smiles_
self.mols = mols
if self.order is 'random':
return
self.decisions = []
for m in self.mols:
self.decisions.append(
self.env.get_decision_sequence(m, list(range(m.GetNumAtoms())))
)
def __len__(self):
"""Get number of molecules in the dataset."""
return len(self.mols)
def __getitem__(self, item):
"""Get the decision sequence for generating the molecule indexed by item."""
if self.order == 'canonical':
return self.decisions[item]
else:
m = self.mols[item]
nodes = list(range(m.GetNumAtoms()))
random.shuffle(nodes)
return self.env.get_decision_sequence(m, nodes)
########################################################################################################################
# progress tracking #
########################################################################################################################
class Printer(object):
def __init__(self, num_epochs, dataset_size, batch_size, writer=None):
"""Wrapper to track the learning progress.
Parameters
----------
num_epochs : int
Number of epochs for training
dataset_size : int
batch_size : int
writer : None or SummaryWriter
If not None, tensorboard will be used to visualize learning curves.
"""
super(Printer, self).__init__()
self.num_epochs = num_epochs
self.batch_size = batch_size
self.num_batches = math.ceil(dataset_size / batch_size)
self.count = 0
self.batch_count = 0
self.writer = writer
self._reset()
def _reset(self):
"""Reset when an epoch is completed."""
self.batch_loss = 0
self.batch_prob = 0
def _get_current_batch(self):
"""Get current batch index."""
remainer = self.batch_count % self.num_batches
if (remainer == 0):
return self.num_batches
else:
return remainer
def update(self, epoch, loss, prob):
"""Update learning progress.
Parameters
----------
epoch : int
loss : float
prob : float
"""
self.count += 1
self.batch_loss += loss
self.batch_prob += prob
if self.count % self.batch_size == 0:
self.batch_count += 1
if self.writer is not None:
self.writer.add_scalar('train_log_prob', self.batch_loss, self.batch_count)
self.writer.add_scalar('train_prob', self.batch_prob, self.batch_count)
print('epoch {:d}/{:d}, batch {:d}/{:d}, loss {:.4f}, prob {:.4f}'.format(
epoch, self.num_epochs, self._get_current_batch(),
self.num_batches, self.batch_loss, self.batch_prob))
self._reset()
########################################################################################################################
# eval #
########################################################################################################################
def summarize_a_molecule(smile, checklist=None):
"""Get information about a molecule.
Parameters
----------
smile : str
Molecule in SMILES format
checklist : dict
Things to learn about the molecule
"""
if checklist is None:
checklist = {
'HBA': Chem.rdMolDescriptors.CalcNumHBA,
'HBD': Chem.rdMolDescriptors.CalcNumHBD,
'logP': MolLogP,
'SA': calculateScore,
'TPSA': Chem.rdMolDescriptors.CalcTPSA,
'QED': qed,
'NumAtoms': lambda mol: mol.GetNumAtoms(),
'NumBonds': lambda mol: mol.GetNumBonds()
}
summary = dict()
mol = Chem.MolFromSmiles(smile)
if mol is None:
summary.update({
'smile': smile,
'valid': False
})
for k in checklist.keys():
summary[k] = None
else:
mol = standardize_mol(mol)
summary.update({
'smile': Chem.MolToSmiles(mol),
'valid': True
})
Chem.SanitizeMol(mol)
for k, f in checklist.items():
summary[k] = f(mol)
return summary
def summarize_molecules(smiles, num_processes):
"""Summarize molecules with multiprocess.
Parameters
----------
smiles : list of str
List of molecules in SMILES for summarization
num_processes : int
Number of processes to use for summarization
Returns
-------
summary_for_valid : dict
Summary of all valid molecules, where
summary_for_valid[k] gives the values of all
valid molecules on item k.
"""
with Pool(processes=num_processes) as pool:
result = pool.map(summarize_a_molecule, smiles)
items = list(result[0].keys())
items.remove('valid')
summary_for_valid = defaultdict(list)
for summary in result:
if summary['valid']:
for k in items:
summary_for_valid[k].append(summary[k])
return summary_for_valid
def get_unique_smiles(smiles):
"""Given a list of smiles, return a list consisting of unique elements in it.
Parameters
----------
smiles : list of str
Molecules in SMILES
Returns
-------
list of str
Sublist where each SMIES occurs exactly once
"""
unique_set = set()
for mol_s in smiles:
if mol_s not in unique_set:
unique_set.add(mol_s)
return list(unique_set)
def get_novel_smiles(new_unique_smiles, reference_unique_smiles):
"""Get novel smiles which do not appear in the reference set.
Parameters
----------
new_unique_smiles : list of str
List of SMILES from which we want to identify novel ones
reference_unique_smiles : list of str
List of reference SMILES that we already have
"""
return set(new_unique_smiles).difference(set(reference_unique_smiles))
# Junction Tree Variational Autoencoder for Molecular Graph Generation (JTNN)
Wengong Jin, Regina Barzilay, Tommi Jaakkola.
Junction Tree Variational Autoencoder for Molecular Graph Generation.
*arXiv preprint arXiv:1802.04364*, 2018.
JTNN uses algorithm called junction tree algorithm to form a tree from the molecular graph.
Then the model will encode the tree and graph into two separate vectors `z_G` and `z_T`. Details can
be found in original paper. The brief process is as below (from original paper):
![image](https://user-images.githubusercontent.com/8686776/63677300-3fb6d980-c81f-11e9-8a65-57c8b03aaf52.png)
**Goal**: JTNN is an auto-encoder model, aiming to learn hidden representation for molecular graphs.
These representations can be used for downstream tasks, such as property prediction, or molecule optimizations.
## Dataset
### ZINC
> The ZINC database is a curated collection of commercially available chemical compounds
prepared especially for virtual screening. (introduction from Wikipedia)
Generally speaking, molecules in the ZINC dataset are more drug-like. We uses ~220,000
molecules for training and 5000 molecules for validation.
### Preprocessing
Class `JTNNDataset` will process a SMILES into a dict, including the junction tree, graph with
encoded nodes(atoms) and edges(bonds), and other information for model to use.
## Usage
### Training
To start training, use `python train.py`. By default, the script will use ZINC dataset
with preprocessed vocabulary, and save model checkpoint at the current working directory.
```
-s SAVE_PATH, Path to save checkpoint models, default to be current
working directory (default: ./)
-m MODEL_PATH, Path to load pre-trained model (default: None)
-b BATCH_SIZE, Batch size (default: 40)
-w HIDDEN_SIZE, Size of representation vectors (default: 200)
-l LATENT_SIZE, Latent Size of node(atom) features and edge(atom)
features (default: 56)
-d DEPTH, Depth of message passing hops (default: 3)
-z BETA, Coefficient of KL Divergence term (default: 1.0)
-q LR, Learning Rate (default: 0.001)
```
Model will be saved periodically.
All training checkpoint will be stored at `SAVE_PATH`, passed by command line or by default.
#### Dataset configuration
If you want to use your own dataset, please create a file contains one SMILES a line,
and pass the file path to the `-t` or `--train` option.
```
-t TRAIN, --train TRAIN
Training file name (default: train)
```
### Evaluation
To start evaluation, use `python reconstruct_eval.py`, and following arguments
```
-t TRAIN, Training file name (default: test)
-m MODEL_PATH, Pre-trained model to be loaded for evalutaion. If not
specified, would use pre-trained model from model zoo
(default: None)
-w HIDDEN_SIZE, Hidden size of representation vector, should be
consistent with pre-trained model (default: 450)
-l LATENT_SIZE, Latent Size of node(atom) features and edge(atom)
features, should be consistent with pre-trained model
(default: 56)
-d DEPTH, Depth of message passing hops, should be consistent
with pre-trained model (default: 3)
```
And it would print out the success rate of reconstructing the same molecules.
### Pre-trained models
Below gives the statistics of pre-trained `JTNN_ZINC` model.
| Pre-trained model | % Reconstruction Accuracy
| ------------------ | -------
| `JTNN_ZINC` | 73.7
### Visualization
Here we draw some "neighbor" of a given molecule, by adding noises on the intermediate representations.
You can download the script with `https://s3.us-west-2.amazonaws.com/dgl-data/dgllife/jtnn_viz_neighbor_mol.ipynb` from the s3
bucket in U.S. or `https://s3.cn-north-1.amazonaws.com.cn/dgl-data/dgllife/jtnn_viz_neighbor_mol.ipynb` from the s3 bucket
in mainland China.
Please put this script at the current directory (`examples/pytorch/model_zoo/chem/generative_models/jtnn/`).
#### Given Molecule
![image](https://user-images.githubusercontent.com/8686776/63773593-0d37da00-c90e-11e9-8933-0abca4b430db.png)
#### Neighbor Molecules
![image](https://user-images.githubusercontent.com/8686776/63773602-1163f780-c90e-11e9-8341-5122dc0d0c82.png)
from .mol_tree import Vocab
from .datautils import JTNNDataset, JTNNCollator
from .chemutils import decode_stereo
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment