Unverified Commit 36c7b771 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[LifeSci] Move to Independent Repo (#1592)

* Move LifeSci

* Remove doc
parent 94c67203
"""Weisfeiler-Lehman Network (WLN) for Reaction Center Prediction."""
# pylint: disable= no-member, arguments-differ, invalid-name
import dgl.function as fn
import torch
import torch.nn as nn
from ..gnn.wln import WLNLinear, WLN
__all__ = ['WLNReactionCenter']
# pylint: disable=W0221, E1101
class WLNContext(nn.Module):
"""Attention-based context computation for each node.
A context vector is computed by taking a weighted sum of node representations,
with weights computed from an attention module.
Parameters
----------
node_in_feats : int
Size for the input node features.
node_pair_in_feats : int
Size for the input features of node pairs.
"""
def __init__(self, node_in_feats, node_pair_in_feats):
super(WLNContext, self).__init__()
self.project_feature_sum = WLNLinear(node_in_feats, node_in_feats, bias=False)
self.project_node_pair_feature = WLNLinear(node_pair_in_feats, node_in_feats)
self.compute_attention = nn.Sequential(
nn.ReLU(),
WLNLinear(node_in_feats, 1),
nn.Sigmoid()
)
def forward(self, batch_complete_graphs, node_feats, feat_sum, node_pair_feat):
"""Compute context vectors for each node.
Parameters
----------
batch_complete_graphs : DGLGraph
A batch of fully connected graphs.
node_feats : float32 tensor of shape (V, node_in_feats)
Input node features. V for the number of nodes.
feat_sum : float32 tensor of shape (E_full, node_in_feats)
Sum of node_feats between each pair of nodes. E_full for the number of
edges in the batch of complete graphs.
node_pair_feat : float32 tensor of shape (E_full, node_pair_in_feats)
Input features for each pair of nodes. E_full for the number of edges in
the batch of complete graphs.
Returns
-------
node_contexts : float32 tensor of shape (V, node_in_feats)
Context vectors for nodes.
"""
with batch_complete_graphs.local_scope():
batch_complete_graphs.ndata['hv'] = node_feats
batch_complete_graphs.edata['a'] = self.compute_attention(
self.project_feature_sum(feat_sum) + \
self.project_node_pair_feature(node_pair_feat)
)
batch_complete_graphs.update_all(
fn.src_mul_edge('hv', 'a', 'm'), fn.sum('m', 'context'))
node_contexts = batch_complete_graphs.ndata.pop('context')
return node_contexts
class WLNReactionCenter(nn.Module):
r"""Weisfeiler-Lehman Network (WLN) for Reaction Center Prediction.
The model is introduced in `Predicting Organic Reaction Outcomes with
Weisfeiler-Lehman Network <https://arxiv.org/abs/1709.04555>`__.
The model uses WLN to update atom representations and then predicts the
score for each pair of atoms to form a bond.
Parameters
----------
node_in_feats : int
Size for the input node features.
edge_in_feats : int
Size for the input edge features.
node_out_feats : int
Size for the output node representations. Default to 300.
node_pair_in_feats : int
Size for the input features of node pairs.
n_layers : int
Number of times for message passing. Note that same parameters
are shared across n_layers message passing. Default to 3.
n_tasks : int
Number of tasks for prediction.
"""
def __init__(self,
node_in_feats,
edge_in_feats,
node_pair_in_feats,
node_out_feats=300,
n_layers=3,
n_tasks=5):
super(WLNReactionCenter, self).__init__()
self.gnn = WLN(node_in_feats=node_in_feats,
edge_in_feats=edge_in_feats,
node_out_feats=node_out_feats,
n_layers=n_layers)
self.context_module = WLNContext(node_in_feats=node_out_feats,
node_pair_in_feats=node_pair_in_feats)
self.project_feature_sum = WLNLinear(node_out_feats, node_out_feats, bias=False)
self.project_node_pair_feature = WLNLinear(node_pair_in_feats, node_out_feats, bias=False)
self.project_context_sum = WLNLinear(node_out_feats, node_out_feats)
self.predict = nn.Sequential(
nn.ReLU(),
WLNLinear(node_out_feats, n_tasks)
)
def forward(self, batch_mol_graphs, batch_complete_graphs,
node_feats, edge_feats, node_pair_feats):
r"""Predict score for each pair of nodes.
Parameters
----------
batch_mol_graphs : DGLGraph
A batch of molecular graphs.
batch_complete_graphs : DGLGraph
A batch of fully connected graphs.
node_feats : float32 tensor of shape (V, node_in_feats)
Input node features. V for the number of nodes.
edge_feats : float32 tensor of shape (E, edge_in_feats)
Input edge features. E for the number of edges.
node_pair_feats : float32 tensor of shape (E_full, node_pair_in_feats)
Input features for each pair of nodes. E_full for the number of edges in
the batch of complete graphs.
Returns
-------
scores : float32 tensor of shape (E_full, 5)
Predicted scores for each pair of atoms to perform one of the following
5 actions in reaction:
* The bond between them gets broken
* Forming a single bond
* Forming a double bond
* Forming a triple bond
* Forming an aromatic bond
biased_scores : float32 tensor of shape (E_full, 5)
Comparing to scores, a bias is added if the pair is for a same atom.
"""
node_feats = self.gnn(batch_mol_graphs, node_feats, edge_feats)
# Compute context vectors for all atoms, which are weighted sum of atom
# representations in all reactants.
with batch_complete_graphs.local_scope():
batch_complete_graphs.ndata['hv'] = node_feats
batch_complete_graphs.apply_edges(fn.u_add_v('hv', 'hv', 'feature_sum'))
feat_sum = batch_complete_graphs.edata.pop('feature_sum')
node_contexts = self.context_module(batch_complete_graphs, node_feats,
feat_sum, node_pair_feats)
# Predict score
with batch_complete_graphs.local_scope():
batch_complete_graphs.ndata['context'] = node_contexts
batch_complete_graphs.apply_edges(fn.u_add_v('context', 'context', 'context_sum'))
scores = self.predict(
self.project_feature_sum(feat_sum) + \
self.project_node_pair_feature(node_pair_feats) + \
self.project_context_sum(batch_complete_graphs.edata['context_sum'])
)
# Masking self loops
nodes = batch_complete_graphs.nodes()
e_ids = batch_complete_graphs.edge_ids(nodes, nodes)
bias = torch.zeros(scores.shape[0], 5).to(scores.device)
bias[e_ids, :] = 1e4
biased_scores = scores - bias
return scores, biased_scores
"""Weisfeiler-Lehman Network (WLN) for ranking candidate products"""
# pylint: disable= no-member, arguments-differ, invalid-name
import torch
import torch.nn as nn
from dgl.nn.pytorch import SumPooling
from ..gnn.wln import WLN
__all__ = ['WLNReactionRanking']
# pylint: disable=W0221, E1101
class WLNReactionRanking(nn.Module):
r"""Weisfeiler-Lehman Network (WLN) for Candidate Product Ranking
The model is introduced in `Predicting Organic Reaction Outcomes with
Weisfeiler-Lehman Network <https://arxiv.org/abs/1709.04555>`__ and then
further improved in `A graph-convolutional neural network model for the
prediction of chemical reactivity
<https://pubs.rsc.org/en/content/articlelanding/2019/sc/c8sc04228d#!divAbstract>`__
The model updates representations of nodes in candidate products with WLN and predicts
the score for candidate products to be the real product.
Parameters
----------
node_in_feats : int
Size for the input node features.
edge_in_feats : int
Size for the input edge features.
node_hidden_feats : int
Size for the hidden node representations. Default to 500.
num_encode_gnn_layers : int
Number of WLN layers for updating node representations.
"""
def __init__(self,
node_in_feats,
edge_in_feats,
node_hidden_feats=500,
num_encode_gnn_layers=3):
super(WLNReactionRanking, self).__init__()
self.gnn = WLN(node_in_feats=node_in_feats,
edge_in_feats=edge_in_feats,
node_out_feats=node_hidden_feats,
n_layers=num_encode_gnn_layers,
set_comparison=False)
self.diff_gnn = WLN(node_in_feats=node_hidden_feats,
edge_in_feats=edge_in_feats,
node_out_feats=node_hidden_feats,
n_layers=1,
project_in_feats=False,
set_comparison=False)
self.readout = SumPooling()
self.predict = nn.Sequential(
nn.Linear(node_hidden_feats, node_hidden_feats),
nn.ReLU(),
nn.Linear(node_hidden_feats, 1)
)
def forward(self, reactant_graph, reactant_node_feats, reactant_edge_feats,
product_graphs, product_node_feats, product_edge_feats,
candidate_scores, batch_num_candidate_products):
r"""Predicts the score for candidate products to be the true product
Parameters
----------
reactant_graph : DGLGraph
DGLGraph for a batch of reactants.
reactant_node_feats : float32 tensor of shape (V1, node_in_feats)
Input node features for the reactants. V1 for the number of nodes.
reactant_edge_feats : float32 tensor of shape (E1, edge_in_feats)
Input edge features for the reactants. E1 for the number of edges in
reactant_graph.
product_graphs : DGLGraph
DGLGraph for the candidate products in a batch of reactions.
product_node_feats : float32 tensor of shape (V2, node_in_feats)
Input node features for the candidate products. V2 for the number of nodes.
product_edge_feats : float32 tensor of shape (E2, edge_in_feats)
Input edge features for the candidate products. E2 for the number of edges
in the graphs for candidate products.
candidate_scores : float32 tensor of shape (B, 1)
Scores for candidate products based on the model for reaction center prediction
batch_num_candidate_products : list of int
Number of candidate products for the reactions in the batch
Returns
-------
float32 tensor of shape (B, 1)
Predicted scores for candidate products
"""
# Update representations for nodes in both reactants and candidate products
batch_reactant_node_feats = self.gnn(
reactant_graph, reactant_node_feats, reactant_edge_feats)
batch_product_node_feats = self.gnn(
product_graphs, product_node_feats, product_edge_feats)
# Iterate over the reactions in the batch
reactant_node_start = 0
product_graph_start = 0
product_node_start = 0
batch_diff_node_feats = []
for i, num_candidate_products in enumerate(batch_num_candidate_products):
reactant_node_end = reactant_node_start + reactant_graph.batch_num_nodes[i]
product_graph_end = product_graph_start + num_candidate_products
product_node_end = product_node_start + sum(
product_graphs.batch_num_nodes[product_graph_start: product_graph_end])
# (N, node_out_feats)
reactant_node_feats = batch_reactant_node_feats[reactant_node_start:
reactant_node_end, :]
product_node_feats = batch_product_node_feats[product_node_start: product_node_end, :]
old_feats_shape = reactant_node_feats.shape
# (1, N, node_out_feats)
expanded_reactant_node_feats = reactant_node_feats.reshape((1,) + old_feats_shape)
# (B, N, node_out_feats)
expanded_reactant_node_feats = expanded_reactant_node_feats.expand(
(num_candidate_products,) + old_feats_shape)
# (B, N, node_out_feats)
candidate_product_node_feats = product_node_feats.reshape(
(num_candidate_products,) + old_feats_shape)
# Get the node representation difference between candidate products and reactants
diff_node_feats = candidate_product_node_feats - expanded_reactant_node_feats
diff_node_feats = diff_node_feats.reshape(-1, diff_node_feats.shape[-1])
batch_diff_node_feats.append(diff_node_feats)
reactant_node_start = reactant_node_end
product_graph_start = product_graph_end
product_node_start = product_node_end
batch_diff_node_feats = torch.cat(batch_diff_node_feats, dim=0)
# One more GNN layer for message passing with the node representation difference
diff_node_feats = self.diff_gnn(product_graphs, batch_diff_node_feats, product_edge_feats)
candidate_product_feats = self.readout(product_graphs, diff_node_feats)
return self.predict(candidate_product_feats) + candidate_scores
"""Utilities for using pretrained models."""
# pylint: disable= no-member, arguments-differ, invalid-name
import os
import torch
import torch.nn.functional as F
from dgl.data.utils import _get_dgl_url, download, get_download_dir, extract_archive
from rdkit import Chem
from ..model import GCNPredictor, GATPredictor, AttentiveFPPredictor, DGMG, DGLJTNNVAE, \
WLNReactionCenter, WLNReactionRanking, WeavePredictor, GIN
__all__ = ['load_pretrained']
URL = {
'GCN_Tox21': 'dgllife/pre_trained/gcn_tox21.pth',
'GAT_Tox21': 'dgllife/pre_trained/gat_tox21.pth',
'Weave_Tox21': 'dgllife/pre_trained/weave_tox21.pth',
'AttentiveFP_Aromaticity': 'dgllife/pre_trained/attentivefp_aromaticity.pth',
'DGMG_ChEMBL_canonical': 'pre_trained/dgmg_ChEMBL_canonical.pth',
'DGMG_ChEMBL_random': 'pre_trained/dgmg_ChEMBL_random.pth',
'DGMG_ZINC_canonical': 'pre_trained/dgmg_ZINC_canonical.pth',
'DGMG_ZINC_random': 'pre_trained/dgmg_ZINC_random.pth',
'JTNN_ZINC': 'pre_trained/JTNN_ZINC.pth',
'wln_center_uspto': 'dgllife/pre_trained/wln_center_uspto_v3.pth',
'wln_rank_uspto': 'dgllife/pre_trained/wln_rank_uspto.pth',
'gin_supervised_contextpred': 'dgllife/pre_trained/gin_supervised_contextpred.pth',
'gin_supervised_infomax': 'dgllife/pre_trained/gin_supervised_infomax.pth',
'gin_supervised_edgepred': 'dgllife/pre_trained/gin_supervised_edgepred.pth',
'gin_supervised_masking': 'dgllife/pre_trained/gin_supervised_masking.pth'
}
def download_and_load_checkpoint(model_name, model, model_postfix,
local_pretrained_path='pre_trained.pth', log=True):
"""Download pretrained model checkpoint
The model will be loaded to CPU.
Parameters
----------
model_name : str
Name of the model
model : nn.Module
Instantiated model instance
model_postfix : str
Postfix for pretrained model checkpoint
local_pretrained_path : str
Local name for the downloaded model checkpoint
log : bool
Whether to print progress for model loading
Returns
-------
model : nn.Module
Pretrained model
"""
url_to_pretrained = _get_dgl_url(model_postfix)
local_pretrained_path = '_'.join([model_name, local_pretrained_path])
download(url_to_pretrained, path=local_pretrained_path, log=log)
checkpoint = torch.load(local_pretrained_path, map_location='cpu')
model.load_state_dict(checkpoint['model_state_dict'])
if log:
print('Pretrained model loaded')
return model
# pylint: disable=I1101
def load_pretrained(model_name, log=True):
"""Load a pretrained model
Parameters
----------
model_name : str
Currently supported options include
* ``'GCN_Tox21'``: A GCN-based model for molecular property prediction on Tox21
* ``'GAT_Tox21'``: A GAT-based model for molecular property prediction on Tox21
* ``'Weave_Tox21'``: A Weave model for molecular property prediction on Tox21
* ``'AttentiveFP_Aromaticity'``: An AttentiveFP model for predicting number of
aromatic atoms on a subset of Pubmed
* ``'DGMG_ChEMBL_canonical'``: A DGMG model trained on ChEMBL with a canonical
atom order
* ``'DGMG_ChEMBL_random'``: A DGMG model trained on ChEMBL for molecule generation
with a random atom order
* ``'DGMG_ZINC_canonical'``: A DGMG model trained on ZINC for molecule generation
with a canonical atom order
* ``'DGMG_ZINC_random'``: A DGMG model pre-trained on ZINC for molecule generation
with a random atom order
* ``'JTNN_ZINC'``: A JTNN model pre-trained on ZINC for molecule generation
* ``'wln_center_uspto'``: A WLN model pre-trained on USPTO for reaction prediction
* ``'wln_rank_uspto'``: A WLN model pre-trained on USPTO for candidate product ranking
* ``'gin_supervised_contextpred'``: A GIN model pre-trained with supervised learning
and context prediction
* ``'gin_supervised_infomax'``: A GIN model pre-trained with supervised learning
and deep graph infomax
* ``'gin_supervised_edgepred'``: A GIN model pre-trained with supervised learning
and edge prediction
* ``'gin_supervised_masking'``: A GIN model pre-trained with supervised learning
and attribute masking
log : bool
Whether to print progress for model loading
Returns
-------
model
"""
if model_name not in URL:
raise RuntimeError("Cannot find a pretrained model with name {}".format(model_name))
if model_name == 'GCN_Tox21':
model = GCNPredictor(in_feats=74,
hidden_feats=[64, 64],
classifier_hidden_feats=64,
n_tasks=12)
elif model_name == 'GAT_Tox21':
model = GATPredictor(in_feats=74,
hidden_feats=[32, 32],
num_heads=[4, 4],
agg_modes=['flatten', 'mean'],
activations=[F.elu, None],
classifier_hidden_feats=64,
n_tasks=12)
elif model_name == 'Weave_Tox21':
model = WeavePredictor(node_in_feats=27,
edge_in_feats=7,
num_gnn_layers=2,
gnn_hidden_feats=50,
graph_feats=128,
n_tasks=12)
elif model_name == 'AttentiveFP_Aromaticity':
model = AttentiveFPPredictor(node_feat_size=39,
edge_feat_size=10,
num_layers=2,
num_timesteps=2,
graph_feat_size=200,
n_tasks=1,
dropout=0.2)
elif model_name.startswith('DGMG'):
if model_name.startswith('DGMG_ChEMBL'):
atom_types = ['O', 'Cl', 'C', 'S', 'F', 'Br', 'N']
elif model_name.startswith('DGMG_ZINC'):
atom_types = ['Br', 'S', 'C', 'P', 'N', 'O', 'F', 'Cl', 'I']
bond_types = [Chem.rdchem.BondType.SINGLE,
Chem.rdchem.BondType.DOUBLE,
Chem.rdchem.BondType.TRIPLE]
model = DGMG(atom_types=atom_types,
bond_types=bond_types,
node_hidden_size=128,
num_prop_rounds=2,
dropout=0.2)
elif model_name == "JTNN_ZINC":
default_dir = get_download_dir()
vocab_file = '{}/jtnn/{}.txt'.format(default_dir, 'vocab')
if not os.path.exists(vocab_file):
zip_file_path = '{}/jtnn.zip'.format(default_dir)
download(_get_dgl_url('dgllife/jtnn.zip'), path=zip_file_path)
extract_archive(zip_file_path, '{}/jtnn'.format(default_dir))
model = DGLJTNNVAE(vocab_file=vocab_file,
depth=3,
hidden_size=450,
latent_size=56)
elif model_name == 'wln_center_uspto':
model = WLNReactionCenter(node_in_feats=82,
edge_in_feats=6,
node_pair_in_feats=10,
node_out_feats=300,
n_layers=3,
n_tasks=5)
elif model_name == 'wln_rank_uspto':
model = WLNReactionRanking(node_in_feats=89,
edge_in_feats=5,
node_hidden_feats=500,
num_encode_gnn_layers=3)
elif model_name in ['gin_supervised_contextpred', 'gin_supervised_infomax',
'gin_supervised_edgepred', 'gin_supervised_masking']:
model = GIN(num_node_emb_list=[120, 3],
num_edge_emb_list=[6, 3],
num_layers=5,
emb_dim=300,
JK='last',
dropout=0.5)
return download_and_load_checkpoint(model_name, model, URL[model_name], log=log)
"""
Readout functions for computing molecular representations
out of node and edge representations.
"""
from .attentivefp_readout import *
from .weighted_sum_and_max import *
from .mlp_readout import *
from .weave_readout import *
"""Readout for AttentiveFP"""
# pylint: disable= no-member, arguments-differ, invalid-name
import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F
__all__ = ['AttentiveFPReadout']
# pylint: disable=W0221
class GlobalPool(nn.Module):
"""One-step readout in AttentiveFP
Parameters
----------
feat_size : int
Size for the input node features, graph features and output graph
representations.
dropout : float
The probability for performing dropout.
"""
def __init__(self, feat_size, dropout):
super(GlobalPool, self).__init__()
self.compute_logits = nn.Sequential(
nn.Linear(2 * feat_size, 1),
nn.LeakyReLU()
)
self.project_nodes = nn.Sequential(
nn.Dropout(dropout),
nn.Linear(feat_size, feat_size)
)
self.gru = nn.GRUCell(feat_size, feat_size)
def forward(self, g, node_feats, g_feats, get_node_weight=False):
"""Perform one-step readout
Parameters
----------
g : DGLGraph
DGLGraph for a batch of graphs.
node_feats : float32 tensor of shape (V, node_feat_size)
Input node features. V for the number of nodes.
g_feats : float32 tensor of shape (G, graph_feat_size)
Input graph features. G for the number of graphs.
get_node_weight : bool
Whether to get the weights of atoms during readout.
Returns
-------
float32 tensor of shape (G, graph_feat_size)
Updated graph features.
float32 tensor of shape (V, 1)
The weights of nodes in readout.
"""
with g.local_scope():
g.ndata['z'] = self.compute_logits(
torch.cat([dgl.broadcast_nodes(g, F.relu(g_feats)), node_feats], dim=1))
g.ndata['a'] = dgl.softmax_nodes(g, 'z')
g.ndata['hv'] = self.project_nodes(node_feats)
g_repr = dgl.sum_nodes(g, 'hv', 'a')
context = F.elu(g_repr)
if get_node_weight:
return self.gru(context, g_feats), g.ndata['a']
else:
return self.gru(context, g_feats)
class AttentiveFPReadout(nn.Module):
"""Readout in AttentiveFP
AttentiveFP is introduced in `Pushing the Boundaries of Molecular Representation for
Drug Discovery with the Graph Attention Mechanism
<https://www.ncbi.nlm.nih.gov/pubmed/31408336>`__
This class computes graph representations out of node features.
Parameters
----------
feat_size : int
Size for the input node features, graph features and output graph
representations.
num_timesteps : int
Times of updating the graph representations with GRU. Default to 2.
dropout : float
The probability for performing dropout. Default to 0.
"""
def __init__(self, feat_size, num_timesteps=2, dropout=0.):
super(AttentiveFPReadout, self).__init__()
self.readouts = nn.ModuleList()
for _ in range(num_timesteps):
self.readouts.append(GlobalPool(feat_size, dropout))
def forward(self, g, node_feats, get_node_weight=False):
"""Computes graph representations out of node features.
Parameters
----------
g : DGLGraph
DGLGraph for a batch of graphs.
node_feats : float32 tensor of shape (V, node_feat_size)
Input node features. V for the number of nodes.
get_node_weight : bool
Whether to get the weights of nodes in readout. Default to False.
Returns
-------
g_feats : float32 tensor of shape (G, graph_feat_size)
Graph representations computed. G for the number of graphs.
node_weights : list of float32 tensor of shape (V, 1), optional
This is returned when ``get_node_weight`` is ``True``.
The list has a length ``num_timesteps`` and ``node_weights[i]``
gives the node weights in the i-th update.
"""
with g.local_scope():
g.ndata['hv'] = node_feats
g_feats = dgl.sum_nodes(g, 'hv')
if get_node_weight:
node_weights = []
for readout in self.readouts:
if get_node_weight:
g_feats, node_weights_t = readout(g, node_feats, g_feats, get_node_weight)
node_weights.append(node_weights_t)
else:
g_feats = readout(g, node_feats, g_feats)
if get_node_weight:
return g_feats, node_weights
else:
return g_feats
"""Readout for SchNet"""
# pylint: disable= no-member, arguments-differ, invalid-name
import dgl
import torch.nn as nn
__all__ = ['MLPNodeReadout']
# pylint: disable=W0221
class MLPNodeReadout(nn.Module):
"""MLP-based Readout.
This layer updates node representations with a MLP and computes graph representations
out of node representations with max, mean or sum.
Parameters
----------
node_feats : int
Size for the input node features.
hidden_feats : int
Size for the hidden representations.
graph_feats : int
Size for the output graph representations.
activation : callable
Activation function. Default to None.
mode : 'max' or 'mean' or 'sum'
Whether to compute elementwise maximum, mean or sum of the node representations.
"""
def __init__(self, node_feats, hidden_feats, graph_feats, activation=None, mode='sum'):
super(MLPNodeReadout, self).__init__()
assert mode in ['max', 'mean', 'sum'], \
"Expect mode to be 'max' or 'mean' or 'sum', got {}".format(mode)
self.mode = mode
self.in_project = nn.Linear(node_feats, hidden_feats)
self.activation = activation
self.out_project = nn.Linear(hidden_feats, graph_feats)
def forward(self, g, node_feats):
"""Computes graph representations out of node features.
Parameters
----------
g : DGLGraph
DGLGraph for a batch of graphs.
node_feats : float32 tensor of shape (V, node_feats)
Input node features, V for the number of nodes.
Returns
-------
graph_feats : float32 tensor of shape (G, graph_feats)
Graph representations computed. G for the number of graphs.
"""
node_feats = self.in_project(node_feats)
if self.activation is not None:
node_feats = self.activation(node_feats)
node_feats = self.out_project(node_feats)
with g.local_scope():
g.ndata['h'] = node_feats
if self.mode == 'max':
graph_feats = dgl.max_nodes(g, 'h')
elif self.mode == 'mean':
graph_feats = dgl.mean_nodes(g, 'h')
elif self.mode == 'sum':
graph_feats = dgl.sum_nodes(g, 'h')
return graph_feats
"""Readout for Weave"""
# pylint: disable= no-member, arguments-differ, invalid-name
import dgl
import torch
import torch.nn as nn
from torch.distributions import Normal
__all__ = ['WeaveGather']
# pylint: disable=W0221, E1101, E1102
class WeaveGather(nn.Module):
r"""Readout in Weave
Parameters
----------
node_in_feats : int
Size for the input node features.
gaussian_expand : bool
Whether to expand each dimension of node features by gaussian histogram.
Default to True.
gaussian_memberships : list of 2-tuples
For each tuple, the first and second element separately specifies the mean
and std for constructing a normal distribution. This argument comes into
effect only when ``gaussian_expand==True``. By default, we set this to be
``[(-1.645, 0.283), (-1.080, 0.170), (-0.739, 0.134), (-0.468, 0.118),
(-0.228, 0.114), (0., 0.114), (0.228, 0.114), (0.468, 0.118),
(0.739, 0.134), (1.080, 0.170), (1.645, 0.283)]``.
activation : callable
Activation function to apply. Default to tanh.
"""
def __init__(self,
node_in_feats,
gaussian_expand=True,
gaussian_memberships=None,
activation=nn.Tanh()):
super(WeaveGather, self).__init__()
self.gaussian_expand = gaussian_expand
if gaussian_expand:
if gaussian_memberships is None:
gaussian_memberships = [
(-1.645, 0.283), (-1.080, 0.170), (-0.739, 0.134), (-0.468, 0.118),
(-0.228, 0.114), (0., 0.114), (0.228, 0.114), (0.468, 0.118),
(0.739, 0.134), (1.080, 0.170), (1.645, 0.283)]
means, stds = map(list, zip(*gaussian_memberships))
self.means = nn.ParameterList([
nn.Parameter(torch.tensor(value), requires_grad=False)
for value in means
])
self.stds = nn.ParameterList([
nn.Parameter(torch.tensor(value), requires_grad=False)
for value in stds
])
self.to_out = nn.Linear(node_in_feats * len(self.means), node_in_feats)
self.activation = activation
def gaussian_histogram(self, node_feats):
r"""Constructs a gaussian histogram to capture the distribution of features
Parameters
----------
node_feats : float32 tensor of shape (V, node_in_feats)
Input node features. V for the number of nodes in the batch of graphs.
Returns
-------
float32 tensor of shape (V, node_in_feats * len(self.means))
Updated node representations
"""
gaussian_dists = [Normal(self.means[i], self.stds[i])
for i in range(len(self.means))]
max_log_probs = [gaussian_dists[i].log_prob(self.means[i])
for i in range(len(self.means))]
# Normalize the probabilities by the maximum point-wise probabilities,
# whose results will be in range [0, 1]. Note that division of probabilities
# is equivalent to subtraction of log probabilities and the latter one is cheaper.
log_probs = [gaussian_dists[i].log_prob(node_feats) - max_log_probs[i]
for i in range(len(self.means))]
probs = torch.stack(log_probs, dim=2).exp() # (V, node_in_feats, len(self.means))
# Add a bias to avoid numerical issues in division
probs = probs + 1e-7
# Normalize the probabilities across all Gaussian distributions
probs = probs / probs.sum(2, keepdim=True)
return probs.reshape(node_feats.shape[0],
node_feats.shape[1] * len(self.means))
def forward(self, g, node_feats):
r"""Computes graph representations out of node representations.
Parameters
----------
g : DGLGraph
DGLGraph for a batch of graphs.
node_feats : float32 tensor of shape (V, node_in_feats)
Input node features. V for the number of nodes in the batch of graphs.
Returns
-------
g_feats : float32 tensor of shape (G, node_in_feats)
Output graph representations. G for the number of graphs in the batch.
"""
if self.gaussian_expand:
node_feats = self.gaussian_histogram(node_feats)
with g.local_scope():
g.ndata['h'] = node_feats
g_feats = dgl.sum_nodes(g, 'h')
if self.gaussian_expand:
g_feats = self.to_out(g_feats)
if self.activation is not None:
g_feats = self.activation(g_feats)
return g_feats
"""Apply weighted sum and max pooling to the node representations and concatenate the results."""
# pylint: disable= no-member, arguments-differ, invalid-name
import dgl
import torch
import torch.nn as nn
from dgl.nn.pytorch import WeightAndSum
__all__ = ['WeightedSumAndMax']
# pylint: disable=W0221
class WeightedSumAndMax(nn.Module):
r"""Apply weighted sum and max pooling to the node
representations and concatenate the results.
Parameters
----------
in_feats : int
Input node feature size
"""
def __init__(self, in_feats):
super(WeightedSumAndMax, self).__init__()
self.weight_and_sum = WeightAndSum(in_feats)
def forward(self, bg, feats):
"""Readout
Parameters
----------
bg : DGLGraph
DGLGraph for a batch of graphs.
feats : FloatTensor of shape (N, M1)
* N is the total number of nodes in the batch of graphs
* M1 is the input node feature size, which must match
in_feats in initialization
Returns
-------
h_g : FloatTensor of shape (B, 2 * M1)
* B is the number of graphs in the batch
* M1 is the input node feature size, which must match
in_feats in initialization
"""
h_g_sum = self.weight_and_sum(bg, feats)
with bg.local_scope():
bg.ndata['h'] = feats
h_g_max = dgl.max_nodes(bg, 'h')
h_g = torch.cat([h_g_sum, h_g_max], dim=1)
return h_g
"""Utils for data processing."""
from .complex_to_graph import *
from .early_stop import *
from .eval import *
from .featurizers import *
from .mol_to_graph import *
from .rdkit_utils import *
from .splitters import *
"""Convert complexes into DGLHeteroGraphs"""
# pylint: disable= no-member, arguments-differ, invalid-name
import dgl.backend as F
import numpy as np
from dgl import graph, bipartite, hetero_from_relations
from ..utils.mol_to_graph import k_nearest_neighbors
__all__ = ['ACNN_graph_construction_and_featurization']
def filter_out_hydrogens(mol):
"""Get indices for non-hydrogen atoms.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance.
Returns
-------
indices_left : list of int
Indices of non-hydrogen atoms.
"""
indices_left = []
for i, atom in enumerate(mol.GetAtoms()):
atomic_num = atom.GetAtomicNum()
# Hydrogen atoms have an atomic number of 1.
if atomic_num != 1:
indices_left.append(i)
return indices_left
def get_atomic_numbers(mol, indices):
"""Get the atomic numbers for the specified atoms.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance.
indices : list of int
Specifying atoms.
Returns
-------
list of int
Atomic numbers computed.
"""
atomic_numbers = []
for i in indices:
atom = mol.GetAtomWithIdx(i)
atomic_numbers.append(atom.GetAtomicNum())
return atomic_numbers
# pylint: disable=C0326
def ACNN_graph_construction_and_featurization(ligand_mol,
protein_mol,
ligand_coordinates,
protein_coordinates,
max_num_ligand_atoms=None,
max_num_protein_atoms=None,
neighbor_cutoff=12.,
max_num_neighbors=12,
strip_hydrogens=False):
"""Graph construction and featurization for `Atomic Convolutional Networks for
Predicting Protein-Ligand Binding Affinity <https://arxiv.org/abs/1703.10603>`__.
Parameters
----------
ligand_mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance.
protein_mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance.
ligand_coordinates : Float Tensor of shape (V1, 3)
Atom coordinates in a ligand.
protein_coordinates : Float Tensor of shape (V2, 3)
Atom coordinates in a protein.
max_num_ligand_atoms : int or None
Maximum number of atoms in ligands for zero padding, which should be no smaller than
ligand_mol.GetNumAtoms() if not None. If None, no zero padding will be performed.
Default to None.
max_num_protein_atoms : int or None
Maximum number of atoms in proteins for zero padding, which should be no smaller than
protein_mol.GetNumAtoms() if not None. If None, no zero padding will be performed.
Default to None.
neighbor_cutoff : float
Distance cutoff to define 'neighboring'. Default to 12.
max_num_neighbors : int
Maximum number of neighbors allowed for each atom. Default to 12.
strip_hydrogens : bool
Whether to exclude hydrogen atoms. Default to False.
"""
assert ligand_coordinates is not None, 'Expect ligand_coordinates to be provided.'
assert protein_coordinates is not None, 'Expect protein_coordinates to be provided.'
if max_num_ligand_atoms is not None:
assert max_num_ligand_atoms >= ligand_mol.GetNumAtoms(), \
'Expect max_num_ligand_atoms to be no smaller than ligand_mol.GetNumAtoms(), ' \
'got {:d} and {:d}'.format(max_num_ligand_atoms, ligand_mol.GetNumAtoms())
if max_num_protein_atoms is not None:
assert max_num_protein_atoms >= protein_mol.GetNumAtoms(), \
'Expect max_num_protein_atoms to be no smaller than protein_mol.GetNumAtoms(), ' \
'got {:d} and {:d}'.format(max_num_protein_atoms, protein_mol.GetNumAtoms())
if strip_hydrogens:
# Remove hydrogen atoms and their corresponding coordinates
ligand_atom_indices_left = filter_out_hydrogens(ligand_mol)
protein_atom_indices_left = filter_out_hydrogens(protein_mol)
ligand_coordinates = ligand_coordinates.take(ligand_atom_indices_left, axis=0)
protein_coordinates = protein_coordinates.take(protein_atom_indices_left, axis=0)
else:
ligand_atom_indices_left = list(range(ligand_mol.GetNumAtoms()))
protein_atom_indices_left = list(range(protein_mol.GetNumAtoms()))
# Compute number of nodes for each type
if max_num_ligand_atoms is None:
num_ligand_atoms = len(ligand_atom_indices_left)
else:
num_ligand_atoms = max_num_ligand_atoms
if max_num_protein_atoms is None:
num_protein_atoms = len(protein_atom_indices_left)
else:
num_protein_atoms = max_num_protein_atoms
# Construct graph for atoms in the ligand
ligand_srcs, ligand_dsts, ligand_dists = k_nearest_neighbors(
ligand_coordinates, neighbor_cutoff, max_num_neighbors)
ligand_graph = graph((ligand_srcs, ligand_dsts),
'ligand_atom', 'ligand', num_ligand_atoms)
ligand_graph.edata['distance'] = F.reshape(F.zerocopy_from_numpy(
np.array(ligand_dists).astype(np.float32)), (-1, 1))
# Construct graph for atoms in the protein
protein_srcs, protein_dsts, protein_dists = k_nearest_neighbors(
protein_coordinates, neighbor_cutoff, max_num_neighbors)
protein_graph = graph((protein_srcs, protein_dsts),
'protein_atom', 'protein', num_protein_atoms)
protein_graph.edata['distance'] = F.reshape(F.zerocopy_from_numpy(
np.array(protein_dists).astype(np.float32)), (-1, 1))
# Construct 4 graphs for complex representation, including the connection within
# protein atoms, the connection within ligand atoms and the connection between
# protein and ligand atoms.
complex_srcs, complex_dsts, complex_dists = k_nearest_neighbors(
np.concatenate([ligand_coordinates, protein_coordinates]),
neighbor_cutoff, max_num_neighbors)
complex_srcs = np.array(complex_srcs)
complex_dsts = np.array(complex_dsts)
complex_dists = np.array(complex_dists)
offset = num_ligand_atoms
# ('ligand_atom', 'complex', 'ligand_atom')
inter_ligand_indices = np.intersect1d(
(complex_srcs < offset).nonzero()[0],
(complex_dsts < offset).nonzero()[0],
assume_unique=True)
inter_ligand_graph = graph(
(complex_srcs[inter_ligand_indices].tolist(),
complex_dsts[inter_ligand_indices].tolist()),
'ligand_atom', 'complex', num_ligand_atoms)
inter_ligand_graph.edata['distance'] = F.reshape(F.zerocopy_from_numpy(
complex_dists[inter_ligand_indices].astype(np.float32)), (-1, 1))
# ('protein_atom', 'complex', 'protein_atom')
inter_protein_indices = np.intersect1d(
(complex_srcs >= offset).nonzero()[0],
(complex_dsts >= offset).nonzero()[0],
assume_unique=True)
inter_protein_graph = graph(
((complex_srcs[inter_protein_indices] - offset).tolist(),
(complex_dsts[inter_protein_indices] - offset).tolist()),
'protein_atom', 'complex', num_protein_atoms)
inter_protein_graph.edata['distance'] = F.reshape(F.zerocopy_from_numpy(
complex_dists[inter_protein_indices].astype(np.float32)), (-1, 1))
# ('ligand_atom', 'complex', 'protein_atom')
ligand_protein_indices = np.intersect1d(
(complex_srcs < offset).nonzero()[0],
(complex_dsts >= offset).nonzero()[0],
assume_unique=True)
ligand_protein_graph = bipartite(
(complex_srcs[ligand_protein_indices].tolist(),
(complex_dsts[ligand_protein_indices] - offset).tolist()),
'ligand_atom', 'complex', 'protein_atom',
(num_ligand_atoms, num_protein_atoms))
ligand_protein_graph.edata['distance'] = F.reshape(F.zerocopy_from_numpy(
complex_dists[ligand_protein_indices].astype(np.float32)), (-1, 1))
# ('protein_atom', 'complex', 'ligand_atom')
protein_ligand_indices = np.intersect1d(
(complex_srcs >= offset).nonzero()[0],
(complex_dsts < offset).nonzero()[0],
assume_unique=True)
protein_ligand_graph = bipartite(
((complex_srcs[protein_ligand_indices] - offset).tolist(),
complex_dsts[protein_ligand_indices].tolist()),
'protein_atom', 'complex', 'ligand_atom',
(num_protein_atoms, num_ligand_atoms))
protein_ligand_graph.edata['distance'] = F.reshape(F.zerocopy_from_numpy(
complex_dists[protein_ligand_indices].astype(np.float32)), (-1, 1))
# Merge the graphs
g = hetero_from_relations(
[protein_graph,
ligand_graph,
inter_ligand_graph,
inter_protein_graph,
ligand_protein_graph,
protein_ligand_graph]
)
# Get atomic numbers for all atoms left and set node features
ligand_atomic_numbers = np.array(get_atomic_numbers(ligand_mol, ligand_atom_indices_left))
# zero padding
ligand_atomic_numbers = np.concatenate([
ligand_atomic_numbers, np.zeros(num_ligand_atoms - len(ligand_atom_indices_left))])
protein_atomic_numbers = np.array(get_atomic_numbers(protein_mol, protein_atom_indices_left))
# zero padding
protein_atomic_numbers = np.concatenate([
protein_atomic_numbers, np.zeros(num_protein_atoms - len(protein_atom_indices_left))])
g.nodes['ligand_atom'].data['atomic_number'] = F.reshape(F.zerocopy_from_numpy(
ligand_atomic_numbers.astype(np.float32)), (-1, 1))
g.nodes['protein_atom'].data['atomic_number'] = F.reshape(F.zerocopy_from_numpy(
protein_atomic_numbers.astype(np.float32)), (-1, 1))
# Prepare mask indicating the existence of nodes
ligand_masks = np.zeros((num_ligand_atoms, 1))
ligand_masks[:len(ligand_atom_indices_left), :] = 1
g.nodes['ligand_atom'].data['mask'] = F.zerocopy_from_numpy(
ligand_masks.astype(np.float32))
protein_masks = np.zeros((num_protein_atoms, 1))
protein_masks[:len(protein_atom_indices_left), :] = 1
g.nodes['protein_atom'].data['mask'] = F.zerocopy_from_numpy(
protein_masks.astype(np.float32))
return g
"""Early stopping"""
# pylint: disable= no-member, arguments-differ, invalid-name
import datetime
import torch
__all__ = ['EarlyStopping']
# pylint: disable=C0103
class EarlyStopping(object):
"""Early stop tracker
Save model checkpoint when observing a performance improvement on
the validation set and early stop if improvement has not been
observed for a particular number of epochs.
Parameters
----------
mode : str
* 'higher': Higher metric suggests a better model
* 'lower': Lower metric suggests a better model
patience : int
The early stopping will happen if we do not observe performance
improvement for ``patience`` consecutive epochs.
filename : str or None
Filename for storing the model checkpoint. If not specified,
we will automatically generate a file starting with ``early_stop``
based on the current time.
Examples
--------
Below gives a demo for a fake training process.
>>> import torch
>>> import torch.nn as nn
>>> from torch.nn import MSELoss
>>> from torch.optim import Adam
>>> from dgllife.utils import EarlyStopping
>>> model = nn.Linear(1, 1)
>>> criterion = MSELoss()
>>> # For MSE, the lower, the better
>>> stopper = EarlyStopping(mode='lower', filename='test.pth')
>>> optimizer = Adam(params=model.parameters(), lr=1e-3)
>>> for epoch in range(1000):
>>> x = torch.randn(1, 1) # Fake input
>>> y = torch.randn(1, 1) # Fake label
>>> pred = model(x)
>>> loss = criterion(y, pred)
>>> optimizer.zero_grad()
>>> loss.backward()
>>> optimizer.step()
>>> early_stop = stopper.step(loss.detach().data, model)
>>> if early_stop:
>>> break
>>> # Load the final parameters saved by the model
>>> stopper.load_checkpoint(model)
"""
def __init__(self, mode='higher', patience=10, filename=None):
if filename is None:
dt = datetime.datetime.now()
filename = 'early_stop_{}_{:02d}-{:02d}-{:02d}.pth'.format(
dt.date(), dt.hour, dt.minute, dt.second)
assert mode in ['higher', 'lower']
self.mode = mode
if self.mode == 'higher':
self._check = self._check_higher
else:
self._check = self._check_lower
self.patience = patience
self.counter = 0
self.filename = filename
self.best_score = None
self.early_stop = False
def _check_higher(self, score, prev_best_score):
"""Check if the new score is higher than the previous best score.
Parameters
----------
score : float
New score.
prev_best_score : float
Previous best score.
Returns
-------
bool
Whether the new score is higher than the previous best score.
"""
return score > prev_best_score
def _check_lower(self, score, prev_best_score):
"""Check if the new score is lower than the previous best score.
Parameters
----------
score : float
New score.
prev_best_score : float
Previous best score.
Returns
-------
bool
Whether the new score is lower than the previous best score.
"""
return score < prev_best_score
def step(self, score, model):
"""Update based on a new score.
The new score is typically model performance on the validation set
for a new epoch.
Parameters
----------
score : float
New score.
model : nn.Module
Model instance.
Returns
-------
bool
Whether an early stop should be performed.
"""
if self.best_score is None:
self.best_score = score
self.save_checkpoint(model)
elif self._check(score, self.best_score):
self.best_score = score
self.save_checkpoint(model)
self.counter = 0
else:
self.counter += 1
print(
f'EarlyStopping counter: {self.counter} out of {self.patience}')
if self.counter >= self.patience:
self.early_stop = True
return self.early_stop
def save_checkpoint(self, model):
'''Saves model when the metric on the validation set gets improved.
Parameters
----------
model : nn.Module
Model instance.
'''
torch.save({'model_state_dict': model.state_dict()}, self.filename)
def load_checkpoint(self, model):
'''Load the latest checkpoint
Parameters
----------
model : nn.Module
Model instance.
'''
model.load_state_dict(torch.load(self.filename)['model_state_dict'])
"""Evaluation of model performance."""
# pylint: disable= no-member, arguments-differ, invalid-name
import numpy as np
import torch
import torch.nn.functional as F
from scipy.stats import pearsonr
from sklearn.metrics import roc_auc_score
__all__ = ['Meter']
# pylint: disable=E1101
class Meter(object):
"""Track and summarize model performance on a dataset for (multi-label) prediction.
When dealing with multitask learning, quite often we normalize the labels so they are
roughly at a same scale. During the evaluation, we need to undo the normalization on
the predicted labels. If mean and std are not None, we will undo the normalization.
Currently we support evaluation with 4 metrics:
* ``pearson r2``
* ``mae``
* ``rmse``
* ``roc auc score``
Parameters
----------
mean : torch.float32 tensor of shape (T) or None.
Mean of existing training labels across tasks if not ``None``. ``T`` for the
number of tasks. Default to ``None`` and we assume no label normalization has been
performed.
std : torch.float32 tensor of shape (T)
Std of existing training labels across tasks if not ``None``. Default to ``None``
and we assume no label normalization has been performed.
Examples
--------
Below gives a demo for a fake evaluation epoch.
>>> import torch
>>> from dgllife.utils import Meter
>>> meter = Meter()
>>> # Simulate 10 fake mini-batches
>>> for batch_id in range(10):
>>> batch_label = torch.randn(3, 3)
>>> batch_pred = torch.randn(3, 3)
>>> meter.update(batch_pred, batch_label)
>>> # Get MAE for all tasks
>>> print(meter.compute_metric('mae'))
[1.1325558423995972, 1.0543707609176636, 1.094650149345398]
>>> # Get MAE averaged over all tasks
>>> print(meter.compute_metric('mae', reduction='mean'))
1.0938589175542195
>>> # Get the sum of MAE over all tasks
>>> print(meter.compute_metric('mae', reduction='sum'))
3.2815767526626587
"""
def __init__(self, mean=None, std=None):
self.mask = []
self.y_pred = []
self.y_true = []
if (mean is not None) and (std is not None):
self.mean = mean.cpu()
self.std = std.cpu()
else:
self.mean = None
self.std = None
def update(self, y_pred, y_true, mask=None):
"""Update for the result of an iteration
Parameters
----------
y_pred : float32 tensor
Predicted labels with shape ``(B, T)``,
``B`` for number of graphs in the batch and ``T`` for the number of tasks
y_true : float32 tensor
Ground truth labels with shape ``(B, T)``
mask : None or float32 tensor
Binary mask indicating the existence of ground truth labels with
shape ``(B, T)``. If None, we assume that all labels exist and create
a one-tensor for placeholder.
"""
self.y_pred.append(y_pred.detach().cpu())
self.y_true.append(y_true.detach().cpu())
if mask is None:
self.mask.append(torch.ones(self.y_pred[-1].shape))
else:
self.mask.append(mask.detach().cpu())
def _finalize(self):
"""Prepare for evaluation.
If normalization was performed on the ground truth labels during training,
we need to undo the normalization on the predicted labels.
Returns
-------
mask : float32 tensor
Binary mask indicating the existence of ground
truth labels with shape (B, T), B for batch size
and T for the number of tasks
y_pred : float32 tensor
Predicted labels with shape (B, T)
y_true : float32 tensor
Ground truth labels with shape (B, T)
"""
mask = torch.cat(self.mask, dim=0)
y_pred = torch.cat(self.y_pred, dim=0)
y_true = torch.cat(self.y_true, dim=0)
if (self.mean is not None) and (self.std is not None):
# To compensate for the imbalance between labels during training,
# we normalize the ground truth labels with training mean and std.
# We need to undo that for evaluation.
y_pred = y_pred * self.std + self.mean
return mask, y_pred, y_true
def _reduce_scores(self, scores, reduction='none'):
"""Finalize the scores to return.
Parameters
----------
scores : list of float
Scores for all tasks.
reduction : 'none' or 'mean' or 'sum'
Controls the form of scores for all tasks
Returns
-------
float or list of float
* If ``reduction == 'none'``, return the list of scores for all tasks.
* If ``reduction == 'mean'``, return the mean of scores for all tasks.
* If ``reduction == 'sum'``, return the sum of scores for all tasks.
"""
if reduction == 'none':
return scores
elif reduction == 'mean':
return np.mean(scores)
elif reduction == 'sum':
return np.sum(scores)
else:
raise ValueError(
"Expect reduction to be 'none', 'mean' or 'sum', got {}".format(reduction))
def multilabel_score(self, score_func, reduction='none'):
"""Evaluate for multi-label prediction.
Parameters
----------
score_func : callable
A score function that takes task-specific ground truth and predicted labels as
input and return a float as the score. The labels are in the form of 1D tensor.
reduction : 'none' or 'mean' or 'sum'
Controls the form of scores for all tasks
Returns
-------
float or list of float
* If ``reduction == 'none'``, return the list of scores for all tasks.
* If ``reduction == 'mean'``, return the mean of scores for all tasks.
* If ``reduction == 'sum'``, return the sum of scores for all tasks.
"""
mask, y_pred, y_true = self._finalize()
n_tasks = y_true.shape[1]
scores = []
for task in range(n_tasks):
task_w = mask[:, task]
task_y_true = y_true[:, task][task_w != 0]
task_y_pred = y_pred[:, task][task_w != 0]
task_score = score_func(task_y_true, task_y_pred)
if task_score is not None:
scores.append(task_score)
return self._reduce_scores(scores, reduction)
def pearson_r2(self, reduction='none'):
"""Compute squared Pearson correlation coefficient.
Parameters
----------
reduction : 'none' or 'mean' or 'sum'
Controls the form of scores for all tasks
Returns
-------
float or list of float
* If ``reduction == 'none'``, return the list of scores for all tasks.
* If ``reduction == 'mean'``, return the mean of scores for all tasks.
* If ``reduction == 'sum'``, return the sum of scores for all tasks.
"""
def score(y_true, y_pred):
return pearsonr(y_true.numpy(), y_pred.numpy())[0] ** 2
return self.multilabel_score(score, reduction)
def mae(self, reduction='none'):
"""Compute mean absolute error.
Parameters
----------
reduction : 'none' or 'mean' or 'sum'
Controls the form of scores for all tasks
Returns
-------
float or list of float
* If ``reduction == 'none'``, return the list of scores for all tasks.
* If ``reduction == 'mean'``, return the mean of scores for all tasks.
* If ``reduction == 'sum'``, return the sum of scores for all tasks.
"""
def score(y_true, y_pred):
return F.l1_loss(y_true, y_pred).data.item()
return self.multilabel_score(score, reduction)
def rmse(self, reduction='none'):
"""Compute root mean square error.
Parameters
----------
reduction : 'none' or 'mean' or 'sum'
Controls the form of scores for all tasks
Returns
-------
float or list of float
* If ``reduction == 'none'``, return the list of scores for all tasks.
* If ``reduction == 'mean'``, return the mean of scores for all tasks.
* If ``reduction == 'sum'``, return the sum of scores for all tasks.
"""
def score(y_true, y_pred):
return np.sqrt(F.mse_loss(y_pred, y_true).cpu().item())
return self.multilabel_score(score, reduction)
def roc_auc_score(self, reduction='none'):
"""Compute roc-auc score for binary classification.
ROC-AUC scores are not well-defined in cases where labels for a task have one single
class only. In this case we will simply ignore this task and print a warning message.
Parameters
----------
reduction : 'none' or 'mean' or 'sum'
Controls the form of scores for all tasks
Returns
-------
float or list of float
* If ``reduction == 'none'``, return the list of scores for all tasks.
* If ``reduction == 'mean'``, return the mean of scores for all tasks.
* If ``reduction == 'sum'``, return the sum of scores for all tasks.
"""
# Todo: This function only supports binary classification and we may need
# to support categorical classes.
assert (self.mean is None) and (self.std is None), \
'Label normalization should not be performed for binary classification.'
def score(y_true, y_pred):
if len(y_true.unique()) == 1:
print('Warning: Only one class {} present in y_true for a task. '
'ROC AUC score is not defined in that case.'.format(y_true[0]))
return None
else:
return roc_auc_score(y_true.long().numpy(), torch.sigmoid(y_pred).numpy())
return self.multilabel_score(score, reduction)
def compute_metric(self, metric_name, reduction='none'):
"""Compute metric based on metric name.
Parameters
----------
metric_name : str
* ``'r2'``: compute squared Pearson correlation coefficient
* ``'mae'``: compute mean absolute error
* ``'rmse'``: compute root mean square error
* ``'roc_auc_score'``: compute roc-auc score
reduction : 'none' or 'mean' or 'sum'
Controls the form of scores for all tasks
Returns
-------
float or list of float
* If ``reduction == 'none'``, return the list of scores for all tasks.
* If ``reduction == 'mean'``, return the mean of scores for all tasks.
* If ``reduction == 'sum'``, return the sum of scores for all tasks.
"""
if metric_name == 'r2':
return self.pearson_r2(reduction)
elif metric_name == 'mae':
return self.mae(reduction)
elif metric_name == 'rmse':
return self.rmse(reduction)
elif metric_name == 'roc_auc_score':
return self.roc_auc_score(reduction)
else:
raise ValueError('Expect metric_name to be "r2" or "mae" or "rmse" '
'or "roc_auc_score", got {}'.format(metric_name))
"""Node and edge featurization for molecular graphs."""
# pylint: disable= no-member, arguments-differ, invalid-name
import itertools
import os.path as osp
from collections import defaultdict
from functools import partial
from rdkit import Chem, RDConfig
from rdkit.Chem import AllChem, ChemicalFeatures
import numpy as np
import torch
import dgl.backend as F
__all__ = ['one_hot_encoding',
'atom_type_one_hot',
'atomic_number_one_hot',
'atomic_number',
'atom_degree_one_hot',
'atom_degree',
'atom_total_degree_one_hot',
'atom_total_degree',
'atom_explicit_valence_one_hot',
'atom_explicit_valence',
'atom_implicit_valence_one_hot',
'atom_implicit_valence',
'atom_hybridization_one_hot',
'atom_total_num_H_one_hot',
'atom_total_num_H',
'atom_formal_charge_one_hot',
'atom_formal_charge',
'atom_num_radical_electrons_one_hot',
'atom_num_radical_electrons',
'atom_is_aromatic_one_hot',
'atom_is_aromatic',
'atom_is_in_ring_one_hot',
'atom_is_in_ring',
'atom_chiral_tag_one_hot',
'atom_mass',
'ConcatFeaturizer',
'BaseAtomFeaturizer',
'CanonicalAtomFeaturizer',
'WeaveAtomFeaturizer',
'PretrainAtomFeaturizer',
'bond_type_one_hot',
'bond_is_conjugated_one_hot',
'bond_is_conjugated',
'bond_is_in_ring_one_hot',
'bond_is_in_ring',
'bond_stereo_one_hot',
'bond_direction_one_hot',
'BaseBondFeaturizer',
'CanonicalBondFeaturizer',
'WeaveEdgeFeaturizer',
'PretrainBondFeaturizer']
def one_hot_encoding(x, allowable_set, encode_unknown=False):
"""One-hot encoding.
Parameters
----------
x
Value to encode.
allowable_set : list
The elements of the allowable_set should be of the
same type as x.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element.
Returns
-------
list
List of boolean values where at most one value is True.
The list is of length ``len(allowable_set)`` if ``encode_unknown=False``
and ``len(allowable_set) + 1`` otherwise.
Examples
--------
>>> from dgllife.utils import one_hot_encoding
>>> one_hot_encoding('C', ['C', 'O'])
[True, False]
>>> one_hot_encoding('S', ['C', 'O'])
[False, False]
>>> one_hot_encoding('S', ['C', 'O'], encode_unknown=True)
[False, False, True]
"""
if encode_unknown and (allowable_set[-1] is not None):
allowable_set.append(None)
if encode_unknown and (x not in allowable_set):
x = None
return list(map(lambda s: x == s, allowable_set))
#################################################################
# Atom featurization
#################################################################
def atom_type_one_hot(atom, allowable_set=None, encode_unknown=False):
"""One hot encoding for the type of an atom.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
allowable_set : list of str
Atom types to consider. Default: ``C``, ``N``, ``O``, ``S``, ``F``, ``Si``, ``P``,
``Cl``, ``Br``, ``Mg``, ``Na``, ``Ca``, ``Fe``, ``As``, ``Al``, ``I``, ``B``, ``V``,
``K``, ``Tl``, ``Yb``, ``Sb``, ``Sn``, ``Ag``, ``Pd``, ``Co``, ``Se``, ``Ti``, ``Zn``,
``H``, ``Li``, ``Ge``, ``Cu``, ``Au``, ``Ni``, ``Cd``, ``In``, ``Mn``, ``Zr``, ``Cr``,
``Pt``, ``Hg``, ``Pb``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
Returns
-------
list
List of boolean values where at most one value is True.
See Also
--------
one_hot_encoding
atomic_number
atomic_number_one_hot
"""
if allowable_set is None:
allowable_set = ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na', 'Ca',
'Fe', 'As', 'Al', 'I', 'B', 'V', 'K', 'Tl', 'Yb', 'Sb', 'Sn',
'Ag', 'Pd', 'Co', 'Se', 'Ti', 'Zn', 'H', 'Li', 'Ge', 'Cu', 'Au',
'Ni', 'Cd', 'In', 'Mn', 'Zr', 'Cr', 'Pt', 'Hg', 'Pb']
return one_hot_encoding(atom.GetSymbol(), allowable_set, encode_unknown)
def atomic_number_one_hot(atom, allowable_set=None, encode_unknown=False):
"""One hot encoding for the atomic number of an atom.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
allowable_set : list of int
Atomic numbers to consider. Default: ``1`` - ``100``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
Returns
-------
list
List of boolean values where at most one value is True.
See Also
--------
one_hot_encoding
atomic_number
atom_type_one_hot
"""
if allowable_set is None:
allowable_set = list(range(1, 101))
return one_hot_encoding(atom.GetAtomicNum(), allowable_set, encode_unknown)
def atomic_number(atom):
"""Get the atomic number for an atom.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
Returns
-------
list
List containing one int only.
See Also
--------
atomic_number_one_hot
atom_type_one_hot
"""
return [atom.GetAtomicNum()]
def atom_degree_one_hot(atom, allowable_set=None, encode_unknown=False):
"""One hot encoding for the degree of an atom.
Note that the result will be different depending on whether the Hs are
explicitly modeled in the graph.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
allowable_set : list of int
Atom degrees to consider. Default: ``0`` - ``10``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
Returns
-------
list
List of boolean values where at most one value is True.
See Also
--------
one_hot_encoding
atom_degree
atom_total_degree
atom_total_degree_one_hot
"""
if allowable_set is None:
allowable_set = list(range(11))
return one_hot_encoding(atom.GetDegree(), allowable_set, encode_unknown)
def atom_degree(atom):
"""Get the degree of an atom.
Note that the result will be different depending on whether the Hs are
explicitly modeled in the graph.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
Returns
-------
list
List containing one int only.
See Also
--------
atom_degree_one_hot
atom_total_degree
atom_total_degree_one_hot
"""
return [atom.GetDegree()]
def atom_total_degree_one_hot(atom, allowable_set=None, encode_unknown=False):
"""One hot encoding for the degree of an atom including Hs.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
allowable_set : list
Total degrees to consider. Default: ``0`` - ``5``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
See Also
--------
one_hot_encoding
atom_degree
atom_degree_one_hot
atom_total_degree
"""
if allowable_set is None:
allowable_set = list(range(6))
return one_hot_encoding(atom.GetTotalDegree(), allowable_set, encode_unknown)
def atom_total_degree(atom):
"""The degree of an atom including Hs.
Returns
-------
list
List containing one int only.
See Also
--------
atom_total_degree_one_hot
atom_degree
atom_degree_one_hot
"""
return [atom.GetTotalDegree()]
def atom_explicit_valence_one_hot(atom, allowable_set=None, encode_unknown=False):
"""One hot encoding for the explicit valence of an aotm.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
allowable_set : list of int
Atom explicit valences to consider. Default: ``1`` - ``6``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
Returns
-------
list
List of boolean values where at most one value is True.
See Also
--------
one_hot_encoding
atom_explicit_valence
"""
if allowable_set is None:
allowable_set = list(range(1, 7))
return one_hot_encoding(atom.GetExplicitValence(), allowable_set, encode_unknown)
def atom_explicit_valence(atom):
"""Get the explicit valence of an atom.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
Returns
-------
list
List containing one int only.
See Also
--------
atom_explicit_valence_one_hot
"""
return [atom.GetExplicitValence()]
def atom_implicit_valence_one_hot(atom, allowable_set=None, encode_unknown=False):
"""One hot encoding for the implicit valence of an atom.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
allowable_set : list of int
Atom implicit valences to consider. Default: ``0`` - ``6``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
Returns
-------
list
List of boolean values where at most one value is True.
See Also
--------
atom_implicit_valence
"""
if allowable_set is None:
allowable_set = list(range(7))
return one_hot_encoding(atom.GetImplicitValence(), allowable_set, encode_unknown)
def atom_implicit_valence(atom):
"""Get the implicit valence of an atom.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
Reurns
------
list
List containing one int only.
See Also
--------
atom_implicit_valence_one_hot
"""
return [atom.GetImplicitValence()]
# pylint: disable=I1101
def atom_hybridization_one_hot(atom, allowable_set=None, encode_unknown=False):
"""One hot encoding for the hybridization of an atom.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
allowable_set : list of rdkit.Chem.rdchem.HybridizationType
Atom hybridizations to consider. Default: ``Chem.rdchem.HybridizationType.SP``,
``Chem.rdchem.HybridizationType.SP2``, ``Chem.rdchem.HybridizationType.SP3``,
``Chem.rdchem.HybridizationType.SP3D``, ``Chem.rdchem.HybridizationType.SP3D2``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
Returns
-------
list
List of boolean values where at most one value is True.
See Also
--------
one_hot_encoding
"""
if allowable_set is None:
allowable_set = [Chem.rdchem.HybridizationType.SP,
Chem.rdchem.HybridizationType.SP2,
Chem.rdchem.HybridizationType.SP3,
Chem.rdchem.HybridizationType.SP3D,
Chem.rdchem.HybridizationType.SP3D2]
return one_hot_encoding(atom.GetHybridization(), allowable_set, encode_unknown)
def atom_total_num_H_one_hot(atom, allowable_set=None, encode_unknown=False):
"""One hot encoding for the total number of Hs of an atom.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
allowable_set : list of int
Total number of Hs to consider. Default: ``0`` - ``4``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
Returns
-------
list
List of boolean values where at most one value is True.
See Also
--------
one_hot_encoding
atom_total_num_H
"""
if allowable_set is None:
allowable_set = list(range(5))
return one_hot_encoding(atom.GetTotalNumHs(), allowable_set, encode_unknown)
def atom_total_num_H(atom):
"""Get the total number of Hs of an atom.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
Returns
-------
list
List containing one int only.
See Also
--------
atom_total_num_H_one_hot
"""
return [atom.GetTotalNumHs()]
def atom_formal_charge_one_hot(atom, allowable_set=None, encode_unknown=False):
"""One hot encoding for the formal charge of an atom.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
allowable_set : list of int
Formal charges to consider. Default: ``-2`` - ``2``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
Returns
-------
list
List of boolean values where at most one value is True.
See Also
--------
one_hot_encoding
atom_formal_charge
"""
if allowable_set is None:
allowable_set = list(range(-2, 3))
return one_hot_encoding(atom.GetFormalCharge(), allowable_set, encode_unknown)
def atom_formal_charge(atom):
"""Get formal charge for an atom.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
Returns
-------
list
List containing one int only.
See Also
--------
atom_formal_charge_one_hot
"""
return [atom.GetFormalCharge()]
def atom_partial_charge(atom):
"""Get Gasteiger partial charge for an atom.
For using this function, you must have called ``AllChem.ComputeGasteigerCharges(mol)``
to compute Gasteiger charges.
Occasionally, we can get nan or infinity Gasteiger charges, in which case we will set
the result to be 0.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
Returns
-------
list
List containing one float only.
"""
gasteiger_charge = atom.GetProp('_GasteigerCharge')
if gasteiger_charge in ['-nan', 'nan', '-inf', 'inf']:
gasteiger_charge = 0
return [float(gasteiger_charge)]
def atom_num_radical_electrons_one_hot(atom, allowable_set=None, encode_unknown=False):
"""One hot encoding for the number of radical electrons of an atom.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
allowable_set : list of int
Number of radical electrons to consider. Default: ``0`` - ``4``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
Returns
-------
list
List of boolean values where at most one value is True.
See Also
--------
one_hot_encoding
atom_num_radical_electrons
"""
if allowable_set is None:
allowable_set = list(range(5))
return one_hot_encoding(atom.GetNumRadicalElectrons(), allowable_set, encode_unknown)
def atom_num_radical_electrons(atom):
"""Get the number of radical electrons for an atom.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
Returns
-------
list
List containing one int only.
See Also
--------
atom_num_radical_electrons_one_hot
"""
return [atom.GetNumRadicalElectrons()]
def atom_is_aromatic_one_hot(atom, allowable_set=None, encode_unknown=False):
"""One hot encoding for whether the atom is aromatic.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
allowable_set : list of bool
Conditions to consider. Default: ``False`` and ``True``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
Returns
-------
list
List of boolean values where at most one value is True.
See Also
--------
one_hot_encoding
atom_is_aromatic
"""
if allowable_set is None:
allowable_set = [False, True]
return one_hot_encoding(atom.GetIsAromatic(), allowable_set, encode_unknown)
def atom_is_aromatic(atom):
"""Get whether the atom is aromatic.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
Returns
-------
list
List containing one bool only.
See Also
--------
atom_is_aromatic_one_hot
"""
return [atom.GetIsAromatic()]
def atom_is_in_ring_one_hot(atom, allowable_set=None, encode_unknown=False):
"""One hot encoding for whether the atom is in ring.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
allowable_set : list of bool
Conditions to consider. Default: ``False`` and ``True``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
Returns
-------
list
List of boolean values where at most one value is True.
See Also
--------
one_hot_encoding
atom_is_in_ring
"""
if allowable_set is None:
allowable_set = [False, True]
return one_hot_encoding(atom.IsInRing(), allowable_set, encode_unknown)
def atom_is_in_ring(atom):
"""Get whether the atom is in ring.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
Returns
-------
list
List containing one bool only.
See Also
--------
atom_is_in_ring_one_hot
"""
return [atom.IsInRing()]
def atom_chiral_tag_one_hot(atom, allowable_set=None, encode_unknown=False):
"""One hot encoding for the chiral tag of an atom.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
allowable_set : list of rdkit.Chem.rdchem.ChiralType
Chiral tags to consider. Default: ``rdkit.Chem.rdchem.ChiralType.CHI_UNSPECIFIED``,
``rdkit.Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW``,
``rdkit.Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW``,
``rdkit.Chem.rdchem.ChiralType.CHI_OTHER``.
Returns
-------
list
List containing one bool only.
See Also
--------
one_hot_encoding
"""
if allowable_set is None:
allowable_set = [Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW,
Chem.rdchem.ChiralType.CHI_OTHER]
return one_hot_encoding(atom.GetChiralTag(), allowable_set, encode_unknown)
def atom_mass(atom, coef=0.01):
"""Get the mass of an atom and scale it.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
coef : float
The mass will be multiplied by ``coef``.
Returns
-------
list
List containing one float only.
"""
return [atom.GetMass() * coef]
class ConcatFeaturizer(object):
"""Concatenate the evaluation results of multiple functions as a single feature.
Parameters
----------
func_list : list
List of functions for computing molecular descriptors from objects of a same
particular data type, e.g. ``rdkit.Chem.rdchem.Atom``. Each function is of signature
``func(data_type) -> list of float or bool or int``. The resulting order of
the features will follow that of the functions in the list.
"""
def __init__(self, func_list):
self.func_list = func_list
def __call__(self, x):
"""Featurize the input data.
Parameters
----------
x :
Data to featurize.
Returns
-------
list
List of feature values, which can be of type bool, float or int.
"""
return list(itertools.chain.from_iterable(
[func(x) for func in self.func_list]))
class BaseAtomFeaturizer(object):
"""An abstract class for atom featurizers.
Loop over all atoms in a molecule and featurize them with the ``featurizer_funcs``.
**We assume the resulting DGLGraph will not contain any virtual nodes and a node i in the
graph corresponds to exactly atom i in the molecule.**
Parameters
----------
featurizer_funcs : dict
Mapping feature name to the featurization function.
Each function is of signature ``func(rdkit.Chem.rdchem.Atom) -> list or 1D numpy array``.
feat_sizes : dict
Mapping feature name to the size of the corresponding feature. If None, they will be
computed when needed. Default: None.
Examples
--------
>>> from dgllife.utils import BaseAtomFeaturizer, atom_mass, atom_degree_one_hot
>>> from rdkit import Chem
>>> mol = Chem.MolFromSmiles('CCO')
>>> atom_featurizer = BaseAtomFeaturizer({'mass': atom_mass, 'degree': atom_degree_one_hot})
>>> atom_featurizer(mol)
{'mass': tensor([[0.1201],
[0.1201],
[0.1600]]),
'degree': tensor([[0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])}
>>> # Get feature size for atom mass
>>> print(atom_featurizer.feat_size('mass'))
1
>>> # Get feature size for atom degree
>>> print(atom_featurizer.feat_size('degree'))
11
See Also
--------
CanonicalAtomFeaturizer
"""
def __init__(self, featurizer_funcs, feat_sizes=None):
self.featurizer_funcs = featurizer_funcs
if feat_sizes is None:
feat_sizes = dict()
self._feat_sizes = feat_sizes
def feat_size(self, feat_name=None):
"""Get the feature size for ``feat_name``.
When there is only one feature, users do not need to provide ``feat_name``.
Parameters
----------
feat_name : str
Feature for query.
Returns
-------
int
Feature size for the feature with name ``feat_name``. Default to None.
"""
if feat_name is None:
assert len(self.featurizer_funcs) == 1, \
'feat_name should be provided if there are more than one features'
feat_name = list(self.featurizer_funcs.keys())[0]
if feat_name not in self.featurizer_funcs:
return ValueError('Expect feat_name to be in {}, got {}'.format(
list(self.featurizer_funcs.keys()), feat_name))
if feat_name not in self._feat_sizes:
atom = Chem.MolFromSmiles('C').GetAtomWithIdx(0)
self._feat_sizes[feat_name] = len(self.featurizer_funcs[feat_name](atom))
return self._feat_sizes[feat_name]
def __call__(self, mol):
"""Featurize all atoms in a molecule.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance.
Returns
-------
dict
For each function in self.featurizer_funcs with the key ``k``, store the computed
feature under the key ``k``. Each feature is a tensor of dtype float32 and shape
(N, M), where N is the number of atoms in the molecule.
"""
num_atoms = mol.GetNumAtoms()
atom_features = defaultdict(list)
# Compute features for each atom
for i in range(num_atoms):
atom = mol.GetAtomWithIdx(i)
for feat_name, feat_func in self.featurizer_funcs.items():
atom_features[feat_name].append(feat_func(atom))
# Stack the features and convert them to float arrays
processed_features = dict()
for feat_name, feat_list in atom_features.items():
feat = np.stack(feat_list)
processed_features[feat_name] = F.zerocopy_from_numpy(feat.astype(np.float32))
return processed_features
class CanonicalAtomFeaturizer(BaseAtomFeaturizer):
"""A default featurizer for atoms.
The atom features include:
* **One hot encoding of the atom type**. The supported atom types include
``C``, ``N``, ``O``, ``S``, ``F``, ``Si``, ``P``, ``Cl``, ``Br``, ``Mg``,
``Na``, ``Ca``, ``Fe``, ``As``, ``Al``, ``I``, ``B``, ``V``, ``K``, ``Tl``,
``Yb``, ``Sb``, ``Sn``, ``Ag``, ``Pd``, ``Co``, ``Se``, ``Ti``, ``Zn``,
``H``, ``Li``, ``Ge``, ``Cu``, ``Au``, ``Ni``, ``Cd``, ``In``, ``Mn``, ``Zr``,
``Cr``, ``Pt``, ``Hg``, ``Pb``.
* **One hot encoding of the atom degree**. The supported possibilities
include ``0 - 10``.
* **One hot encoding of the number of implicit Hs on the atom**. The supported
possibilities include ``0 - 6``.
* **Formal charge of the atom**.
* **Number of radical electrons of the atom**.
* **One hot encoding of the atom hybridization**. The supported possibilities include
``SP``, ``SP2``, ``SP3``, ``SP3D``, ``SP3D2``.
* **Whether the atom is aromatic**.
* **One hot encoding of the number of total Hs on the atom**. The supported possibilities
include ``0 - 4``.
**We assume the resulting DGLGraph will not contain any virtual nodes.**
Parameters
----------
atom_data_field : str
Name for storing atom features in DGLGraphs, default to 'h'.
Examples
--------
>>> from rdkit import Chem
>>> from dgllife.utils import CanonicalAtomFeaturizer
>>> mol = Chem.MolFromSmiles('CCO')
>>> atom_featurizer = CanonicalAtomFeaturizer(atom_data_field='feat')
>>> atom_featurizer(mol)
{'feat': tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
1., 0.],
[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 1.,
0., 0.],
[0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0.,
0., 0.]])}
>>> # Get feature size for nodes
>>> print(atom_featurizer.feat_size('feat'))
74
See Also
--------
BaseAtomFeaturizer
"""
def __init__(self, atom_data_field='h'):
super(CanonicalAtomFeaturizer, self).__init__(
featurizer_funcs={atom_data_field: ConcatFeaturizer(
[atom_type_one_hot,
atom_degree_one_hot,
atom_implicit_valence_one_hot,
atom_formal_charge,
atom_num_radical_electrons,
atom_hybridization_one_hot,
atom_is_aromatic,
atom_total_num_H_one_hot]
)})
class WeaveAtomFeaturizer(object):
"""Atom featurizer in Weave.
The atom featurization performed in `Molecular Graph Convolutions: Moving Beyond Fingerprints
<https://arxiv.org/abs/1603.00856>`__, which considers:
* atom types
* chirality
* formal charge
* partial charge
* aromatic atom
* hybridization
* hydrogen bond donor
* hydrogen bond acceptor
* the number of rings the atom belongs to for ring size between 3 and 8
Parameters
----------
atom_data_field : str
Name for storing atom features in DGLGraphs, default to 'h'.
atom_types : list of str or None
Atom types to consider for one-hot encoding. If None, we will use a default
choice of ``'H', 'C', 'N', 'O', 'F', 'P', 'S', 'Cl', 'Br', 'I'``.
chiral_types : list of Chem.rdchem.ChiralType or None
Atom chirality to consider for one-hot encoding. If None, we will use a default
choice of ``Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW``.
hybridization_types : list of Chem.rdchem.HybridizationType or None
Atom hybridization types to consider for one-hot encoding. If None, we will use a
default choice of ``Chem.rdchem.HybridizationType.SP, Chem.rdchem.HybridizationType.SP2,
Chem.rdchem.HybridizationType.SP3``.
"""
def __init__(self, atom_data_field='h', atom_types=None, chiral_types=None,
hybridization_types=None):
super(WeaveAtomFeaturizer, self).__init__()
self._atom_data_field = atom_data_field
if atom_types is None:
atom_types = ['H', 'C', 'N', 'O', 'F', 'P', 'S', 'Cl', 'Br', 'I']
self._atom_types = atom_types
if chiral_types is None:
chiral_types = [Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW]
self._chiral_types = chiral_types
if hybridization_types is None:
hybridization_types = [Chem.rdchem.HybridizationType.SP,
Chem.rdchem.HybridizationType.SP2,
Chem.rdchem.HybridizationType.SP3]
self._hybridization_types = hybridization_types
self._featurizer = ConcatFeaturizer([
partial(atom_type_one_hot, allowable_set=atom_types, encode_unknown=True),
partial(atom_chiral_tag_one_hot, allowable_set=chiral_types),
atom_formal_charge, atom_partial_charge, atom_is_aromatic,
partial(atom_hybridization_one_hot, allowable_set=hybridization_types)
])
def feat_size(self):
"""Get the feature size.
Returns
-------
int
Feature size.
"""
mol = Chem.MolFromSmiles('C')
feats = self(mol)[self._atom_data_field]
return feats.shape[-1]
def get_donor_acceptor_info(self, mol_feats):
"""Bookkeep whether an atom is donor/acceptor for hydrogen bonds.
Parameters
----------
mol_feats : tuple of rdkit.Chem.rdMolChemicalFeatures.MolChemicalFeature
Features for molecules.
Returns
-------
is_donor : dict
Mapping atom ids to binary values indicating whether atoms
are donors for hydrogen bonds
is_acceptor : dict
Mapping atom ids to binary values indicating whether atoms
are acceptors for hydrogen bonds
"""
is_donor = defaultdict(bool)
is_acceptor = defaultdict(bool)
# Get hydrogen bond donor/acceptor information
for feats in mol_feats:
if feats.GetFamily() == 'Donor':
nodes = feats.GetAtomIds()
for u in nodes:
is_donor[u] = True
elif feats.GetFamily() == 'Acceptor':
nodes = feats.GetAtomIds()
for u in nodes:
is_acceptor[u] = True
return is_donor, is_acceptor
def __call__(self, mol):
"""Featurizes the input molecule.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance.
Returns
-------
dict
Mapping atom_data_field as specified in the input argument to the atom
features, which is a float32 tensor of shape (N, M), N is the number of
atoms and M is the feature size.
"""
atom_features = []
AllChem.ComputeGasteigerCharges(mol)
num_atoms = mol.GetNumAtoms()
# Get information for donor and acceptor
fdef_name = osp.join(RDConfig.RDDataDir, 'BaseFeatures.fdef')
mol_featurizer = ChemicalFeatures.BuildFeatureFactory(fdef_name)
mol_feats = mol_featurizer.GetFeaturesForMol(mol)
is_donor, is_acceptor = self.get_donor_acceptor_info(mol_feats)
# Get a symmetrized smallest set of smallest rings
# Following the practice from Chainer Chemistry (https://github.com/chainer/
# chainer-chemistry/blob/da2507b38f903a8ee333e487d422ba6dcec49b05/chainer_chemistry/
# dataset/preprocessors/weavenet_preprocessor.py)
sssr = Chem.GetSymmSSSR(mol)
for i in range(num_atoms):
atom = mol.GetAtomWithIdx(i)
# Features that can be computed directly from RDKit atom instances, which is a list
feats = self._featurizer(atom)
# Donor/acceptor indicator
feats.append(float(is_donor[i]))
feats.append(float(is_acceptor[i]))
# Count the number of rings the atom belongs to for ring size between 3 and 8
count = [0 for _ in range(3, 9)]
for ring in sssr:
ring_size = len(ring)
if i in ring and 3 <= ring_size <= 8:
count[ring_size - 3] += 1
feats.extend(count)
atom_features.append(feats)
atom_features = np.stack(atom_features)
return {self._atom_data_field: F.zerocopy_from_numpy(atom_features.astype(np.float32))}
class PretrainAtomFeaturizer(object):
"""AtomFeaturizer in Strategies for Pre-training Graph Neural Networks.
The atom featurization performed in `Strategies for Pre-training Graph Neural Networks
<https://arxiv.org/abs/1905.12265>`__, which considers:
* atomic number
* chirality
Parameters
----------
atomic_number_types : list of int or None
Atomic number types to consider for one-hot encoding. If None, we will use a default
choice of 1-118.
chiral_types : list of Chem.rdchem.ChiralType or None
Atom chirality to consider for one-hot encoding. If None, we will use a default
choice of ``Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW, Chem.rdchem.ChiralType.CHI_OTHER``.
"""
def __init__(self, atomic_number_types=None, chiral_types=None):
if atomic_number_types is None:
atomic_number_types = list(range(1, 119))
self._atomic_number_types = atomic_number_types
if chiral_types is None:
chiral_types = [
Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW,
Chem.rdchem.ChiralType.CHI_OTHER
]
self._chiral_types = chiral_types
def __call__(self, mol):
"""Featurizes the input molecule.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance.
Returns
-------
dict
Mapping 'atomic_number' and 'chirality_type' to separately an int64 tensor
of shape (N, 1), N is the number of atoms
"""
atom_features = []
num_atoms = mol.GetNumAtoms()
for i in range(num_atoms):
atom = mol.GetAtomWithIdx(i)
atom_features.append([
self._atomic_number_types.index(atom.GetAtomicNum()),
self._chiral_types.index(atom.GetChiralTag())
])
atom_features = np.stack(atom_features)
atom_features = F.zerocopy_from_numpy(atom_features.astype(np.int64))
return {
'atomic_number': atom_features[:, 0],
'chirality_type': atom_features[:, 1]
}
def bond_type_one_hot(bond, allowable_set=None, encode_unknown=False):
"""One hot encoding for the type of a bond.
Parameters
----------
bond : rdkit.Chem.rdchem.Bond
RDKit bond instance.
allowable_set : list of Chem.rdchem.BondType
Bond types to consider. Default: ``Chem.rdchem.BondType.SINGLE``,
``Chem.rdchem.BondType.DOUBLE``, ``Chem.rdchem.BondType.TRIPLE``,
``Chem.rdchem.BondType.AROMATIC``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
Returns
-------
list
List of boolean values where at most one value is True.
See Also
--------
one_hot_encoding
"""
if allowable_set is None:
allowable_set = [Chem.rdchem.BondType.SINGLE,
Chem.rdchem.BondType.DOUBLE,
Chem.rdchem.BondType.TRIPLE,
Chem.rdchem.BondType.AROMATIC]
return one_hot_encoding(bond.GetBondType(), allowable_set, encode_unknown)
def bond_is_conjugated_one_hot(bond, allowable_set=None, encode_unknown=False):
"""One hot encoding for whether the bond is conjugated.
Parameters
----------
bond : rdkit.Chem.rdchem.Bond
RDKit bond instance.
allowable_set : list of bool
Conditions to consider. Default: ``False`` and ``True``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
Returns
-------
list
List of boolean values where at most one value is True.
See Also
--------
one_hot_encoding
bond_is_conjugated
"""
if allowable_set is None:
allowable_set = [False, True]
return one_hot_encoding(bond.GetIsConjugated(), allowable_set, encode_unknown)
def bond_is_conjugated(bond):
"""Get whether the bond is conjugated.
Parameters
----------
bond : rdkit.Chem.rdchem.Bond
RDKit bond instance.
Returns
-------
list
List containing one bool only.
See Also
--------
bond_is_conjugated_one_hot
"""
return [bond.GetIsConjugated()]
def bond_is_in_ring_one_hot(bond, allowable_set=None, encode_unknown=False):
"""One hot encoding for whether the bond is in a ring of any size.
Parameters
----------
bond : rdkit.Chem.rdchem.Bond
RDKit bond instance.
allowable_set : list of bool
Conditions to consider. Default: ``False`` and ``True``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
Returns
-------
list
List of boolean values where at most one value is True.
See Also
--------
one_hot_encoding
bond_is_in_ring
"""
if allowable_set is None:
allowable_set = [False, True]
return one_hot_encoding(bond.IsInRing(), allowable_set, encode_unknown)
def bond_is_in_ring(bond):
"""Get whether the bond is in a ring of any size.
Parameters
----------
bond : rdkit.Chem.rdchem.Bond
RDKit bond instance.
Returns
-------
list
List containing one bool only.
See Also
--------
bond_is_in_ring_one_hot
"""
return [bond.IsInRing()]
def bond_stereo_one_hot(bond, allowable_set=None, encode_unknown=False):
"""One hot encoding for the stereo configuration of a bond.
Parameters
----------
bond : rdkit.Chem.rdchem.Bond
RDKit bond instance.
allowable_set : list of rdkit.Chem.rdchem.BondStereo
Stereo configurations to consider. Default: ``rdkit.Chem.rdchem.BondStereo.STEREONONE``,
``rdkit.Chem.rdchem.BondStereo.STEREOANY``, ``rdkit.Chem.rdchem.BondStereo.STEREOZ``,
``rdkit.Chem.rdchem.BondStereo.STEREOE``, ``rdkit.Chem.rdchem.BondStereo.STEREOCIS``,
``rdkit.Chem.rdchem.BondStereo.STEREOTRANS``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
Returns
-------
list
List of boolean values where at most one value is True.
See Also
--------
one_hot_encoding
"""
if allowable_set is None:
allowable_set = [Chem.rdchem.BondStereo.STEREONONE,
Chem.rdchem.BondStereo.STEREOANY,
Chem.rdchem.BondStereo.STEREOZ,
Chem.rdchem.BondStereo.STEREOE,
Chem.rdchem.BondStereo.STEREOCIS,
Chem.rdchem.BondStereo.STEREOTRANS]
return one_hot_encoding(bond.GetStereo(), allowable_set, encode_unknown)
def bond_direction_one_hot(bond, allowable_set=None, encode_unknown=False):
"""One hot encoding for the direction of a bond.
Parameters
----------
bond : rdkit.Chem.rdchem.Bond
RDKit bond instance.
allowable_set : list of Chem.rdchem.BondDir
Bond directions to consider. Default: ``Chem.rdchem.BondDir.NONE``,
``Chem.rdchem.BondDir.ENDUPRIGHT``, ``Chem.rdchem.BondDir.ENDDOWNRIGHT``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
Returns
-------
list
List of boolean values where at most one value is True.
See Also
--------
one_hot_encoding
"""
if allowable_set is None:
allowable_set = [Chem.rdchem.BondDir.NONE,
Chem.rdchem.BondDir.ENDUPRIGHT,
Chem.rdchem.BondDir.ENDDOWNRIGHT]
return one_hot_encoding(bond.GetBondDir(), allowable_set, encode_unknown)
class BaseBondFeaturizer(object):
"""An abstract class for bond featurizers.
Loop over all bonds in a molecule and featurize them with the ``featurizer_funcs``.
We assume the constructed ``DGLGraph`` is a bi-directed graph where the **i** th bond in the
molecule, i.e. ``mol.GetBondWithIdx(i)``, corresponds to the **(2i)**-th and **(2i+1)**-th edges
in the DGLGraph.
**We assume the resulting DGLGraph will be created with :func:`smiles_to_bigraph` without
self loops.**
Parameters
----------
featurizer_funcs : dict
Mapping feature name to the featurization function.
Each function is of signature ``func(rdkit.Chem.rdchem.Bond) -> list or 1D numpy array``.
feat_sizes : dict
Mapping feature name to the size of the corresponding feature. If None, they will be
computed when needed. Default: None.
Examples
--------
>>> from dgllife.utils import BaseBondFeaturizer, bond_type_one_hot, bond_is_in_ring
>>> from rdkit import Chem
>>> mol = Chem.MolFromSmiles('CCO')
>>> bond_featurizer = BaseBondFeaturizer({'type': bond_type_one_hot, 'ring': bond_is_in_ring})
>>> bond_featurizer(mol)
{'type': tensor([[1., 0., 0., 0.],
[1., 0., 0., 0.],
[1., 0., 0., 0.],
[1., 0., 0., 0.]]),
'ring': tensor([[0.], [0.], [0.], [0.]])}
>>> # Get feature size
>>> bond_featurizer.feat_size('type')
4
>>> bond_featurizer.feat_size('ring')
1
See Also
--------
CanonicalBondFeaturizer
"""
def __init__(self, featurizer_funcs, feat_sizes=None):
self.featurizer_funcs = featurizer_funcs
if feat_sizes is None:
feat_sizes = dict()
self._feat_sizes = feat_sizes
def feat_size(self, feat_name=None):
"""Get the feature size for ``feat_name``.
When there is only one feature, users do not need to provide ``feat_name``.
Parameters
----------
feat_name : str
Feature for query.
Returns
-------
int
Feature size for the feature with name ``feat_name``. Default to None.
"""
if feat_name is None:
assert len(self.featurizer_funcs) == 1, \
'feat_name should be provided if there are more than one features'
feat_name = list(self.featurizer_funcs.keys())[0]
if feat_name not in self.featurizer_funcs:
return ValueError('Expect feat_name to be in {}, got {}'.format(
list(self.featurizer_funcs.keys()), feat_name))
if feat_name not in self._feat_sizes:
bond = Chem.MolFromSmiles('CO').GetBondWithIdx(0)
self._feat_sizes[feat_name] = len(self.featurizer_funcs[feat_name](bond))
return self._feat_sizes[feat_name]
def __call__(self, mol):
"""Featurize all bonds in a molecule.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance.
Returns
-------
dict
For each function in self.featurizer_funcs with the key ``k``, store the computed
feature under the key ``k``. Each feature is a tensor of dtype float32 and shape
(N, M), where N is the number of atoms in the molecule.
"""
num_bonds = mol.GetNumBonds()
bond_features = defaultdict(list)
# Compute features for each bond
for i in range(num_bonds):
bond = mol.GetBondWithIdx(i)
for feat_name, feat_func in self.featurizer_funcs.items():
feat = feat_func(bond)
bond_features[feat_name].extend([feat, feat.copy()])
# Stack the features and convert them to float arrays
processed_features = dict()
for feat_name, feat_list in bond_features.items():
feat = np.stack(feat_list)
processed_features[feat_name] = F.zerocopy_from_numpy(feat.astype(np.float32))
return processed_features
class CanonicalBondFeaturizer(BaseBondFeaturizer):
"""A default featurizer for bonds.
The bond features include:
* **One hot encoding of the bond type**. The supported bond types include
``SINGLE``, ``DOUBLE``, ``TRIPLE``, ``AROMATIC``.
* **Whether the bond is conjugated.**.
* **Whether the bond is in a ring of any size.**
* **One hot encoding of the stereo configuration of a bond**. The supported bond stereo
configurations include ``STEREONONE``, ``STEREOANY``, ``STEREOZ``, ``STEREOE``,
``STEREOCIS``, ``STEREOTRANS``.
**We assume the resulting DGLGraph will be created with :func:`smiles_to_bigraph` without
self loops.**
Examples
--------
>>> from dgllife.utils import CanonicalBondFeaturizer
>>> from rdkit import Chem
>>> mol = Chem.MolFromSmiles('CCO')
>>> bond_featurizer = CanonicalBondFeaturizer(bond_data_field='feat')
>>> bond_featurizer(mol)
{'feat': tensor([[1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
[1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
[1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
[1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.]])}
>>> # Get feature size
>>> bond_featurizer.feat_size('type')
12
See Also
--------
BaseBondFeaturizer
"""
def __init__(self, bond_data_field='e'):
super(CanonicalBondFeaturizer, self).__init__(
featurizer_funcs={bond_data_field: ConcatFeaturizer(
[bond_type_one_hot,
bond_is_conjugated,
bond_is_in_ring,
bond_stereo_one_hot]
)})
# pylint: disable=E1102
class WeaveEdgeFeaturizer(object):
"""Edge featurizer in Weave.
The edge featurization is introduced in `Molecular Graph Convolutions:
Moving Beyond Fingerprints <https://arxiv.org/abs/1603.00856>`__.
This featurization is performed for a complete graph of atoms with self loops added,
which considers:
* Number of bonds between each pairs of atoms
* One-hot encoding of bond type if a bond exists between a pair of atoms
* Whether a pair of atoms belongs to a same ring
Parameters
----------
edge_data_field : str
Name for storing edge features in DGLGraphs, default to ``'e'``.
max_distance : int
Maximum number of bonds to consider between each pair of atoms.
Default to 7.
bond_types : list of Chem.rdchem.BondType or None
Bond types to consider for one hot encoding. If None, we consider by
default single, double, triple and aromatic bonds.
"""
def __init__(self, edge_data_field='e', max_distance=7, bond_types=None):
super(WeaveEdgeFeaturizer, self).__init__()
self._edge_data_field = edge_data_field
self._max_distance = max_distance
if bond_types is None:
bond_types = [Chem.rdchem.BondType.SINGLE,
Chem.rdchem.BondType.DOUBLE,
Chem.rdchem.BondType.TRIPLE,
Chem.rdchem.BondType.AROMATIC]
self._bond_types = bond_types
def feat_size(self):
"""Get the feature size.
Returns
-------
int
Feature size.
"""
mol = Chem.MolFromSmiles('C')
feats = self(mol)[self._edge_data_field]
return feats.shape[-1]
def __call__(self, mol):
"""Featurizes the input molecule.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance.
Returns
-------
dict
Mapping self._edge_data_field to a float32 tensor of shape (N, M), where
N is the number of atom pairs and M is the feature size.
"""
# Part 1 based on number of bonds between each pair of atoms
distance_matrix = torch.from_numpy(Chem.GetDistanceMatrix(mol))
# Change shape from (V, V, 1) to (V^2, 1)
distance_matrix = distance_matrix.float().reshape(-1, 1)
# Elementwise compare if distance is bigger than 0, 1, ..., max_distance - 1
distance_indicators = (distance_matrix >
torch.arange(0, self._max_distance).float()).float()
# Part 2 for one hot encoding of bond type.
num_atoms = mol.GetNumAtoms()
bond_indicators = torch.zeros(num_atoms, num_atoms, len(self._bond_types))
for bond in mol.GetBonds():
bond_type_encoding = torch.tensor(
bond_type_one_hot(bond, allowable_set=self._bond_types)).float()
begin_atom_idx, end_atom_idx = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
bond_indicators[begin_atom_idx, end_atom_idx] = bond_type_encoding
bond_indicators[end_atom_idx, begin_atom_idx] = bond_type_encoding
# Reshape from (V, V, num_bond_types) to (V^2, num_bond_types)
bond_indicators = bond_indicators.reshape(-1, len(self._bond_types))
# Part 3 for whether a pair of atoms belongs to a same ring.
sssr = Chem.GetSymmSSSR(mol)
ring_mate_indicators = torch.zeros(num_atoms, num_atoms, 1)
for ring in sssr:
ring = list(ring)
num_atoms_in_ring = len(ring)
for i in range(num_atoms_in_ring):
ring_mate_indicators[ring[i], torch.tensor(ring)] = 1
ring_mate_indicators = ring_mate_indicators.reshape(-1, 1)
return {self._edge_data_field: torch.cat([distance_indicators,
bond_indicators,
ring_mate_indicators], dim=1)}
class PretrainBondFeaturizer(object):
"""BondFeaturizer in Strategies for Pre-training Graph Neural Networks.
The bond featurization performed in `Strategies for Pre-training Graph Neural Networks
<https://arxiv.org/abs/1905.12265>`__, which considers:
* bond type
* bond direction
Parameters
----------
bond_types : list of Chem.rdchem.BondType or None
Bond types to consider. Default to ``Chem.rdchem.BondType.SINGLE``,
``Chem.rdchem.BondType.DOUBLE``, ``Chem.rdchem.BondType.TRIPLE``,
``Chem.rdchem.BondType.AROMATIC``.
bond_direction_types : list of Chem.rdchem.BondDir or None
Bond directions to consider. Default to ``Chem.rdchem.BondDir.NONE``,
``Chem.rdchem.BondDir.ENDUPRIGHT``, ``Chem.rdchem.BondDir.ENDDOWNRIGHT``.
self_loop : bool
Whether self loops will be added. Default to True.
"""
def __init__(self, bond_types=None, bond_direction_types=None, self_loop=True):
if bond_types is None:
bond_types = [
Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE,
Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC
]
self._bond_types = bond_types
if bond_direction_types is None:
bond_direction_types = [
Chem.rdchem.BondDir.NONE,
Chem.rdchem.BondDir.ENDUPRIGHT,
Chem.rdchem.BondDir.ENDDOWNRIGHT
]
self._bond_direction_types = bond_direction_types
self._self_loop = self_loop
def __call__(self, mol):
"""Featurizes the input molecule.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance.
Returns
-------
dict
Mapping 'bond_type' and 'bond_direction_type' separately to an int64
tensor of shape (N, 1), where N is the number of edges.
"""
edge_features = []
num_bonds = mol.GetNumBonds()
# Compute features for each bond
for i in range(num_bonds):
bond = mol.GetBondWithIdx(i)
bond_feats = [
self._bond_types.index(bond.GetBondType()),
self._bond_direction_types.index(bond.GetBondDir())
]
edge_features.extend([bond_feats, bond_feats.copy()])
if self._self_loop:
self_loop_features = torch.zeros((mol.GetNumAtoms(), 2), dtype=torch.int64)
self_loop_features[:, 0] = len(self._bond_types)
if num_bonds == 0:
edge_features = self_loop_features
else:
edge_features = np.stack(edge_features)
edge_features = F.zerocopy_from_numpy(edge_features.astype(np.int64))
edge_features = torch.cat([edge_features, self_loop_features], dim=0)
return {'bond_type': edge_features[:, 0], 'bond_direction_type': edge_features[:, 1]}
"""Convert molecules into DGLGraphs."""
# pylint: disable= no-member, arguments-differ, invalid-name
from functools import partial
import torch
from dgl import DGLGraph
from rdkit import Chem
from rdkit.Chem import rdmolfiles, rdmolops
from sklearn.neighbors import NearestNeighbors
__all__ = ['mol_to_graph',
'smiles_to_bigraph',
'mol_to_bigraph',
'smiles_to_complete_graph',
'mol_to_complete_graph',
'k_nearest_neighbors',
'mol_to_nearest_neighbor_graph',
'smiles_to_nearest_neighbor_graph']
# pylint: disable=I1101
def mol_to_graph(mol, graph_constructor, node_featurizer, edge_featurizer, canonical_atom_order):
"""Convert an RDKit molecule object into a DGLGraph and featurize for it.
This function can be used to construct any arbitrary ``DGLGraph`` from an
RDKit molecule instance.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule holder
graph_constructor : callable
Takes an RDKit molecule as input and returns a DGLGraph
node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for nodes like atoms in a molecule, which can be used to
update ndata for a DGLGraph.
edge_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for edges like bonds in a molecule, which can be used to
update edata for a DGLGraph.
canonical_atom_order : bool
Whether to use a canonical order of atoms returned by RDKit. Setting it
to true might change the order of atoms in the graph constructed.
Returns
-------
g : DGLGraph
Converted DGLGraph for the molecule
See Also
--------
mol_to_bigraph
mol_to_complete_graph
mol_to_nearest_neighbor_graph
"""
if canonical_atom_order:
new_order = rdmolfiles.CanonicalRankAtoms(mol)
mol = rdmolops.RenumberAtoms(mol, new_order)
g = graph_constructor(mol)
if node_featurizer is not None:
g.ndata.update(node_featurizer(mol))
if edge_featurizer is not None:
g.edata.update(edge_featurizer(mol))
return g
def construct_bigraph_from_mol(mol, add_self_loop=False):
"""Construct a bi-directed DGLGraph with topology only for the molecule.
The **i** th atom in the molecule, i.e. ``mol.GetAtomWithIdx(i)``, corresponds to the
**i** th node in the returned DGLGraph.
The **i** th bond in the molecule, i.e. ``mol.GetBondWithIdx(i)``, corresponds to the
**(2i)**-th and **(2i+1)**-th edges in the returned DGLGraph. The **(2i)**-th and
**(2i+1)**-th edges will be separately from **u** to **v** and **v** to **u**, where
**u** is ``bond.GetBeginAtomIdx()`` and **v** is ``bond.GetEndAtomIdx()``.
If self loops are added, the last **n** edges will separately be self loops for
atoms ``0, 1, ..., n-1``.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule holder
add_self_loop : bool
Whether to add self loops in DGLGraphs. Default to False.
Returns
-------
g : DGLGraph
Empty bigraph topology of the molecule
"""
g = DGLGraph()
# Add nodes
num_atoms = mol.GetNumAtoms()
g.add_nodes(num_atoms)
# Add edges
src_list = []
dst_list = []
num_bonds = mol.GetNumBonds()
for i in range(num_bonds):
bond = mol.GetBondWithIdx(i)
u = bond.GetBeginAtomIdx()
v = bond.GetEndAtomIdx()
src_list.extend([u, v])
dst_list.extend([v, u])
g.add_edges(src_list, dst_list)
if add_self_loop:
nodes = g.nodes()
g.add_edges(nodes, nodes)
return g
def mol_to_bigraph(mol, add_self_loop=False,
node_featurizer=None,
edge_featurizer=None,
canonical_atom_order=True):
"""Convert an RDKit molecule object into a bi-directed DGLGraph and featurize for it.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule holder
add_self_loop : bool
Whether to add self loops in DGLGraphs. Default to False.
node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for nodes like atoms in a molecule, which can be used to update
ndata for a DGLGraph. Default to None.
edge_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for edges like bonds in a molecule, which can be used to update
edata for a DGLGraph. Default to None.
canonical_atom_order : bool
Whether to use a canonical order of atoms returned by RDKit. Setting it
to true might change the order of atoms in the graph constructed. Default
to True.
Returns
-------
g : DGLGraph
Bi-directed DGLGraph for the molecule
Examples
--------
>>> from rdkit import Chem
>>> from dgllife.utils import mol_to_bigraph
>>> mol = Chem.MolFromSmiles('CCO')
>>> g = mol_to_bigraph(mol)
>>> print(g)
DGLGraph(num_nodes=3, num_edges=4,
ndata_schemes={}
edata_schemes={})
We can also initialize node/edge features when constructing graphs.
>>> import torch
>>> from rdkit import Chem
>>> from dgllife.utils import mol_to_bigraph
>>> def featurize_atoms(mol):
>>> feats = []
>>> for atom in mol.GetAtoms():
>>> feats.append(atom.GetAtomicNum())
>>> return {'atomic': torch.tensor(feats).reshape(-1, 1).float()}
>>> def featurize_bonds(mol):
>>> feats = []
>>> bond_types = [Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE,
>>> Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC]
>>> for bond in mol.GetBonds():
>>> btype = bond_types.index(bond.GetBondType())
>>> # One bond between atom u and v corresponds to two edges (u, v) and (v, u)
>>> feats.extend([btype, btype])
>>> return {'type': torch.tensor(feats).reshape(-1, 1).float()}
>>> mol = Chem.MolFromSmiles('CCO')
>>> g = mol_to_bigraph(mol, node_featurizer=featurize_atoms,
>>> edge_featurizer=featurize_bonds)
>>> print(g.ndata['atomic'])
tensor([[6.],
[8.],
[6.]])
>>> print(g.edata['type'])
tensor([[0.],
[0.],
[0.],
[0.]])
See Also
--------
smiles_to_bigraph
"""
return mol_to_graph(mol, partial(construct_bigraph_from_mol, add_self_loop=add_self_loop),
node_featurizer, edge_featurizer, canonical_atom_order)
def smiles_to_bigraph(smiles, add_self_loop=False,
node_featurizer=None,
edge_featurizer=None,
canonical_atom_order=True):
"""Convert a SMILES into a bi-directed DGLGraph and featurize for it.
Parameters
----------
smiles : str
String of SMILES
add_self_loop : bool
Whether to add self loops in DGLGraphs. Default to False.
node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for nodes like atoms in a molecule, which can be used to update
ndata for a DGLGraph. Default to None.
edge_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for edges like bonds in a molecule, which can be used to update
edata for a DGLGraph. Default to None.
canonical_atom_order : bool
Whether to use a canonical order of atoms returned by RDKit. Setting it
to true might change the order of atoms in the graph constructed. Default
to True.
Returns
-------
g : DGLGraph
Bi-directed DGLGraph for the molecule
Examples
--------
>>> from dgllife.utils import smiles_to_bigraph
>>> g = smiles_to_bigraph('CCO')
>>> print(g)
DGLGraph(num_nodes=3, num_edges=4,
ndata_schemes={}
edata_schemes={})
We can also initialize node/edge features when constructing graphs.
>>> import torch
>>> from rdkit import Chem
>>> from dgllife.utils import smiles_to_bigraph
>>> def featurize_atoms(mol):
>>> feats = []
>>> for atom in mol.GetAtoms():
>>> feats.append(atom.GetAtomicNum())
>>> return {'atomic': torch.tensor(feats).reshape(-1, 1).float()}
>>> def featurize_bonds(mol):
>>> feats = []
>>> bond_types = [Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE,
>>> Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC]
>>> for bond in mol.GetBonds():
>>> btype = bond_types.index(bond.GetBondType())
>>> # One bond between atom u and v corresponds to two edges (u, v) and (v, u)
>>> feats.extend([btype, btype])
>>> return {'type': torch.tensor(feats).reshape(-1, 1).float()}
>>> g = smiles_to_bigraph('CCO', node_featurizer=featurize_atoms,
>>> edge_featurizer=featurize_bonds)
>>> print(g.ndata['atomic'])
tensor([[6.],
[8.],
[6.]])
>>> print(g.edata['type'])
tensor([[0.],
[0.],
[0.],
[0.]])
See Also
--------
mol_to_bigraph
"""
mol = Chem.MolFromSmiles(smiles)
return mol_to_bigraph(mol, add_self_loop, node_featurizer,
edge_featurizer, canonical_atom_order)
def construct_complete_graph_from_mol(mol, add_self_loop=False):
"""Construct a complete graph with topology only for the molecule
The **i** th atom in the molecule, i.e. ``mol.GetAtomWithIdx(i)``, corresponds to the
**i** th node in the returned DGLGraph.
The edges are in the order of (0, 0), (1, 0), (2, 0), ... (0, 1), (1, 1), (2, 1), ...
If self loops are not created, we will not have (0, 0), (1, 1), ...
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule holder
add_self_loop : bool
Whether to add self loops in DGLGraphs. Default to False.
Returns
-------
g : DGLGraph
Empty complete graph topology of the molecule
"""
num_atoms = mol.GetNumAtoms()
edge_list = []
for i in range(num_atoms):
for j in range(num_atoms):
if i != j or add_self_loop:
edge_list.append((i, j))
g = DGLGraph(edge_list)
return g
def mol_to_complete_graph(mol, add_self_loop=False,
node_featurizer=None,
edge_featurizer=None,
canonical_atom_order=True):
"""Convert an RDKit molecule into a complete DGLGraph and featurize for it.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule holder
add_self_loop : bool
Whether to add self loops in DGLGraphs. Default to False.
node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for nodes like atoms in a molecule, which can be used to update
ndata for a DGLGraph. Default to None.
edge_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for edges like bonds in a molecule, which can be used to update
edata for a DGLGraph. Default to None.
canonical_atom_order : bool
Whether to use a canonical order of atoms returned by RDKit. Setting it
to true might change the order of atoms in the graph constructed. Default
to True.
Returns
-------
g : DGLGraph
Complete DGLGraph for the molecule
Examples
--------
>>> from rdkit import Chem
>>> from dgllife.utils import mol_to_complete_graph
>>> mol = Chem.MolFromSmiles('CCO')
>>> g = mol_to_complete_graph(mol)
>>> print(g)
DGLGraph(num_nodes=3, num_edges=6,
ndata_schemes={}
edata_schemes={})
We can also initialize node/edge features when constructing graphs.
>>> import torch
>>> from rdkit import Chem
>>> from dgllife.utils import mol_to_complete_graph
>>> from functools import partial
>>> def featurize_atoms(mol):
>>> feats = []
>>> for atom in mol.GetAtoms():
>>> feats.append(atom.GetAtomicNum())
>>> return {'atomic': torch.tensor(feats).reshape(-1, 1).float()}
>>> def featurize_edges(mol, add_self_loop=False):
>>> feats = []
>>> num_atoms = mol.GetNumAtoms()
>>> atoms = list(mol.GetAtoms())
>>> distance_matrix = Chem.GetDistanceMatrix(mol)
>>> for i in range(num_atoms):
>>> for j in range(num_atoms):
>>> if i != j or add_self_loop:
>>> feats.append(float(distance_matrix[i, j]))
>>> return {'dist': torch.tensor(feats).reshape(-1, 1).float()}
>>> mol = Chem.MolFromSmiles('CCO')
>>> add_self_loop = True
>>> g = mol_to_complete_graph(
>>> mol, add_self_loop=add_self_loop, node_featurizer=featurize_atoms,
>>> edge_featurizer=partial(featurize_edges, add_self_loop=add_self_loop))
>>> print(g.ndata['atomic'])
tensor([[6.],
[8.],
[6.]])
>>> print(g.edata['dist'])
tensor([[0.],
[2.],
[1.],
[2.],
[0.],
[1.],
[1.],
[1.],
[0.]])
See Also
--------
smiles_to_complete_graph
"""
return mol_to_graph(mol,
partial(construct_complete_graph_from_mol, add_self_loop=add_self_loop),
node_featurizer, edge_featurizer, canonical_atom_order)
def smiles_to_complete_graph(smiles, add_self_loop=False,
node_featurizer=None,
edge_featurizer=None,
canonical_atom_order=True):
"""Convert a SMILES into a complete DGLGraph and featurize for it.
Parameters
----------
smiles : str
String of SMILES
add_self_loop : bool
Whether to add self loops in DGLGraphs. Default to False.
node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for nodes like atoms in a molecule, which can be used to update
ndata for a DGLGraph. Default to None.
edge_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for edges like bonds in a molecule, which can be used to update
edata for a DGLGraph. Default to None.
canonical_atom_order : bool
Whether to use a canonical order of atoms returned by RDKit. Setting it
to true might change the order of atoms in the graph constructed. Default
to True.
Returns
-------
g : DGLGraph
Complete DGLGraph for the molecule
Examples
--------
>>> from dgllife.utils import smiles_to_complete_graph
>>> g = smiles_to_complete_graph('CCO')
>>> print(g)
DGLGraph(num_nodes=3, num_edges=6,
ndata_schemes={}
edata_schemes={})
We can also initialize node/edge features when constructing graphs.
>>> import torch
>>> from rdkit import Chem
>>> from dgllife.utils import smiles_to_complete_graph
>>> from functools import partial
>>> def featurize_atoms(mol):
>>> feats = []
>>> for atom in mol.GetAtoms():
>>> feats.append(atom.GetAtomicNum())
>>> return {'atomic': torch.tensor(feats).reshape(-1, 1).float()}
>>> def featurize_edges(mol, add_self_loop=False):
>>> feats = []
>>> num_atoms = mol.GetNumAtoms()
>>> atoms = list(mol.GetAtoms())
>>> distance_matrix = Chem.GetDistanceMatrix(mol)
>>> for i in range(num_atoms):
>>> for j in range(num_atoms):
>>> if i != j or add_self_loop:
>>> feats.append(float(distance_matrix[i, j]))
>>> return {'dist': torch.tensor(feats).reshape(-1, 1).float()}
>>> add_self_loop = True
>>> g = smiles_to_complete_graph(
>>> 'CCO', add_self_loop=add_self_loop, node_featurizer=featurize_atoms,
>>> edge_featurizer=partial(featurize_edges, add_self_loop=add_self_loop))
>>> print(g.ndata['atomic'])
tensor([[6.],
[8.],
[6.]])
>>> print(g.edata['dist'])
tensor([[0.],
[2.],
[1.],
[2.],
[0.],
[1.],
[1.],
[1.],
[0.]])
See Also
--------
mol_to_complete_graph
"""
mol = Chem.MolFromSmiles(smiles)
return mol_to_complete_graph(mol, add_self_loop, node_featurizer,
edge_featurizer, canonical_atom_order)
def k_nearest_neighbors(coordinates, neighbor_cutoff, max_num_neighbors=None,
p_distance=2, self_loops=False):
"""Find k nearest neighbors for each atom
We do not guarantee that the edges are sorted according to the distance
between atoms.
Parameters
----------
coordinates : numpy.ndarray of shape (N, D)
The coordinates of atoms in the molecule. N for the number of atoms
and D for the dimensions of the coordinates.
neighbor_cutoff : float
If the distance between a pair of nodes is larger than neighbor_cutoff,
they will not be considered as neighboring nodes.
max_num_neighbors : int or None.
If not None, then this specifies the maximum number of neighbors
allowed for each atom. Default to None.
p_distance : int
We compute the distance between neighbors using Minkowski (:math:`l_p`)
distance. When ``p_distance = 1``, Minkowski distance is equivalent to
Manhattan distance. When ``p_distance = 2``, Minkowski distance is
equivalent to the standard Euclidean distance. Default to 2.
self_loops : bool
Whether to allow a node to be its own neighbor. Default to False.
Returns
-------
srcs : list of int
Source nodes.
dsts : list of int
Destination nodes, corresponding to ``srcs``.
distances : list of float
Distances between the end nodes, corresponding to ``srcs`` and ``dsts``.
Examples
--------
>>> from dgllife.utils import get_mol_3d_coordinates, k_nearest_neighbors
>>> from rdkit import Chem
>>> from rdkit.Chem import AllChem
>>> mol = Chem.MolFromSmiles('CC1(C(N2C(S1)C(C2=O)NC(=O)CC3=CC=CC=C3)C(=O)O)C')
>>> AllChem.EmbedMolecule(mol)
>>> AllChem.MMFFOptimizeMolecule(mol)
>>> coords = get_mol_3d_coordinates(mol)
>>> srcs, dsts, dists = k_nearest_neighbors(coords, neighbor_cutoff=1.25)
>>> print(srcs)
[8, 7, 11, 10, 20, 19]
>>> print(dsts)
[7, 8, 10, 11, 19, 20]
>>> print(dists)
[1.2084666104583117, 1.2084666104583117, 1.226457824344217,
1.226457824344217, 1.2230522248065987, 1.2230522248065987]
See Also
--------
get_mol_3d_coordinates
mol_to_nearest_neighbor_graph
smiles_to_nearest_neighbor_graph
"""
num_atoms = coordinates.shape[0]
model = NearestNeighbors(radius=neighbor_cutoff, p=p_distance)
model.fit(coordinates)
dists_, nbrs = model.radius_neighbors(coordinates)
srcs, dsts, dists = [], [], []
for i in range(num_atoms):
dists_i = dists_[i].tolist()
nbrs_i = nbrs[i].tolist()
if not self_loops:
dists_i.remove(0)
nbrs_i.remove(i)
if max_num_neighbors is not None and len(nbrs_i) > max_num_neighbors:
packed_nbrs = list(zip(dists_i, nbrs_i))
# Sort neighbors based on distance from smallest to largest
packed_nbrs.sort(key=lambda tup: tup[0])
dists_i, nbrs_i = map(list, zip(*packed_nbrs))
dsts.extend([i for _ in range(max_num_neighbors)])
srcs.extend(nbrs_i[:max_num_neighbors])
dists.extend(dists_i[:max_num_neighbors])
else:
dsts.extend([i for _ in range(len(nbrs_i))])
srcs.extend(nbrs_i)
dists.extend(dists_i)
return srcs, dsts, dists
# pylint: disable=E1102
def mol_to_nearest_neighbor_graph(mol,
coordinates,
neighbor_cutoff,
max_num_neighbors=None,
p_distance=2,
add_self_loop=False,
node_featurizer=None,
edge_featurizer=None,
canonical_atom_order=True,
keep_dists=False,
dist_field='dist'):
"""Convert an RDKit molecule into a nearest neighbor graph and featurize for it.
Different from bigraph and complete graph, the nearest neighbor graph
may not be symmetric since i is the closest neighbor of j does not
necessarily suggest the other way.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule holder
coordinates : numpy.ndarray of shape (N, D)
The coordinates of atoms in the molecule. N for the number of atoms
and D for the dimensions of the coordinates.
neighbor_cutoff : float
If the distance between a pair of nodes is larger than neighbor_cutoff,
they will not be considered as neighboring nodes.
max_num_neighbors : int or None.
If not None, then this specifies the maximum number of neighbors
allowed for each atom. Default to None.
p_distance : int
We compute the distance between neighbors using Minkowski (:math:`l_p`)
distance. When ``p_distance = 1``, Minkowski distance is equivalent to
Manhattan distance. When ``p_distance = 2``, Minkowski distance is
equivalent to the standard Euclidean distance. Default to 2.
add_self_loop : bool
Whether to add self loops in DGLGraphs. Default to False.
node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for nodes like atoms in a molecule, which can be used to update
ndata for a DGLGraph. Default to None.
edge_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for edges like bonds in a molecule, which can be used to update
edata for a DGLGraph. Default to None.
canonical_atom_order : bool
Whether to use a canonical order of atoms returned by RDKit. Setting it
to true might change the order of atoms in the graph constructed. Default
to True.
keep_dists : bool
Whether to store the distance between neighboring atoms in ``edata`` of the
constructed DGLGraphs. Default to False.
dist_field : str
Field for storing distance between neighboring atoms in ``edata``. This comes
into effect only when ``keep_dists=True``. Default to ``'dist'``.
Returns
-------
g : DGLGraph
Nearest neighbor DGLGraph for the molecule
Examples
--------
>>> from dgllife.utils import mol_to_nearest_neighbor_graph
>>> from rdkit import Chem
>>> from rdkit.Chem import AllChem
>>> mol = Chem.MolFromSmiles('CC1(C(N2C(S1)C(C2=O)NC(=O)CC3=CC=CC=C3)C(=O)O)C')
>>> AllChem.EmbedMolecule(mol)
>>> AllChem.MMFFOptimizeMolecule(mol)
>>> coords = get_mol_3d_coordinates(mol)
>>> g = mol_to_nearest_neighbor_graph(mol, coords, neighbor_cutoff=1.25)
>>> print(g)
DGLGraph(num_nodes=23, num_edges=6,
ndata_schemes={}
edata_schemes={})
Quite often we will want to use the distance between end atoms of edges, this can be
achieved with
>>> g = mol_to_nearest_neighbor_graph(mol, coords, neighbor_cutoff=1.25, keep_dists=True)
>>> print(g.edata['dist'])
tensor([[1.2024],
[1.2024],
[1.2270],
[1.2270],
[1.2259],
[1.2259]])
See Also
--------
get_mol_3d_coordinates
k_nearest_neighbors
smiles_to_nearest_neighbor_graph
"""
if canonical_atom_order:
new_order = rdmolfiles.CanonicalRankAtoms(mol)
mol = rdmolops.RenumberAtoms(mol, new_order)
srcs, dsts, dists = k_nearest_neighbors(coordinates=coordinates,
neighbor_cutoff=neighbor_cutoff,
max_num_neighbors=max_num_neighbors,
p_distance=p_distance,
self_loops=add_self_loop)
g = DGLGraph()
# Add nodes first since some nodes may be completely isolated
num_atoms = mol.GetNumAtoms()
g.add_nodes(num_atoms)
# Add edges
g.add_edges(srcs, dsts)
if node_featurizer is not None:
g.ndata.update(node_featurizer(mol))
if edge_featurizer is not None:
g.edata.update(edge_featurizer(mol))
if keep_dists:
assert dist_field not in g.edata, \
'Expect {} to be reserved for distance between neighboring atoms.'
g.edata[dist_field] = torch.tensor(dists).float().reshape(-1, 1)
return g
def smiles_to_nearest_neighbor_graph(smiles,
coordinates,
neighbor_cutoff,
max_num_neighbors=None,
p_distance=2,
add_self_loop=False,
node_featurizer=None,
edge_featurizer=None,
canonical_atom_order=True,
keep_dists=False,
dist_field='dist'):
"""Convert a SMILES into a nearest neighbor graph and featurize for it.
Different from bigraph and complete graph, the nearest neighbor graph
may not be symmetric since i is the closest neighbor of j does not
necessarily suggest the other way.
Parameters
----------
smiles : str
String of SMILES
coordinates : numpy.ndarray of shape (N, D)
The coordinates of atoms in the molecule. N for the number of atoms
and D for the dimensions of the coordinates.
neighbor_cutoff : float
If the distance between a pair of nodes is larger than neighbor_cutoff,
they will not be considered as neighboring nodes.
max_num_neighbors : int or None.
If not None, then this specifies the maximum number of neighbors
allowed for each atom. Default to None.
p_distance : int
We compute the distance between neighbors using Minkowski (:math:`l_p`)
distance. When ``p_distance = 1``, Minkowski distance is equivalent to
Manhattan distance. When ``p_distance = 2``, Minkowski distance is
equivalent to the standard Euclidean distance. Default to 2.
add_self_loop : bool
Whether to add self loops in DGLGraphs. Default to False.
node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for nodes like atoms in a molecule, which can be used to update
ndata for a DGLGraph. Default to None.
edge_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for edges like bonds in a molecule, which can be used to update
edata for a DGLGraph. Default to None.
canonical_atom_order : bool
Whether to use a canonical order of atoms returned by RDKit. Setting it
to true might change the order of atoms in the graph constructed. Default
to True.
keep_dists : bool
Whether to store the distance between neighboring atoms in ``edata`` of the
constructed DGLGraphs. Default to False.
dist_field : str
Field for storing distance between neighboring atoms in ``edata``. This comes
into effect only when ``keep_dists=True``. Default to ``'dist'``.
Returns
-------
g : DGLGraph
Nearest neighbor DGLGraph for the molecule
Examples
--------
>>> from dgllife.utils import smiles_to_nearest_neighbor_graph
>>> from rdkit import Chem
>>> from rdkit.Chem import AllChem
>>> smiles = 'CC1(C(N2C(S1)C(C2=O)NC(=O)CC3=CC=CC=C3)C(=O)O)C'
>>> mol = Chem.MolFromSmiles(smiles)
>>> AllChem.EmbedMolecule(mol)
>>> AllChem.MMFFOptimizeMolecule(mol)
>>> coords = get_mol_3d_coordinates(mol)
>>> g = mol_to_nearest_neighbor_graph(mol, coords, neighbor_cutoff=1.25)
>>> print(g)
DGLGraph(num_nodes=23, num_edges=6,
ndata_schemes={}
edata_schemes={})
Quite often we will want to use the distance between end atoms of edges, this can be
achieved with
>>> g = smiles_to_nearest_neighbor_graph(smiles, coords, neighbor_cutoff=1.25, keep_dists=True)
>>> print(g.edata['dist'])
tensor([[1.2024],
[1.2024],
[1.2270],
[1.2270],
[1.2259],
[1.2259]])
See Also
--------
get_mol_3d_coordinates
k_nearest_neighbors
mol_to_nearest_neighbor_graph
"""
mol = Chem.MolFromSmiles(smiles)
return mol_to_nearest_neighbor_graph(
mol, coordinates, neighbor_cutoff, max_num_neighbors, p_distance, add_self_loop,
node_featurizer, edge_featurizer, canonical_atom_order, keep_dists, dist_field)
"""Utils for RDKit, mostly adapted from DeepChem
(https://github.com/deepchem/deepchem/blob/master/deepchem)."""
# pylint: disable= no-member, arguments-differ, invalid-name
import warnings
from functools import partial
from multiprocessing import Pool
from rdkit import Chem
from rdkit.Chem import AllChem
__all__ = ['get_mol_3d_coordinates',
'load_molecule',
'multiprocess_load_molecules']
# pylint: disable=W0702
def get_mol_3d_coordinates(mol):
"""Get 3D coordinates of the molecule.
This function requires that molecular conformation has been initialized.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance.
Returns
-------
numpy.ndarray of shape (N, 3) or None
The 3D coordinates of atoms in the molecule. N for the number of atoms in
the molecule. For failures in getting the conformations, None will be returned.
Examples
--------
An error will occur in the example below since the molecule object does not
carry conformation information.
>>> from rdkit import Chem
>>> from dgllife.utils import get_mol_3d_coordinates
>>> mol = Chem.MolFromSmiles('CCO')
Below we give a working example based on molecule conformation initialized from calculation.
>>> from rdkit.Chem import AllChem
>>> AllChem.EmbedMolecule(mol)
>>> AllChem.MMFFOptimizeMolecule(mol)
>>> coords = get_mol_3d_coordinates(mol)
>>> print(coords)
array([[ 1.20967478, -0.25802181, 0. ],
[-0.05021255, 0.57068079, 0. ],
[-1.15946223, -0.31265898, 0. ]])
"""
try:
conf = mol.GetConformer()
conf_num_atoms = conf.GetNumAtoms()
mol_num_atoms = mol.GetNumAtoms()
assert mol_num_atoms == conf_num_atoms, \
'Expect the number of atoms in the molecule and its conformation ' \
'to be the same, got {:d} and {:d}'.format(mol_num_atoms, conf_num_atoms)
return conf.GetPositions()
except:
warnings.warn('Unable to get conformation of the molecule.')
return None
# pylint: disable=E1101
def load_molecule(molecule_file, sanitize=False, calc_charges=False,
remove_hs=False, use_conformation=True):
"""Load a molecule from a file of format ``.mol2`` or ``.sdf`` or ``.pdbqt`` or ``.pdb``.
Parameters
----------
molecule_file : str
Path to file for storing a molecule, which can be of format ``.mol2`` or ``.sdf``
or ``.pdbqt`` or ``.pdb``.
sanitize : bool
Whether sanitization is performed in initializing RDKit molecule instances. See
https://www.rdkit.org/docs/RDKit_Book.html for details of the sanitization.
Default to False.
calc_charges : bool
Whether to add Gasteiger charges via RDKit. Setting this to be True will enforce
``sanitize`` to be True. Default to False.
remove_hs : bool
Whether to remove hydrogens via RDKit. Note that removing hydrogens can be quite
slow for large molecules. Default to False.
use_conformation : bool
Whether we need to extract molecular conformation from proteins and ligands.
Default to True.
Returns
-------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance for the loaded molecule.
coordinates : np.ndarray of shape (N, 3) or None
The 3D coordinates of atoms in the molecule. N for the number of atoms in
the molecule. None will be returned if ``use_conformation`` is False or
we failed to get conformation information.
"""
if molecule_file.endswith('.mol2'):
mol = Chem.MolFromMol2File(molecule_file, sanitize=False, removeHs=False)
elif molecule_file.endswith('.sdf'):
supplier = Chem.SDMolSupplier(molecule_file, sanitize=False, removeHs=False)
mol = supplier[0]
elif molecule_file.endswith('.pdbqt'):
with open(molecule_file) as file:
pdbqt_data = file.readlines()
pdb_block = ''
for line in pdbqt_data:
pdb_block += '{}\n'.format(line[:66])
mol = Chem.MolFromPDBBlock(pdb_block, sanitize=False, removeHs=False)
elif molecule_file.endswith('.pdb'):
mol = Chem.MolFromPDBFile(molecule_file, sanitize=False, removeHs=False)
else:
return ValueError('Expect the format of the molecule_file to be '
'one of .mol2, .sdf, .pdbqt and .pdb, got {}'.format(molecule_file))
try:
if sanitize or calc_charges:
Chem.SanitizeMol(mol)
if calc_charges:
# Compute Gasteiger charges on the molecule.
try:
AllChem.ComputeGasteigerCharges(mol)
except:
warnings.warn('Unable to compute charges for the molecule.')
if remove_hs:
mol = Chem.RemoveHs(mol)
except:
return None, None
if use_conformation:
coordinates = get_mol_3d_coordinates(mol)
else:
coordinates = None
return mol, coordinates
def multiprocess_load_molecules(files, sanitize=False, calc_charges=False,
remove_hs=False, use_conformation=True, num_processes=2):
"""Load molecules from files with multiprocessing, which can be of format ``.mol2`` or
``.sdf`` or ``.pdbqt`` or ``.pdb``.
Parameters
----------
files : list of str
Each element is a path to a file storing a molecule, which can be of format ``.mol2``,
``.sdf``, ``.pdbqt``, or ``.pdb``.
sanitize : bool
Whether sanitization is performed in initializing RDKit molecule instances. See
https://www.rdkit.org/docs/RDKit_Book.html for details of the sanitization.
Default to False.
calc_charges : bool
Whether to add Gasteiger charges via RDKit. Setting this to be True will enforce
``sanitize`` to be True. Default to False.
remove_hs : bool
Whether to remove hydrogens via RDKit. Note that removing hydrogens can be quite
slow for large molecules. Default to False.
use_conformation : bool
Whether we need to extract molecular conformation from proteins and ligands.
Default to True.
num_processes : int or None
Number of worker processes to use. If None,
then we will use the number of CPUs in the systetm. Default to 2.
Returns
-------
list of 2-tuples
The first element of each 2-tuple is an RDKit molecule instance. The second element
of each 2-tuple is the 3D atom coordinates of the corresponding molecule if
use_conformation is True and the coordinates has been successfully loaded. Otherwise,
it will be None.
"""
if num_processes == 1:
mols_loaded = []
for f in files:
mols_loaded.append(load_molecule(
f, sanitize=sanitize, calc_charges=calc_charges,
remove_hs=remove_hs, use_conformation=use_conformation))
else:
with Pool(processes=num_processes) as pool:
mols_loaded = pool.map_async(partial(
load_molecule, sanitize=sanitize, calc_charges=calc_charges,
remove_hs=remove_hs, use_conformation=use_conformation), files)
mols_loaded = mols_loaded.get()
return mols_loaded
"""Various methods for splitting chemical datasets.
We mostly adapt them from deepchem
(https://github.com/deepchem/deepchem/blob/master/deepchem/splits/splitters.py).
"""
# pylint: disable= no-member, arguments-differ, invalid-name
# pylint: disable=E0611
from collections import defaultdict
from functools import partial
from itertools import accumulate, chain
from rdkit import Chem
from rdkit.Chem import rdMolDescriptors
from rdkit.Chem.rdmolops import FastFindRings
from rdkit.Chem.Scaffolds import MurckoScaffold
import dgl.backend as F
import numpy as np
from dgl.data.utils import split_dataset, Subset
__all__ = ['ConsecutiveSplitter',
'RandomSplitter',
'MolecularWeightSplitter',
'ScaffoldSplitter',
'SingleTaskStratifiedSplitter']
def base_k_fold_split(split_method, dataset, k, log):
"""Split dataset for k-fold cross validation.
Parameters
----------
split_method : callable
Arbitrary method for splitting the dataset
into training, validation and test subsets.
dataset
We assume ``len(dataset)`` gives the size for the dataset and ``dataset[i]``
gives the ith datapoint.
k : int
Number of folds to use and should be no smaller than 2.
log : bool
Whether to print a message at the start of preparing each fold.
Returns
-------
all_folds : list of 2-tuples
Each element of the list represents a fold and is a 2-tuple (train_set, val_set),
which are all :class:`Subset` instances.
"""
assert k >= 2, 'Expect the number of folds to be no smaller than 2, got {:d}'.format(k)
all_folds = []
frac_per_part = 1. / k
for i in range(k):
if log:
print('Processing fold {:d}/{:d}'.format(i + 1, k))
# We are reusing the code for train-validation-test split.
train_set1, val_set, train_set2 = split_method(dataset,
frac_train=i * frac_per_part,
frac_val=frac_per_part,
frac_test=1. - (i + 1) * frac_per_part)
# For cross validation, each fold consists of only a train subset and
# a validation subset.
train_set = Subset(dataset, np.concatenate(
[train_set1.indices, train_set2.indices]).astype(np.int64))
all_folds.append((train_set, val_set))
return all_folds
def train_val_test_sanity_check(frac_train, frac_val, frac_test):
"""Sanity check for train-val-test split
Ensure that the fractions of the dataset to use for training,
validation and test add up to 1.
Parameters
----------
frac_train : float
Fraction of the dataset to use for training.
frac_val : float
Fraction of the dataset to use for validation.
frac_test : float
Fraction of the dataset to use for test.
"""
total_fraction = frac_train + frac_val + frac_test
assert np.allclose(total_fraction, 1.), \
'Expect the sum of fractions for training, validation and ' \
'test to be 1, got {:.4f}'.format(total_fraction)
def indices_split(dataset, frac_train, frac_val, frac_test, indices):
"""Reorder datapoints based on the specified indices and then take consecutive
chunks as subsets.
Parameters
----------
dataset
We assume ``len(dataset)`` gives the size for the dataset and ``dataset[i]``
gives the ith datapoint.
frac_train : float
Fraction of data to use for training.
frac_val : float
Fraction of data to use for validation.
frac_test : float
Fraction of data to use for test.
indices : list or ndarray
Indices specifying the order of datapoints.
Returns
-------
list of length 3
Subsets for training, validation and test, which are all :class:`Subset` instances.
"""
frac_list = np.array([frac_train, frac_val, frac_test])
assert np.allclose(np.sum(frac_list), 1.), \
'Expect frac_list sum to 1, got {:.4f}'.format(np.sum(frac_list))
num_data = len(dataset)
lengths = (num_data * frac_list).astype(int)
lengths[-1] = num_data - np.sum(lengths[:-1])
return [Subset(dataset, list(indices[offset - length:offset]))
for offset, length in zip(accumulate(lengths), lengths)]
def count_and_log(message, i, total, log_every_n):
"""Print a message to reflect the progress of processing once a while.
Parameters
----------
message : str
Message to print.
i : int
Current index.
total : int
Total count.
log_every_n : None or int
Molecule related computation can take a long time for a large dataset and we want
to learn the progress of processing. This can be done by printing a message whenever
a batch of ``log_every_n`` molecules have been processed. If None, no messages will
be printed.
"""
if (log_every_n is not None) and ((i + 1) % log_every_n == 0):
print('{} {:d}/{:d}'.format(message, i + 1, total))
def prepare_mols(dataset, mols, sanitize, log_every_n=1000):
"""Prepare RDKit molecule instances.
Parameters
----------
dataset
We assume ``len(dataset)`` gives the size for the dataset, ``dataset[i]``
gives the ith datapoint and ``dataset.smiles[i]`` gives the SMILES for the
ith datapoint.
mols : None or list of rdkit.Chem.rdchem.Mol
None or pre-computed RDKit molecule instances. If not None, we expect a
one-on-one correspondence between ``dataset.smiles`` and ``mols``, i.e.
``mols[i]`` corresponds to ``dataset.smiles[i]``.
sanitize : bool
This argument only comes into effect when ``mols`` is None and decides whether
sanitization is performed in initializing RDKit molecule instances. See
https://www.rdkit.org/docs/RDKit_Book.html for details of the sanitization.
log_every_n : None or int
Molecule related computation can take a long time for a large dataset and we want
to learn the progress of processing. This can be done by printing a message whenever
a batch of ``log_every_n`` molecules have been processed. If None, no messages will
be printed. Default to 1000.
Returns
-------
mols : list of rdkit.Chem.rdchem.Mol
RDkit molecule instances where there is a one-on-one correspondence between
``dataset.smiles`` and ``mols``, i.e. ``mols[i]`` corresponds to ``dataset.smiles[i]``.
"""
if mols is not None:
# Sanity check
assert len(mols) == len(dataset), \
'Expect mols to be of the same size as that of the dataset, ' \
'got {:d} and {:d}'.format(len(mols), len(dataset))
else:
if log_every_n is not None:
print('Start initializing RDKit molecule instances...')
mols = []
for i, s in enumerate(dataset.smiles):
count_and_log('Creating RDKit molecule instance',
i, len(dataset.smiles), log_every_n)
mols.append(Chem.MolFromSmiles(s, sanitize=sanitize))
return mols
class ConsecutiveSplitter(object):
"""Split datasets with the input order.
The dataset is split without permutation, so the splitting is deterministic.
"""
@staticmethod
def train_val_test_split(dataset, frac_train=0.8, frac_val=0.1, frac_test=0.1):
"""Split the dataset into three consecutive chunks for training, validation and test.
Parameters
----------
dataset
We assume ``len(dataset)`` gives the size for the dataset and ``dataset[i]``
gives the ith datapoint.
frac_train : float
Fraction of data to use for training. By default, we set this to be 0.8, i.e.
80% of the dataset is used for training.
frac_val : float
Fraction of data to use for validation. By default, we set this to be 0.1, i.e.
10% of the dataset is used for validation.
frac_test : float
Fraction of data to use for test. By default, we set this to be 0.1, i.e.
10% of the dataset is used for test.
Returns
-------
list of length 3
Subsets for training, validation and test that also have ``len(dataset)`` and
``dataset[i]`` behaviors
"""
return split_dataset(dataset, frac_list=[frac_train, frac_val, frac_test], shuffle=False)
@staticmethod
def k_fold_split(dataset, k=5, log=True):
"""Split the dataset for k-fold cross validation by taking consecutive chunks.
Parameters
----------
dataset
We assume ``len(dataset)`` gives the size for the dataset and ``dataset[i]``
gives the ith datapoint.
k : int
Number of folds to use and should be no smaller than 2. Default to be 5.
log : bool
Whether to print a message at the start of preparing each fold.
Returns
-------
list of 2-tuples
Each element of the list represents a fold and is a 2-tuple ``(train_set, val_set)``.
``train_set`` and ``val_set`` also have ``len(dataset)`` and ``dataset[i]`` behaviors.
"""
return base_k_fold_split(ConsecutiveSplitter.train_val_test_split, dataset, k, log)
class RandomSplitter(object):
"""Randomly reorder datasets and then split them.
The dataset is split with permutation and the splitting is hence random.
"""
@staticmethod
def train_val_test_split(dataset, frac_train=0.8, frac_val=0.1,
frac_test=0.1, random_state=None):
"""Randomly permute the dataset and then split it into
three consecutive chunks for training, validation and test.
Parameters
----------
dataset
We assume ``len(dataset)`` gives the size for the dataset and ``dataset[i]``
gives the ith datapoint.
frac_train : float
Fraction of data to use for training. By default, we set this to be 0.8, i.e.
80% of the dataset is used for training.
frac_val : float
Fraction of data to use for validation. By default, we set this to be 0.1, i.e.
10% of the dataset is used for validation.
frac_test : float
Fraction of data to use for test. By default, we set this to be 0.1, i.e.
10% of the dataset is used for test.
random_state : None, int or array_like, optional
Random seed used to initialize the pseudo-random number generator.
Can be any integer between 0 and 2**32 - 1 inclusive, an array
(or other sequence) of such integers, or None (the default).
If seed is None, then RandomState will try to read data from /dev/urandom
(or the Windows analogue) if available or seed from the clock otherwise.
Returns
-------
list of length 3
Subsets for training, validation and test, which also have ``len(dataset)``
and ``dataset[i]`` behaviors.
"""
return split_dataset(dataset, frac_list=[frac_train, frac_val, frac_test],
shuffle=True, random_state=random_state)
@staticmethod
def k_fold_split(dataset, k=5, random_state=None, log=True):
"""Randomly permute the dataset and then split it
for k-fold cross validation by taking consecutive chunks.
Parameters
----------
dataset
We assume ``len(dataset)`` gives the size for the dataset and ``dataset[i]``
gives the ith datapoint.
k : int
Number of folds to use and should be no smaller than 2. Default to be 5.
random_state : None, int or array_like, optional
Random seed used to initialize the pseudo-random number generator.
Can be any integer between 0 and 2**32 - 1 inclusive, an array
(or other sequence) of such integers, or None (the default).
If seed is None, then RandomState will try to read data from /dev/urandom
(or the Windows analogue) if available or seed from the clock otherwise.
log : bool
Whether to print a message at the start of preparing each fold. Default to True.
Returns
-------
list of 2-tuples
Each element of the list represents a fold and is a 2-tuple ``(train_set, val_set)``.
``train_set`` and ``val_set`` also have ``len(dataset)`` and ``dataset[i]`` behaviors.
"""
# Permute the dataset only once so that each datapoint
# will appear once in exactly one fold.
indices = np.random.RandomState(seed=random_state).permutation(len(dataset))
return base_k_fold_split(partial(indices_split, indices=indices), dataset, k, log)
# pylint: disable=I1101
class MolecularWeightSplitter(object):
"""Sort molecules based on their weights and then split them."""
@staticmethod
def molecular_weight_indices(molecules, log_every_n):
"""Reorder molecules based on molecular weights.
Parameters
----------
molecules : list of rdkit.Chem.rdchem.Mol
Pre-computed RDKit molecule instances. We expect a one-on-one
correspondence between ``dataset.smiles`` and ``mols``, i.e.
``mols[i]`` corresponds to ``dataset.smiles[i]``.
log_every_n : None or int
Molecule related computation can take a long time for a large dataset and we want
to learn the progress of processing. This can be done by printing a message whenever
a batch of ``log_every_n`` molecules have been processed. If None, no messages will
be printed.
Returns
-------
indices : list or ndarray
Indices specifying the order of datapoints, which are basically
argsort of the molecular weights.
"""
if log_every_n is not None:
print('Start computing molecular weights.')
mws = []
for i, mol in enumerate(molecules):
count_and_log('Computing molecular weight for compound',
i, len(molecules), log_every_n)
mws.append(rdMolDescriptors.CalcExactMolWt(mol))
return np.argsort(mws)
@staticmethod
def train_val_test_split(dataset, mols=None, sanitize=True, frac_train=0.8,
frac_val=0.1, frac_test=0.1, log_every_n=1000):
"""Sort molecules based on their weights and then split them into
three consecutive chunks for training, validation and test.
Parameters
----------
dataset
We assume ``len(dataset)`` gives the size for the dataset, ``dataset[i]``
gives the ith datapoint and ``dataset.smiles[i]`` gives the SMILES for the
ith datapoint.
mols : None or list of rdkit.Chem.rdchem.Mol
None or pre-computed RDKit molecule instances. If not None, we expect a
one-on-one correspondence between ``dataset.smiles`` and ``mols``, i.e.
``mols[i]`` corresponds to ``dataset.smiles[i]``. Default to None.
sanitize : bool
This argument only comes into effect when ``mols`` is None and decides whether
sanitization is performed in initializing RDKit molecule instances. See
https://www.rdkit.org/docs/RDKit_Book.html for details of the sanitization.
Default to be True.
frac_train : float
Fraction of data to use for training. By default, we set this to be 0.8, i.e.
80% of the dataset is used for training.
frac_val : float
Fraction of data to use for validation. By default, we set this to be 0.1, i.e.
10% of the dataset is used for validation.
frac_test : float
Fraction of data to use for test. By default, we set this to be 0.1, i.e.
10% of the dataset is used for test.
log_every_n : None or int
Molecule related computation can take a long time for a large dataset and we want
to learn the progress of processing. This can be done by printing a message whenever
a batch of ``log_every_n`` molecules have been processed. If None, no messages will
be printed. Default to 1000.
Returns
-------
list of length 3
Subsets for training, validation and test, which also have ``len(dataset)``
and ``dataset[i]`` behaviors
"""
# Perform sanity check first as molecule instance initialization and descriptor
# computation can take a long time.
train_val_test_sanity_check(frac_train, frac_val, frac_test)
molecules = prepare_mols(dataset, mols, sanitize, log_every_n)
sorted_indices = MolecularWeightSplitter.molecular_weight_indices(molecules, log_every_n)
return indices_split(dataset, frac_train, frac_val, frac_test, sorted_indices)
@staticmethod
def k_fold_split(dataset, mols=None, sanitize=True, k=5, log_every_n=1000):
"""Sort molecules based on their weights and then split them
for k-fold cross validation by taking consecutive chunks.
Parameters
----------
dataset
We assume ``len(dataset)`` gives the size for the dataset, ``dataset[i]``
gives the ith datapoint and ``dataset.smiles[i]`` gives the SMILES for the
ith datapoint.
mols : None or list of rdkit.Chem.rdchem.Mol
None or pre-computed RDKit molecule instances. If not None, we expect a
one-on-one correspondence between ``dataset.smiles`` and ``mols``, i.e.
``mols[i]`` corresponds to ``dataset.smiles[i]``. Default to None.
sanitize : bool
This argument only comes into effect when ``mols`` is None and decides whether
sanitization is performed in initializing RDKit molecule instances. See
https://www.rdkit.org/docs/RDKit_Book.html for details of the sanitization.
Default to be True.
k : int
Number of folds to use and should be no smaller than 2. Default to be 5.
log_every_n : None or int
Molecule related computation can take a long time for a large dataset and we want
to learn the progress of processing. This can be done by printing a message whenever
a batch of ``log_every_n`` molecules have been processed. If None, no messages will
be printed. Default to 1000.
Returns
-------
list of 2-tuples
Each element of the list represents a fold and is a 2-tuple ``(train_set, val_set)``.
``train_set`` and ``val_set`` also have ``len(dataset)`` and ``dataset[i]`` behaviors.
"""
molecules = prepare_mols(dataset, mols, sanitize, log_every_n)
sorted_indices = MolecularWeightSplitter.molecular_weight_indices(molecules, log_every_n)
return base_k_fold_split(partial(indices_split, indices=sorted_indices), dataset, k,
log=(log_every_n is not None))
# pylint: disable=W0702
class ScaffoldSplitter(object):
"""Group molecules based on their Bemis-Murcko scaffolds and then split the groups.
Group molecules so that all molecules in a group have a same scaffold (see reference).
The dataset is then split at the level of groups.
References
----------
Bemis, G. W.; Murcko, M. A. “The Properties of Known Drugs.
1. Molecular Frameworks.” J. Med. Chem. 39:2887-93 (1996).
"""
@staticmethod
def get_ordered_scaffold_sets(molecules, include_chirality, log_every_n):
"""Group molecules based on their Bemis-Murcko scaffolds and
order these groups based on their sizes.
The order is decided by comparing the size of groups, where groups with a larger size
are placed before the ones with a smaller size.
Parameters
----------
molecules : list of rdkit.Chem.rdchem.Mol
Pre-computed RDKit molecule instances. We expect a one-on-one
correspondence between ``dataset.smiles`` and ``mols``, i.e.
``mols[i]`` corresponds to ``dataset.smiles[i]``.
include_chirality : bool
Whether to consider chirality in computing scaffolds.
log_every_n : None or int
Molecule related computation can take a long time for a large dataset and we want
to learn the progress of processing. This can be done by printing a message whenever
a batch of ``log_every_n`` molecules have been processed. If None, no messages will
be printed.
Returns
-------
scaffold_sets : list
Each element of the list is a list of int,
representing the indices of compounds with a same scaffold.
"""
if log_every_n is not None:
print('Start computing Bemis-Murcko scaffolds.')
scaffolds = defaultdict(list)
for i, mol in enumerate(molecules):
count_and_log('Computing Bemis-Murcko for compound',
i, len(molecules), log_every_n)
# For mols that have not been sanitized, we need to compute their ring information
try:
FastFindRings(mol)
mol_scaffold = MurckoScaffold.MurckoScaffoldSmiles(
mol=mol, includeChirality=include_chirality)
# Group molecules that have the same scaffold
scaffolds[mol_scaffold].append(i)
except:
print('Failed to compute the scaffold for molecule {:d} '
'and it will be excluded.'.format(i + 1))
# Order groups of molecules by first comparing the size of groups
# and then the index of the first compound in the group.
scaffold_sets = [
scaffold_set for (scaffold, scaffold_set) in sorted(
scaffolds.items(), key=lambda x: (len(x[1]), x[1][0]), reverse=True)
]
return scaffold_sets
@staticmethod
def train_val_test_split(dataset, mols=None, sanitize=True, include_chirality=False,
frac_train=0.8, frac_val=0.1, frac_test=0.1, log_every_n=1000):
"""Split the dataset into training, validation and test set based on molecular scaffolds.
This spliting method ensures that molecules with a same scaffold will be collectively
in only one of the training, validation or test set. As a result, the fraction
of dataset to use for training and validation tend to be smaller than ``frac_train``
and ``frac_val``, while the fraction of dataset to use for test tends to be larger
than ``frac_test``.
Parameters
----------
dataset
We assume ``len(dataset)`` gives the size for the dataset, ``dataset[i]``
gives the ith datapoint and ``dataset.smiles[i]`` gives the SMILES for the
ith datapoint.
mols : None or list of rdkit.Chem.rdchem.Mol
None or pre-computed RDKit molecule instances. If not None, we expect a
one-on-one correspondence between ``dataset.smiles`` and ``mols``, i.e.
``mols[i]`` corresponds to ``dataset.smiles[i]``. Default to None.
sanitize : bool
This argument only comes into effect when ``mols`` is None and decides whether
sanitization is performed in initializing RDKit molecule instances. See
https://www.rdkit.org/docs/RDKit_Book.html for details of the sanitization.
Default to True.
include_chirality : bool
Whether to consider chirality in computing scaffolds. Default to False.
frac_train : float
Fraction of data to use for training. By default, we set this to be 0.8, i.e.
80% of the dataset is used for training.
frac_val : float
Fraction of data to use for validation. By default, we set this to be 0.1, i.e.
10% of the dataset is used for validation.
frac_test : float
Fraction of data to use for test. By default, we set this to be 0.1, i.e.
10% of the dataset is used for test.
log_every_n : None or int
Molecule related computation can take a long time for a large dataset and we want
to learn the progress of processing. This can be done by printing a message whenever
a batch of ``log_every_n`` molecules have been processed. If None, no messages will
be printed. Default to 1000.
Returns
-------
list of length 3
Subsets for training, validation and test, which also have ``len(dataset)`` and
``dataset[i]`` behaviors
"""
# Perform sanity check first as molecule related computation can take a long time.
train_val_test_sanity_check(frac_train, frac_val, frac_test)
molecules = prepare_mols(dataset, mols, sanitize)
scaffold_sets = ScaffoldSplitter.get_ordered_scaffold_sets(
molecules, include_chirality, log_every_n)
train_indices, val_indices, test_indices = [], [], []
train_cutoff = int(frac_train * len(molecules))
val_cutoff = int((frac_train + frac_val) * len(molecules))
for group_indices in scaffold_sets:
if len(train_indices) + len(group_indices) > train_cutoff:
if len(train_indices) + len(val_indices) + len(group_indices) > val_cutoff:
test_indices.extend(group_indices)
else:
val_indices.extend(group_indices)
else:
train_indices.extend(group_indices)
return [Subset(dataset, train_indices),
Subset(dataset, val_indices),
Subset(dataset, test_indices)]
@staticmethod
def k_fold_split(dataset, mols=None, sanitize=True,
include_chirality=False, k=5, log_every_n=1000):
"""Group molecules based on their scaffolds and sort groups based on their sizes.
The groups are then split for k-fold cross validation.
Same as usual k-fold splitting methods, each molecule will appear only once
in the validation set among all folds. In addition, this method ensures that
molecules with a same scaffold will be collectively in either the training
set or the validation set for each fold.
Note that the folds can be highly imbalanced depending on the
scaffold distribution in the dataset.
Parameters
----------
dataset
We assume ``len(dataset)`` gives the size for the dataset, ``dataset[i]``
gives the ith datapoint and ``dataset.smiles[i]`` gives the SMILES for the
ith datapoint.
mols : None or list of rdkit.Chem.rdchem.Mol
None or pre-computed RDKit molecule instances. If not None, we expect a
one-on-one correspondence between ``dataset.smiles`` and ``mols``, i.e.
``mols[i]`` corresponds to ``dataset.smiles[i]``. Default to None.
sanitize : bool
This argument only comes into effect when ``mols`` is None and decides whether
sanitization is performed in initializing RDKit molecule instances. See
https://www.rdkit.org/docs/RDKit_Book.html for details of the sanitization.
Default to True.
include_chirality : bool
Whether to consider chirality in computing scaffolds. Default to False.
k : int
Number of folds to use and should be no smaller than 2. Default to be 5.
log_every_n : None or int
Molecule related computation can take a long time for a large dataset and we want
to learn the progress of processing. This can be done by printing a message whenever
a batch of ``log_every_n`` molecules have been processed. If None, no messages will
be printed. Default to 1000.
Returns
-------
list of 2-tuples
Each element of the list represents a fold and is a 2-tuple ``(train_set, val_set)``.
``train_set`` and ``val_set`` also have ``len(dataset)`` and ``dataset[i]`` behaviors.
"""
assert k >= 2, 'Expect the number of folds to be no smaller than 2, got {:d}'.format(k)
molecules = prepare_mols(dataset, mols, sanitize)
scaffold_sets = ScaffoldSplitter.get_ordered_scaffold_sets(
molecules, include_chirality, log_every_n)
# k buckets that form a relatively balanced partition of the dataset
index_buckets = [[] for _ in range(k)]
for group_indices in scaffold_sets:
bucket_chosen = int(np.argmin([len(bucket) for bucket in index_buckets]))
index_buckets[bucket_chosen].extend(group_indices)
all_folds = []
for i in range(k):
if log_every_n is not None:
print('Processing fold {:d}/{:d}'.format(i + 1, k))
train_indices = list(chain.from_iterable(index_buckets[:i] + index_buckets[i + 1:]))
val_indices = index_buckets[i]
all_folds.append((Subset(dataset, train_indices), Subset(dataset, val_indices)))
return all_folds
class SingleTaskStratifiedSplitter(object):
"""Splits the dataset by stratification on a single task.
We sort the molecules based on their label values for a task and then repeatedly
take buckets of datapoints to augment the training, validation and test subsets.
"""
@staticmethod
def train_val_test_split(dataset, labels, task_id, frac_train=0.8, frac_val=0.1,
frac_test=0.1, bucket_size=10, random_state=None):
"""Split the dataset into training, validation and test subsets as stated above.
Parameters
----------
dataset
We assume ``len(dataset)`` gives the size for the dataset, ``dataset[i]``
gives the ith datapoint and ``dataset.smiles[i]`` gives the SMILES for the
ith datapoint.
labels : tensor of shape (N, T)
Dataset labels all tasks. N for the number of datapoints and T for the number
of tasks.
task_id : int
Index for the task.
frac_train : float
Fraction of data to use for training. By default, we set this to be 0.8, i.e.
80% of the dataset is used for training.
frac_val : float
Fraction of data to use for validation. By default, we set this to be 0.1, i.e.
10% of the dataset is used for validation.
frac_test : float
Fraction of data to use for test. By default, we set this to be 0.1, i.e.
10% of the dataset is used for test.
bucket_size : int
Size of bucket of datapoints. Default to 10.
random_state : None, int or array_like, optional
Random seed used to initialize the pseudo-random number generator.
Can be any integer between 0 and 2**32 - 1 inclusive, an array
(or other sequence) of such integers, or None (the default).
If seed is None, then RandomState will try to read data from /dev/urandom
(or the Windows analogue) if available or seed from the clock otherwise.
Returns
-------
list of length 3
Subsets for training, validation and test, which also have ``len(dataset)``
and ``dataset[i]`` behaviors
"""
train_val_test_sanity_check(frac_train, frac_val, frac_test)
if random_state is not None:
np.random.seed(random_state)
if not isinstance(labels, np.ndarray):
labels = F.asnumpy(labels)
task_labels = labels[:, task_id]
sorted_indices = np.argsort(task_labels)
train_bucket_cutoff = int(np.round(frac_train * bucket_size))
val_bucket_cutoff = int(np.round(frac_val * bucket_size)) + train_bucket_cutoff
train_indices, val_indices, test_indices = [], [], []
while sorted_indices.shape[0] >= bucket_size:
current_batch, sorted_indices = np.split(sorted_indices, [bucket_size])
shuffled = np.random.permutation(range(bucket_size))
train_indices.extend(
current_batch[shuffled[:train_bucket_cutoff]].tolist())
val_indices.extend(
current_batch[shuffled[train_bucket_cutoff:val_bucket_cutoff]].tolist())
test_indices.extend(
current_batch[shuffled[val_bucket_cutoff:]].tolist())
# Place rest samples in the training set.
train_indices.extend(sorted_indices.tolist())
return [Subset(dataset, train_indices),
Subset(dataset, val_indices),
Subset(dataset, test_indices)]
@staticmethod
def k_fold_split(dataset, labels, task_id, k=5, log=True):
"""Sort molecules based on their label values for a task and then split them
for k-fold cross validation by taking consecutive chunks.
Parameters
----------
dataset
We assume ``len(dataset)`` gives the size for the dataset, ``dataset[i]``
gives the ith datapoint and ``dataset.smiles[i]`` gives the SMILES for the
ith datapoint.
labels : tensor of shape (N, T)
Dataset labels all tasks. N for the number of datapoints and T for the number
of tasks.
task_id : int
Index for the task.
k : int
Number of folds to use and should be no smaller than 2. Default to be 5.
log : bool
Whether to print a message at the start of preparing each fold.
Returns
-------
list of 2-tuples
Each element of the list represents a fold and is a 2-tuple ``(train_set, val_set)``.
``train_set`` and ``val_set`` also have ``len(dataset)`` and ``dataset[i]`` behaviors.
"""
if not isinstance(labels, np.ndarray):
labels = F.asnumpy(labels)
task_labels = labels[:, task_id]
sorted_indices = np.argsort(task_labels).tolist()
return base_k_fold_split(partial(indices_split, indices=sorted_indices), dataset, k, log)
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
from setuptools import find_packages
from setuptools import setup
CURRENT_DIR = os.path.dirname(__file__)
def get_lib_path():
"""Get library path, name and version"""
# We can not import `libinfo.py` in setup.py directly since __init__.py
# Will be invoked which introduces dependences
libinfo_py = os.path.join(CURRENT_DIR, './dgllife/libinfo.py')
libinfo = {'__file__': libinfo_py}
exec(compile(open(libinfo_py, "rb").read(), libinfo_py, 'exec'), libinfo, libinfo)
version = libinfo['__version__']
return version
VERSION = get_lib_path()
setup(
name='dgllife',
version=VERSION,
description='DGL-based package for Life Science',
keywords=[
'pytorch',
'dgl',
'graph-neural-networks',
'life-science',
'drug-discovery'
],
maintainer='DGL Team',
packages=[package for package in find_packages()
if package.startswith('dgllife')],
install_requires=[
'scikit-learn>=0.22.2',
'pandas',
'requests>=2.22.0',
'tqdm',
'numpy>=1.14.0',
'scipy>=1.1.0',
'networkx>=2.1',
],
url='https://github.com/dmlc/dgl/tree/master/apps/life_sci',
classifiers=[
'Development Status :: 3 - Alpha',
'Programming Language :: Python :: 3',
'License :: OSI Approved :: Apache Software License'
],
license='APACHE'
)
"""
This is the global script that set the version information of DGL-LifeSci.
This script runs and update all the locations that related to versions
List of affected files:
- app-root/python/dgllife/__init__.py
- app-root/conda/dgllife/meta.yaml
"""
import os
import re
__version__ = "0.2.2"
print(__version__)
# Implementations
def update(file_name, pattern, repl):
update = []
hit_counter = 0
need_update = False
for l in open(file_name):
result = re.findall(pattern, l)
if result:
assert len(result) == 1
hit_counter += 1
if result[0] != repl:
l = re.sub(pattern, repl, l)
need_update = True
print("%s: %s->%s" % (file_name, result[0], repl))
else:
print("%s: version is already %s" % (file_name, repl))
update.append(l)
if hit_counter != 1:
raise RuntimeError("Cannot find version in %s" % file_name)
if need_update:
with open(file_name, "w") as output_file:
for l in update:
output_file.write(l)
def main():
curr_dir = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
proj_root = os.path.abspath(os.path.join(curr_dir, ".."))
# python path
update(os.path.join(proj_root, "python/dgllife/libinfo.py"),
r"(?<=__version__ = \")[.0-9a-z]+", __version__)
# conda
update(os.path.join(proj_root, "conda/dgllife/meta.yaml"),
"(?<=version: \")[.0-9a-z]+", __version__)
if __name__ == '__main__':
main()
import os
import pandas as pd
from dgllife.data.csv_dataset import *
from dgllife.utils.featurizers import *
from dgllife.utils.mol_to_graph import *
def test_data_frame():
data = [['CCO', 0, 1], ['CO', 2, 3]]
df = pd.DataFrame(data, columns = ['smiles', 'task1', 'task2'])
return df
def remove_file(fname):
if os.path.isfile(fname):
try:
os.remove(fname)
except OSError:
pass
def test_mol_csv():
df = test_data_frame()
fname = 'test.bin'
dataset = MoleculeCSVDataset(df=df, smiles_to_graph=smiles_to_bigraph,
node_featurizer=CanonicalAtomFeaturizer(),
edge_featurizer=CanonicalBondFeaturizer(),
smiles_column='smiles',
cache_file_path=fname)
assert dataset.task_names == ['task1', 'task2']
smiles, graph, label, mask = dataset[0]
assert label.shape[0] == 2
assert mask.shape[0] == 2
assert 'h' in graph.ndata
assert 'e' in graph.edata
# Test task_names
dataset = MoleculeCSVDataset(df=df, smiles_to_graph=smiles_to_bigraph,
node_featurizer=None,
edge_featurizer=None,
smiles_column='smiles',
cache_file_path=fname,
task_names=['task1'])
assert dataset.task_names == ['task1']
# Test load
dataset = MoleculeCSVDataset(df=df, smiles_to_graph=smiles_to_bigraph,
node_featurizer=CanonicalAtomFeaturizer(),
edge_featurizer=None,
smiles_column='smiles',
cache_file_path=fname,
load=True)
smiles, graph, label, mask = dataset[0]
assert 'h' in graph.ndata
assert 'e' in graph.edata
dataset = MoleculeCSVDataset(df=df, smiles_to_graph=smiles_to_bigraph,
node_featurizer=CanonicalAtomFeaturizer(),
edge_featurizer=None,
smiles_column='smiles',
cache_file_path=fname,
load=False)
smiles, graph, label, mask = dataset[0]
assert 'h' in graph.ndata
assert 'e' not in graph.edata
remove_file(fname)
if __name__ == '__main__':
test_mol_csv()
import os
from dgllife.data import *
from dgllife.data.uspto import get_bond_changes, process_file
def remove_file(fname):
if os.path.isfile(fname):
try:
os.remove(fname)
except OSError:
pass
def test_pubchem_aromaticity():
print('Test pubchem aromaticity')
dataset = PubChemBioAssayAromaticity()
remove_file('pubchem_aromaticity_dglgraph.bin')
def test_tox21():
print('Test Tox21')
dataset = Tox21()
remove_file('tox21_dglgraph.bin')
def test_alchemy():
print('Test Alchemy')
dataset = TencentAlchemyDataset(mode='valid',
node_featurizer=None,
edge_featurizer=None)
dataset = TencentAlchemyDataset(mode='valid',
node_featurizer=None,
edge_featurizer=None,
load=False)
def test_pdbbind():
print('Test PDBBind')
dataset = PDBBind(subset='core', remove_hs=True)
def test_wln_reaction():
print('Test datasets for reaction prediction with WLN')
reaction1 = '[CH2:15]([CH:16]([CH3:17])[CH3:18])[Mg+:19].[CH2:20]1[O:21][CH2:22][CH2:23]' \
'[CH2:24]1.[Cl-:14].[OH:1][c:2]1[n:3][cH:4][c:5]([C:6](=[O:7])[N:8]([O:9]' \
'[CH3:10])[CH3:11])[cH:12][cH:13]1>>[OH:1][c:2]1[n:3][cH:4][c:5]([C:6](=[O:7])' \
'[CH2:15][CH:16]([CH3:17])[CH3:18])[cH:12][cH:13]1\n'
reaction2 = '[CH3:14][NH2:15].[N+:1](=[O:2])([O-:3])[c:4]1[cH:5][c:6]([C:7](=[O:8])[OH:9])' \
'[cH:10][cH:11][c:12]1[Cl:13].[OH2:16]>>[N+:1](=[O:2])([O-:3])[c:4]1[cH:5][c:6]' \
'([C:7](=[O:8])[OH:9])[cH:10][cH:11][c:12]1[NH:15][CH3:14]\n'
reactions = [reaction1, reaction2]
# Test utility functions
assert get_bond_changes(reaction2) == {('12', '13', 0.0), ('12', '15', 1.0)}
with open('test.txt', 'w') as f:
for reac in reactions:
f.write(reac)
process_file('test.txt')
with open('test.txt.proc', 'r') as f:
lines = f.readlines()
for i in range(len(lines)):
l = lines[i].strip()
react = reactions[i].strip()
bond_changes = get_bond_changes(react)
assert l == '{} {}'.format(
react,
';'.join(['{}-{}-{}'.format(x[0], x[1], x[2]) for x in bond_changes]))
remove_file('test.txt.proc')
# Test configured dataset
dataset = WLNCenterDataset('test.txt', 'test_graphs.bin')
remove_file('test_graphs.bin')
with open('test_candidate_bond_changes.txt', 'w') as f:
for reac in reactions:
# simulate fake candidate bond changes
candidate_string = ''
for i in range(2):
candidate_string += '{} {} {:.1f} {:.3f};'.format(i+1, i+2, 0.0, 0.234)
candidate_string += '\n'
f.write(candidate_string)
dataset = WLNRankDataset('test.txt.proc', 'test_candidate_bond_changes.txt', 'train')
remove_file('test.txt')
remove_file('test.txt.proc')
remove_file('test_graphs.bin')
remove_file('test_candidate_bond_changes.txt')
if __name__ == '__main__':
test_pubchem_aromaticity()
test_tox21()
test_alchemy()
test_pdbbind()
test_wln_reaction()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment