Unverified Commit 36c7b771 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[LifeSci] Move to Independent Repo (#1592)

* Move LifeSci

* Remove doc
parent 94c67203
import numpy as np
import torch
from torch.nn import BCEWithLogitsLoss
from torch.optim import Adam
from torch.utils.data import DataLoader
from dgl import model_zoo
from utils import Meter, EarlyStopping, collate_molgraphs, set_random_seed, \
load_dataset_for_classification, load_model
def run_a_train_epoch(args, epoch, model, data_loader, loss_criterion, optimizer):
model.train()
train_meter = Meter()
for batch_id, batch_data in enumerate(data_loader):
smiles, bg, labels, masks = batch_data
atom_feats = bg.ndata.pop(args['atom_data_field'])
atom_feats, labels, masks = atom_feats.to(args['device']), \
labels.to(args['device']), \
masks.to(args['device'])
logits = model(bg, atom_feats)
# Mask non-existing labels
loss = (loss_criterion(logits, labels) * (masks != 0).float()).mean()
optimizer.zero_grad()
loss.backward()
optimizer.step()
print('epoch {:d}/{:d}, batch {:d}/{:d}, loss {:.4f}'.format(
epoch + 1, args['num_epochs'], batch_id + 1, len(data_loader), loss.item()))
train_meter.update(logits, labels, masks)
train_score = np.mean(train_meter.compute_metric(args['metric_name']))
print('epoch {:d}/{:d}, training {} {:.4f}'.format(
epoch + 1, args['num_epochs'], args['metric_name'], train_score))
def run_an_eval_epoch(args, model, data_loader):
model.eval()
eval_meter = Meter()
with torch.no_grad():
for batch_id, batch_data in enumerate(data_loader):
smiles, bg, labels, masks = batch_data
atom_feats = bg.ndata.pop(args['atom_data_field'])
atom_feats, labels = atom_feats.to(args['device']), labels.to(args['device'])
logits = model(bg, atom_feats)
eval_meter.update(logits, labels, masks)
return np.mean(eval_meter.compute_metric(args['metric_name']))
def main(args):
args['device'] = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
set_random_seed(args['random_seed'])
# Interchangeable with other datasets
dataset, train_set, val_set, test_set = load_dataset_for_classification(args)
train_loader = DataLoader(train_set, batch_size=args['batch_size'],
collate_fn=collate_molgraphs)
val_loader = DataLoader(val_set, batch_size=args['batch_size'],
collate_fn=collate_molgraphs)
test_loader = DataLoader(test_set, batch_size=args['batch_size'],
collate_fn=collate_molgraphs)
if args['pre_trained']:
args['num_epochs'] = 0
model = model_zoo.chem.load_pretrained(args['exp'])
else:
args['n_tasks'] = dataset.n_tasks
model = load_model(args)
loss_criterion = BCEWithLogitsLoss(pos_weight=dataset.task_pos_weights.to(args['device']),
reduction='none')
optimizer = Adam(model.parameters(), lr=args['lr'])
stopper = EarlyStopping(patience=args['patience'])
model.to(args['device'])
for epoch in range(args['num_epochs']):
# Train
run_a_train_epoch(args, epoch, model, train_loader, loss_criterion, optimizer)
# Validation and early stop
val_score = run_an_eval_epoch(args, model, val_loader)
early_stop = stopper.step(val_score, model)
print('epoch {:d}/{:d}, validation {} {:.4f}, best validation {} {:.4f}'.format(
epoch + 1, args['num_epochs'], args['metric_name'],
val_score, args['metric_name'], stopper.best_score))
if early_stop:
break
if not args['pre_trained']:
stopper.load_checkpoint(model)
test_score = run_an_eval_epoch(args, model, test_loader)
print('test {} {:.4f}'.format(args['metric_name'], test_score))
if __name__ == '__main__':
import argparse
from configure import get_exp_configure
parser = argparse.ArgumentParser(description='Molecule Classification')
parser.add_argument('-m', '--model', type=str, choices=['GCN', 'GAT'],
help='Model to use')
parser.add_argument('-d', '--dataset', type=str, choices=['Tox21'],
help='Dataset to use')
parser.add_argument('-p', '--pre-trained', action='store_true',
help='Whether to skip training and use a pre-trained model')
args = parser.parse_args().__dict__
args['exp'] = '_'.join([args['model'], args['dataset']])
args.update(get_exp_configure(args['exp']))
main(args)
from dgl.data.chem import BaseAtomFeaturizer, CanonicalAtomFeaturizer, ConcatFeaturizer, \
atom_type_one_hot, atom_degree_one_hot, atom_formal_charge, atom_num_radical_electrons, \
atom_hybridization_one_hot, atom_total_num_H_one_hot, BaseBondFeaturizer
from functools import partial
from utils import chirality
GCN_Tox21 = {
'random_seed': 0,
'batch_size': 128,
'lr': 1e-3,
'num_epochs': 100,
'atom_data_field': 'h',
'frac_train': 0.8,
'frac_val': 0.1,
'frac_test': 0.1,
'in_feats': 74,
'gcn_hidden_feats': [64, 64],
'classifier_hidden_feats': 64,
'patience': 10,
'atom_featurizer': CanonicalAtomFeaturizer(),
'metric_name': 'roc_auc'
}
GAT_Tox21 = {
'random_seed': 0,
'batch_size': 128,
'lr': 1e-3,
'num_epochs': 100,
'atom_data_field': 'h',
'frac_train': 0.8,
'frac_val': 0.1,
'frac_test': 0.1,
'in_feats': 74,
'gat_hidden_feats': [32, 32],
'classifier_hidden_feats': 64,
'num_heads': [4, 4],
'patience': 10,
'atom_featurizer': CanonicalAtomFeaturizer(),
'metric_name': 'roc_auc'
}
MPNN_Alchemy = {
'random_seed': 0,
'batch_size': 16,
'num_epochs': 250,
'node_in_feats': 15,
'edge_in_feats': 5,
'output_dim': 12,
'lr': 0.0001,
'patience': 50,
'metric_name': 'l1',
'weight_decay': 0
}
SCHNET_Alchemy = {
'random_seed': 0,
'batch_size': 16,
'num_epochs': 250,
'norm': True,
'output_dim': 12,
'lr': 0.0001,
'patience': 50,
'metric_name': 'l1',
'weight_decay': 0
}
MGCN_Alchemy = {
'random_seed': 0,
'batch_size': 16,
'num_epochs': 250,
'norm': True,
'output_dim': 12,
'lr': 0.0001,
'patience': 50,
'metric_name': 'l1',
'weight_decay': 0
}
AttentiveFP_Aromaticity = {
'random_seed': 8,
'graph_feat_size': 200,
'num_layers': 2,
'num_timesteps': 2,
'node_feat_size': 39,
'edge_feat_size': 10,
'output_size': 1,
'dropout': 0.2,
'weight_decay': 10 ** (-5.0),
'lr': 10 ** (-2.5),
'batch_size': 128,
'num_epochs': 800,
'frac_train': 0.8,
'frac_val': 0.1,
'frac_test': 0.1,
'patience': 80,
'metric_name': 'rmse',
# Follow the atom featurization in the original work
'atom_featurizer': BaseAtomFeaturizer(
featurizer_funcs={'hv': ConcatFeaturizer([
partial(atom_type_one_hot, allowable_set=[
'B', 'C', 'N', 'O', 'F', 'Si', 'P', 'S', 'Cl', 'As', 'Se', 'Br', 'Te', 'I', 'At'],
encode_unknown=True),
partial(atom_degree_one_hot, allowable_set=list(range(6))),
atom_formal_charge, atom_num_radical_electrons,
partial(atom_hybridization_one_hot, encode_unknown=True),
lambda atom: [0], # A placeholder for aromatic information,
atom_total_num_H_one_hot, chirality
],
)}
),
'bond_featurizer': BaseBondFeaturizer({
'he': lambda bond: [0 for _ in range(10)]
})
}
experiment_configures = {
'GCN_Tox21': GCN_Tox21,
'GAT_Tox21': GAT_Tox21,
'MPNN_Alchemy': MPNN_Alchemy,
'SCHNET_Alchemy': SCHNET_Alchemy,
'MGCN_Alchemy': MGCN_Alchemy,
'AttentiveFP_Aromaticity': AttentiveFP_Aromaticity
}
def get_exp_configure(exp_name):
return experiment_configures[exp_name]
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from dgl import model_zoo
from utils import Meter, set_random_seed, collate_molgraphs, EarlyStopping, \
load_dataset_for_regression, load_model
def regress(args, model, bg):
if args['model'] == 'MPNN':
h = bg.ndata.pop('n_feat')
e = bg.edata.pop('e_feat')
h, e = h.to(args['device']), e.to(args['device'])
return model(bg, h, e)
elif args['model'] in ['SCHNET', 'MGCN']:
node_types = bg.ndata.pop('node_type')
edge_distances = bg.edata.pop('distance')
node_types, edge_distances = node_types.to(args['device']), \
edge_distances.to(args['device'])
return model(bg, node_types, edge_distances)
else:
atom_feats, bond_feats = bg.ndata.pop('hv'), bg.edata.pop('he')
atom_feats, bond_feats = atom_feats.to(args['device']), bond_feats.to(args['device'])
return model(bg, atom_feats, bond_feats)
def run_a_train_epoch(args, epoch, model, data_loader,
loss_criterion, optimizer):
model.train()
train_meter = Meter()
for batch_id, batch_data in enumerate(data_loader):
smiles, bg, labels, masks = batch_data
labels, masks = labels.to(args['device']), masks.to(args['device'])
prediction = regress(args, model, bg)
loss = (loss_criterion(prediction, labels) * (masks != 0).float()).mean()
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_meter.update(prediction, labels, masks)
total_score = np.mean(train_meter.compute_metric(args['metric_name']))
print('epoch {:d}/{:d}, training {} {:.4f}'.format(
epoch + 1, args['num_epochs'], args['metric_name'], total_score))
def run_an_eval_epoch(args, model, data_loader):
model.eval()
eval_meter = Meter()
with torch.no_grad():
for batch_id, batch_data in enumerate(data_loader):
smiles, bg, labels, masks = batch_data
labels = labels.to(args['device'])
prediction = regress(args, model, bg)
eval_meter.update(prediction, labels, masks)
total_score = np.mean(eval_meter.compute_metric(args['metric_name']))
return total_score
def main(args):
args['device'] = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
set_random_seed(args['random_seed'])
train_set, val_set, test_set = load_dataset_for_regression(args)
train_loader = DataLoader(dataset=train_set,
batch_size=args['batch_size'],
shuffle=True,
collate_fn=collate_molgraphs)
val_loader = DataLoader(dataset=val_set,
batch_size=args['batch_size'],
shuffle=True,
collate_fn=collate_molgraphs)
if test_set is not None:
test_loader = DataLoader(dataset=test_set,
batch_size=args['batch_size'],
collate_fn=collate_molgraphs)
if args['pre_trained']:
args['num_epochs'] = 0
model = model_zoo.chem.load_pretrained(args['exp'])
else:
model = load_model(args)
if args['model'] in ['SCHNET', 'MGCN']:
model.set_mean_std(train_set.mean, train_set.std, args['device'])
loss_fn = nn.MSELoss(reduction='none')
optimizer = torch.optim.Adam(model.parameters(), lr=args['lr'], weight_decay=args['weight_decay'])
stopper = EarlyStopping(mode='lower', patience=args['patience'])
model.to(args['device'])
for epoch in range(args['num_epochs']):
# Train
run_a_train_epoch(args, epoch, model, train_loader, loss_fn, optimizer)
# Validation and early stop
val_score = run_an_eval_epoch(args, model, val_loader)
early_stop = stopper.step(val_score, model)
print('epoch {:d}/{:d}, validation {} {:.4f}, best validation {} {:.4f}'.format(
epoch + 1, args['num_epochs'], args['metric_name'], val_score,
args['metric_name'], stopper.best_score))
if early_stop:
break
if test_set is not None:
if not args['pre_trained']:
stopper.load_checkpoint(model)
test_score = run_an_eval_epoch(args, model, test_loader)
print('test {} {:.4f}'.format(args['metric_name'], test_score))
if __name__ == "__main__":
import argparse
from configure import get_exp_configure
parser = argparse.ArgumentParser(description='Molecule Regression')
parser.add_argument('-m', '--model', type=str,
choices=['MPNN', 'SCHNET', 'MGCN', 'AttentiveFP'],
help='Model to use')
parser.add_argument('-d', '--dataset', type=str, choices=['Alchemy', 'Aromaticity'],
help='Dataset to use')
parser.add_argument('-p', '--pre-trained', action='store_true',
help='Whether to skip training and use a pre-trained model')
args = parser.parse_args().__dict__
args['exp'] = '_'.join([args['model'], args['dataset']])
args.update(get_exp_configure(args['exp']))
main(args)
import datetime
import dgl
import numpy as np
import random
import torch
import torch.nn.functional as F
from dgl import model_zoo
from dgl.data.chem import smiles_to_bigraph, one_hot_encoding, RandomSplitter
from sklearn.metrics import roc_auc_score
def set_random_seed(seed=0):
"""Set random seed.
Parameters
----------
seed : int
Random seed to use
"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
def chirality(atom):
try:
return one_hot_encoding(atom.GetProp('_CIPCode'), ['R', 'S']) + \
[atom.HasProp('_ChiralityPossible')]
except:
return [False, False] + [atom.HasProp('_ChiralityPossible')]
class Meter(object):
"""Track and summarize model performance on a dataset for
(multi-label) binary classification."""
def __init__(self):
self.mask = []
self.y_pred = []
self.y_true = []
def update(self, y_pred, y_true, mask):
"""Update for the result of an iteration
Parameters
----------
y_pred : float32 tensor
Predicted molecule labels with shape (B, T),
B for batch size and T for the number of tasks
y_true : float32 tensor
Ground truth molecule labels with shape (B, T)
mask : float32 tensor
Mask for indicating the existence of ground
truth labels with shape (B, T)
"""
self.y_pred.append(y_pred.detach().cpu())
self.y_true.append(y_true.detach().cpu())
self.mask.append(mask.detach().cpu())
def roc_auc_score(self):
"""Compute roc-auc score for each task.
Returns
-------
list of float
roc-auc score for all tasks
"""
mask = torch.cat(self.mask, dim=0)
y_pred = torch.cat(self.y_pred, dim=0)
y_true = torch.cat(self.y_true, dim=0)
# Todo: support categorical classes
# This assumes binary case only
y_pred = torch.sigmoid(y_pred)
n_tasks = y_true.shape[1]
scores = []
for task in range(n_tasks):
task_w = mask[:, task]
task_y_true = y_true[:, task][task_w != 0].numpy()
task_y_pred = y_pred[:, task][task_w != 0].numpy()
scores.append(roc_auc_score(task_y_true, task_y_pred))
return scores
def l1_loss(self, reduction):
"""Compute l1 loss for each task.
Returns
-------
list of float
l1 loss for all tasks
reduction : str
* 'mean': average the metric over all labeled data points for each task
* 'sum': sum the metric over all labeled data points for each task
"""
mask = torch.cat(self.mask, dim=0)
y_pred = torch.cat(self.y_pred, dim=0)
y_true = torch.cat(self.y_true, dim=0)
n_tasks = y_true.shape[1]
scores = []
for task in range(n_tasks):
task_w = mask[:, task]
task_y_true = y_true[:, task][task_w != 0]
task_y_pred = y_pred[:, task][task_w != 0]
scores.append(F.l1_loss(task_y_true, task_y_pred, reduction=reduction).item())
return scores
def rmse(self):
"""Compute RMSE for each task.
Returns
-------
list of float
rmse for all tasks
"""
mask = torch.cat(self.mask, dim=0)
y_pred = torch.cat(self.y_pred, dim=0)
y_true = torch.cat(self.y_true, dim=0)
n_data, n_tasks = y_true.shape
scores = []
for task in range(n_tasks):
task_w = mask[:, task]
task_y_true = y_true[:, task][task_w != 0]
task_y_pred = y_pred[:, task][task_w != 0]
scores.append(np.sqrt(F.mse_loss(task_y_pred, task_y_true).cpu().item()))
return scores
def compute_metric(self, metric_name, reduction='mean'):
"""Compute metric for each task.
Parameters
----------
metric_name : str
Name for the metric to compute.
reduction : str
Only comes into effect when the metric_name is l1_loss.
* 'mean': average the metric over all labeled data points for each task
* 'sum': sum the metric over all labeled data points for each task
Returns
-------
list of float
Metric value for each task
"""
assert metric_name in ['roc_auc', 'l1', 'rmse'], \
'Expect metric name to be "roc_auc", "l1" or "rmse", got {}'.format(metric_name)
assert reduction in ['mean', 'sum']
if metric_name == 'roc_auc':
return self.roc_auc_score()
if metric_name == 'l1':
return self.l1_loss(reduction)
if metric_name == 'rmse':
return self.rmse()
class EarlyStopping(object):
"""Early stop performing
Parameters
----------
mode : str
* 'higher': Higher metric suggests a better model
* 'lower': Lower metric suggests a better model
patience : int
Number of epochs to wait before early stop
if the metric stops getting improved
filename : str or None
Filename for storing the model checkpoint
"""
def __init__(self, mode='higher', patience=10, filename=None):
if filename is None:
dt = datetime.datetime.now()
filename = 'early_stop_{}_{:02d}-{:02d}-{:02d}.pth'.format(
dt.date(), dt.hour, dt.minute, dt.second)
assert mode in ['higher', 'lower']
self.mode = mode
if self.mode == 'higher':
self._check = self._check_higher
else:
self._check = self._check_lower
self.patience = patience
self.counter = 0
self.filename = filename
self.best_score = None
self.early_stop = False
def _check_higher(self, score, prev_best_score):
return (score > prev_best_score)
def _check_lower(self, score, prev_best_score):
return (score < prev_best_score)
def step(self, score, model):
if self.best_score is None:
self.best_score = score
self.save_checkpoint(model)
elif self._check(score, self.best_score):
self.best_score = score
self.save_checkpoint(model)
self.counter = 0
else:
self.counter += 1
print(
f'EarlyStopping counter: {self.counter} out of {self.patience}')
if self.counter >= self.patience:
self.early_stop = True
return self.early_stop
def save_checkpoint(self, model):
'''Saves model when the metric on the validation set gets improved.'''
torch.save({'model_state_dict': model.state_dict()}, self.filename)
def load_checkpoint(self, model):
'''Load model saved with early stopping.'''
model.load_state_dict(torch.load(self.filename)['model_state_dict'])
def collate_molgraphs(data):
"""Batching a list of datapoints for dataloader.
Parameters
----------
data : list of 3-tuples or 4-tuples.
Each tuple is for a single datapoint, consisting of
a SMILES, a DGLGraph, all-task labels and optionally
a binary mask indicating the existence of labels.
Returns
-------
smiles : list
List of smiles
bg : DGLGraph
The batched DGLGraph.
labels : Tensor of dtype float32 and shape (B, T)
Batched datapoint labels. B is len(data) and
T is the number of total tasks.
masks : Tensor of dtype float32 and shape (B, T)
Batched datapoint binary mask, indicating the
existence of labels. If binary masks are not
provided, return a tensor with ones.
"""
assert len(data[0]) in [3, 4], \
'Expect the tuple to be of length 3 or 4, got {:d}'.format(len(data[0]))
if len(data[0]) == 3:
smiles, graphs, labels = map(list, zip(*data))
masks = None
else:
smiles, graphs, labels, masks = map(list, zip(*data))
bg = dgl.batch(graphs)
bg.set_n_initializer(dgl.init.zero_initializer)
bg.set_e_initializer(dgl.init.zero_initializer)
labels = torch.stack(labels, dim=0)
if masks is None:
masks = torch.ones(labels.shape)
else:
masks = torch.stack(masks, dim=0)
return smiles, bg, labels, masks
def load_dataset_for_classification(args):
"""Load dataset for classification tasks.
Parameters
----------
args : dict
Configurations.
Returns
-------
dataset
The whole dataset.
train_set
Subset for training.
val_set
Subset for validation.
test_set
Subset for test.
"""
assert args['dataset'] in ['Tox21']
if args['dataset'] == 'Tox21':
from dgl.data.chem import Tox21
dataset = Tox21(smiles_to_bigraph, args['atom_featurizer'])
train_set, val_set, test_set = RandomSplitter.train_val_test_split(
dataset, frac_train=args['frac_train'], frac_val=args['frac_val'],
frac_test=args['frac_test'], random_state=args['random_seed'])
return dataset, train_set, val_set, test_set
def load_dataset_for_regression(args):
"""Load dataset for regression tasks.
Parameters
----------
args : dict
Configurations.
Returns
-------
train_set
Subset for training.
val_set
Subset for validation.
test_set
Subset for test.
"""
assert args['dataset'] in ['Alchemy', 'Aromaticity']
if args['dataset'] == 'Alchemy':
from dgl.data.chem import TencentAlchemyDataset
train_set = TencentAlchemyDataset(mode='dev')
val_set = TencentAlchemyDataset(mode='valid')
test_set = None
if args['dataset'] == 'Aromaticity':
from dgl.data.chem import PubChemBioAssayAromaticity
dataset = PubChemBioAssayAromaticity(smiles_to_bigraph,
args['atom_featurizer'],
args['bond_featurizer'])
train_set, val_set, test_set = RandomSplitter.train_val_test_split(
dataset, frac_train=args['frac_train'], frac_val=args['frac_val'],
frac_test=args['frac_test'], random_state=args['random_seed'])
return train_set, val_set, test_set
def load_model(args):
if args['model'] == 'GCN':
model = model_zoo.chem.GCNClassifier(in_feats=args['in_feats'],
gcn_hidden_feats=args['gcn_hidden_feats'],
classifier_hidden_feats=args['classifier_hidden_feats'],
n_tasks=args['n_tasks'])
if args['model'] == 'GAT':
model = model_zoo.chem.GATClassifier(in_feats=args['in_feats'],
gat_hidden_feats=args['gat_hidden_feats'],
num_heads=args['num_heads'],
classifier_hidden_feats=args['classifier_hidden_feats'],
n_tasks=args['n_tasks'])
if args['model'] == 'MPNN':
model = model_zoo.chem.MPNNModel(node_input_dim=args['node_in_feats'],
edge_input_dim=args['edge_in_feats'],
output_dim=args['output_dim'])
if args['model'] == 'SCHNET':
model = model_zoo.chem.SchNet(norm=args['norm'], output_dim=args['output_dim'])
if args['model'] == 'MGCN':
model = model_zoo.chem.MGCNModel(norm=args['norm'], output_dim=args['output_dim'])
if args['model'] == 'AttentiveFP':
model = model_zoo.chem.AttentiveFP(node_feat_size=args['node_feat_size'],
edge_feat_size=args['edge_feat_size'],
num_layers=args['num_layers'],
num_timesteps=args['num_timesteps'],
graph_feat_size=args['graph_feat_size'],
output_size=args['output_size'],
dropout=args['dropout'])
return model
scikit-learn==0.21.2
pandas==0.25.1
requests==2.22.0
"""Decorator for deprecation message.
This is used in migrating the chem related code to DGL-LifeSci.
Todo(Mufei): remove it in v0.5.
The code is adapted from
https://stackoverflow.com/questions/2536307/
decorators-in-the-python-standard-lib-deprecated-specifically/48632082#48632082.
"""
import warnings
def deprecated(message, mode='func'):
"""Print formatted deprecation message.
Parameters
----------
message : str
mode : str
'func' for function and 'class' for class.
Return
------
callable
"""
assert mode in ['func', 'class']
def deprecated_decorator(func):
def deprecated_func(*args, **kwargs):
if mode == 'func':
warnings.warn("{} is deprecated and will be removed from dgl in v0.5. {}".format(
func.__name__, message), category=DeprecationWarning, stacklevel=2)
else:
warnings.warn("The class is deprecated and "
"will be removed from dgl in v0.5. {}".format(message),
category=DeprecationWarning, stacklevel=2)
warnings.simplefilter('default', DeprecationWarning)
return func(*args, **kwargs)
return deprecated_func
return deprecated_decorator
# Customize Dataset
Generally we follow the practise of PyTorch.
A Dataset class should implement `__getitem__(self, index)` and `__len__(self)`method
```python
class CustomDataset:
def __init__(self):
# Initialize Dataset and preprocess data
def __getitem__(self, index):
# Return the corresponding DGLGraph/label needed for training/evaluation based on index
return self.graphs[index], self.labels[index]
def __len__(self):
return len(self.graphs)
```
DGL supports various backends such as MXNet and PyTorch, therefore we want our dataset to be also backend agnostic.
We prefer user using numpy array in the dataset, and not including any operator/tensor from the specific backend.
If you want to convert the numpy array to the corresponding tensor, you can use the following code
```python
import dgl.backend as F
# g is a DGLGraph, h is a numpy array
g.ndata['h'] = F.zerocopy_from_numpy(h)
# Now g.ndata is a PyTorch Tensor or a MXNet NDArray based on backend used
```
If your dataset is in `.csv` format, you may use
[`CSVDataset`](https://github.com/dmlc/dgl/blob/master/python/dgl/data/chem/csv_dataset.py).
from .datasets import *
from .utils import *
from .csv_dataset import MoleculeCSVDataset
from .tox21 import Tox21
from .alchemy import TencentAlchemyDataset
from .pubchem_aromaticity import PubChemBioAssayAromaticity
from .pdbbind import PDBBind
# -*- coding:utf-8 -*-
"""Example dataloader of Tencent Alchemy Dataset
https://alchemy.tencent.com/
"""
import numpy as np
import os
import os.path as osp
import pathlib
import zipfile
from collections import defaultdict
from ..utils import mol_to_complete_graph, atom_type_one_hot, \
atom_hybridization_one_hot, atom_is_aromatic
from ...utils import download, get_download_dir, _get_dgl_url, save_graphs, load_graphs
from ....utils import retry_method_with_fix
from .... import backend as F
from ....contrib.deprecation import deprecated
try:
import pandas as pd
from rdkit import Chem
from rdkit.Chem import ChemicalFeatures
from rdkit import RDConfig
except ImportError:
pass
def alchemy_nodes(mol):
"""Featurization for all atoms in a molecule. The atom indices
will be preserved.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule object
Returns
-------
atom_feats_dict : dict
Dictionary for atom features
"""
atom_feats_dict = defaultdict(list)
is_donor = defaultdict(int)
is_acceptor = defaultdict(int)
fdef_name = osp.join(RDConfig.RDDataDir, 'BaseFeatures.fdef')
mol_featurizer = ChemicalFeatures.BuildFeatureFactory(fdef_name)
mol_feats = mol_featurizer.GetFeaturesForMol(mol)
mol_conformers = mol.GetConformers()
assert len(mol_conformers) == 1
for i in range(len(mol_feats)):
if mol_feats[i].GetFamily() == 'Donor':
node_list = mol_feats[i].GetAtomIds()
for u in node_list:
is_donor[u] = 1
elif mol_feats[i].GetFamily() == 'Acceptor':
node_list = mol_feats[i].GetAtomIds()
for u in node_list:
is_acceptor[u] = 1
num_atoms = mol.GetNumAtoms()
for u in range(num_atoms):
atom = mol.GetAtomWithIdx(u)
atom_type = atom.GetAtomicNum()
num_h = atom.GetTotalNumHs()
atom_feats_dict['node_type'].append(atom_type)
h_u = []
h_u += atom_type_one_hot(atom, ['H', 'C', 'N', 'O', 'F', 'S', 'Cl'])
h_u.append(atom_type)
h_u.append(is_acceptor[u])
h_u.append(is_donor[u])
h_u += atom_is_aromatic(atom)
h_u += atom_hybridization_one_hot(atom, [Chem.rdchem.HybridizationType.SP,
Chem.rdchem.HybridizationType.SP2,
Chem.rdchem.HybridizationType.SP3])
h_u.append(num_h)
atom_feats_dict['n_feat'].append(F.tensor(np.asarray(h_u, dtype=np.float32)))
atom_feats_dict['n_feat'] = F.stack(atom_feats_dict['n_feat'], dim=0)
atom_feats_dict['node_type'] = F.tensor(
np.asarray(atom_feats_dict['node_type'], dtype=np.int64))
return atom_feats_dict
def alchemy_edges(mol, self_loop=False):
"""Featurization for all bonds in a molecule.
The bond indices will be preserved.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule object
self_loop : bool
Whether to add self loops. Default to be False.
Returns
-------
bond_feats_dict : dict
Dictionary for bond features
"""
bond_feats_dict = defaultdict(list)
mol_conformers = mol.GetConformers()
assert len(mol_conformers) == 1
geom = mol_conformers[0].GetPositions()
num_atoms = mol.GetNumAtoms()
for u in range(num_atoms):
for v in range(num_atoms):
if u == v and not self_loop:
continue
e_uv = mol.GetBondBetweenAtoms(u, v)
if e_uv is None:
bond_type = None
else:
bond_type = e_uv.GetBondType()
bond_feats_dict['e_feat'].append([
float(bond_type == x)
for x in (Chem.rdchem.BondType.SINGLE,
Chem.rdchem.BondType.DOUBLE,
Chem.rdchem.BondType.TRIPLE,
Chem.rdchem.BondType.AROMATIC, None)
])
bond_feats_dict['distance'].append(
np.linalg.norm(geom[u] - geom[v]))
bond_feats_dict['e_feat'] = F.tensor(
np.asarray(bond_feats_dict['e_feat'], dtype=np.float32))
bond_feats_dict['distance'] = F.tensor(
np.asarray(bond_feats_dict['distance'], dtype=np.float32)).reshape(-1, 1)
return bond_feats_dict
class TencentAlchemyDataset(object):
"""
Developed by the Tencent Quantum Lab, the dataset lists 12 quantum mechanical
properties of 130, 000+ organic molecules, comprising up to 12 heavy atoms
(C, N, O, S, F and Cl), sampled from the GDBMedChem database. These properties
have been calculated using the open-source computational chemistry program
Python-based Simulation of Chemistry Framework (PySCF).
For more details, check the `paper <https://arxiv.org/abs/1906.09427>`__.
Parameters
----------
mode : str
'dev', 'valid' or 'test', separately for training, validation and test.
Default to be 'dev'. Note that 'test' is not available as the Alchemy
contest is ongoing.
mol_to_graph: callable, str -> DGLGraph
A function turning an RDKit molecule instance into a DGLGraph.
Default to :func:`dgl.data.chem.mol_to_complete_graph`.
node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for nodes like atoms in a molecule, which can be used to update
ndata for a DGLGraph. By default, we construct graphs where nodes represent atoms
and node features represent atom features. We store the atomic numbers under the
name ``"node_type"`` and store the atom features under the name ``"n_feat"``.
The atom features include:
* One hot encoding for atom types
* Atomic number of atoms
* Whether the atom is a donor
* Whether the atom is an acceptor
* Whether the atom is aromatic
* One hot encoding for atom hybridization
* Total number of Hs on the atom
edge_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for edges like bonds in a molecule, which can be used to update
edata for a DGLGraph. By default, we construct edges between every pair of atoms,
excluding the self loops. We store the distance between the end atoms under the name
``"distance"`` and store the edge features under the name ``"e_feat"``. The edge
features represent one hot encoding of edge types (bond types and non-bond edges).
load : bool
Whether to load the previously pre-processed dataset or pre-process from scratch.
``load`` should be False when we want to try different graph construction and
featurization methods and need to preprocess from scratch. Default to True.
"""
@deprecated('Import TencentAlchemyDataset from dgllife.data.alchemy instead.', 'class')
def __init__(self, mode='dev',
mol_to_graph=mol_to_complete_graph,
node_featurizer=alchemy_nodes,
edge_featurizer=alchemy_edges,
load=True):
if mode == 'test':
raise ValueError('The test mode is not supported before '
'the Alchemy contest finishes.')
assert mode in ['dev', 'valid', 'test'], \
'Expect mode to be dev, valid or test, got {}.'.format(mode)
self.mode = mode
# Construct DGLGraphs from raw data or use the preprocessed data
self.load = load
file_dir = osp.join(get_download_dir(), 'Alchemy_data')
if load:
file_name = "%s_processed_dgl" % (mode)
else:
file_name = "%s_single_sdf" % (mode)
self._file_dir = file_dir
self.file_dir = pathlib.Path(file_dir, file_name)
self._url = 'dataset/alchemy/'
self.zip_file_path = pathlib.Path(file_dir, file_name + '.zip')
self._file_name = file_name
self._load(mol_to_graph, node_featurizer, edge_featurizer)
def _download_and_extract(self):
download(
_get_dgl_url(self._url + self._file_name + '.zip'),
path=str(self.zip_file_path))
if not os.path.exists(str(self.file_dir)):
archive = zipfile.ZipFile(self.zip_file_path)
archive.extractall(self._file_dir)
archive.close()
@retry_method_with_fix(_download_and_extract)
def _load(self, mol_to_graph, node_featurizer, edge_featurizer):
if self.load:
self.graphs, label_dict = load_graphs(osp.join(self.file_dir, "%s_graphs.bin" % self.mode))
self.labels = label_dict['labels']
with open(osp.join(self.file_dir, "%s_smiles.txt" % self.mode), 'r') as f:
smiles_ = f.readlines()
self.smiles = [s.strip() for s in smiles_]
else:
print('Start preprocessing dataset...')
target_file = pathlib.Path(self.file_dir, "%s_target.csv" % self.mode)
self.target = pd.read_csv(
target_file,
index_col=0,
usecols=['gdb_idx',] + ['property_%d' % x for x in range(12)])
self.target = self.target[['property_%d' % x for x in range(12)]]
self.graphs, self.labels, self.smiles = [], [], []
supp = Chem.SDMolSupplier(osp.join(self.file_dir, self.mode + ".sdf"))
cnt = 0
dataset_size = len(self.target)
for mol, label in zip(supp, self.target.iterrows()):
cnt += 1
print('Processing molecule {:d}/{:d}'.format(cnt, dataset_size))
graph = mol_to_graph(mol, node_featurizer=node_featurizer,
edge_featurizer=edge_featurizer)
smiles = Chem.MolToSmiles(mol)
self.smiles.append(smiles)
self.graphs.append(graph)
label = F.tensor(np.asarray(label[1].tolist(), dtype=np.float32))
self.labels.append(label)
save_graphs(osp.join(self.file_dir, "%s_graphs.bin" % self.mode), self.graphs,
labels={'labels': F.stack(self.labels, dim=0)})
with open(osp.join(self.file_dir, "%s_smiles.txt" % self.mode), 'w') as f:
for s in self.smiles:
f.write(s + '\n')
self.set_mean_and_std()
print(len(self.graphs), "loaded!")
def __getitem__(self, item):
"""Get datapoint with index
Parameters
----------
item : int
Datapoint index
Returns
-------
str
SMILES for the ith datapoint
DGLGraph
DGLGraph for the ith datapoint
Tensor of dtype float32
Labels of the datapoint for all tasks
"""
return self.smiles[item], self.graphs[item], self.labels[item]
def __len__(self):
"""Length of the dataset
Returns
-------
int
Length of Dataset
"""
return len(self.graphs)
def set_mean_and_std(self, mean=None, std=None):
"""Set mean and std or compute from labels for future normalization.
Parameters
----------
mean : int or float
Default to be None.
std : int or float
Default to be None.
"""
labels = np.asarray([i.numpy() for i in self.labels])
if mean is None:
mean = np.mean(labels, axis=0)
if std is None:
std = np.std(labels, axis=0)
self.mean = mean
self.std = std
from __future__ import absolute_import
import numpy as np
import os
import sys
from ...utils import save_graphs, load_graphs
from .... import backend as F
from ....contrib.deprecation import deprecated
class MoleculeCSVDataset(object):
"""MoleculeCSVDataset
This is a general class for loading molecular data from pandas.DataFrame.
In data pre-processing, we set non-existing labels to be 0,
and returning mask with 1 where label exists.
All molecules are converted into DGLGraphs. After the first-time construction, the
DGLGraphs can be saved for reloading so that we do not need to reconstruct them every time.
Parameters
----------
df: pandas.DataFrame
Dataframe including smiles and labels. Can be loaded by pandas.read_csv(file_path).
One column includes smiles and other columns for labels.
Column names other than smiles column would be considered as task names.
smiles_to_graph: callable, str -> DGLGraph
A function turning a SMILES into a DGLGraph.
node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for nodes like atoms in a molecule, which can be used to update
ndata for a DGLGraph.
edge_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for edges like bonds in a molecule, which can be used to update
edata for a DGLGraph.
smiles_column: str
Column name that including smiles.
cache_file_path: str
Path to store the preprocessed DGLGraphs. For example, this can be ``'dglgraph.bin'``.
task_names : list of str or None
Columns in the data frame corresponding to real-valued labels. If None, we assume
all columns except the smiles_column are labels. Default to None.
load : bool
Whether to load the previously pre-processed dataset or pre-process from scratch.
``load`` should be False when we want to try different graph construction and
featurization methods and need to preprocess from scratch. Default to True.
"""
@deprecated('Import MoleculeCSVDataset from dgllife.data instead.', 'class')
def __init__(self, df, smiles_to_graph, node_featurizer, edge_featurizer,
smiles_column, cache_file_path, task_names=None, load=True):
if 'rdkit' not in sys.modules:
from ....base import dgl_warning
dgl_warning(
"Please install RDKit (Recommended Version is 2018.09.3)")
self.df = df
self.smiles = self.df[smiles_column].tolist()
if task_names is None:
self.task_names = self.df.columns.drop([smiles_column]).tolist()
else:
self.task_names = task_names
self.n_tasks = len(self.task_names)
self.cache_file_path = cache_file_path
self._pre_process(smiles_to_graph, node_featurizer, edge_featurizer, load)
def _pre_process(self, smiles_to_graph, node_featurizer, edge_featurizer, load):
"""Pre-process the dataset
* Convert molecules from smiles format into DGLGraphs
and featurize their atoms
* Set missing labels to be 0 and use a binary masking
matrix to mask them
Parameters
----------
smiles_to_graph : callable, SMILES -> DGLGraph
Function for converting a SMILES (str) into a DGLGraph.
node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for nodes like atoms in a molecule, which can be used to update
ndata for a DGLGraph.
edge_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for edges like bonds in a molecule, which can be used to update
edata for a DGLGraph.
load : bool
Whether to load the previously pre-processed dataset or pre-process from scratch.
``load`` should be False when we want to try different graph construction and
featurization methods and need to preprocess from scratch. Default to True.
"""
if os.path.exists(self.cache_file_path) and load:
# DGLGraphs have been constructed before, reload them
print('Loading previously saved dgl graphs...')
self.graphs, label_dict = load_graphs(self.cache_file_path)
self.labels = label_dict['labels']
self.mask = label_dict['mask']
else:
print('Processing dgl graphs from scratch...')
self.graphs = []
for i, s in enumerate(self.smiles):
print('Processing molecule {:d}/{:d}'.format(i+1, len(self)))
self.graphs.append(smiles_to_graph(s, node_featurizer=node_featurizer,
edge_featurizer=edge_featurizer))
_label_values = self.df[self.task_names].values
# np.nan_to_num will also turn inf into a very large number
self.labels = F.zerocopy_from_numpy(np.nan_to_num(_label_values).astype(np.float32))
self.mask = F.zerocopy_from_numpy((~np.isnan(_label_values)).astype(np.float32))
save_graphs(self.cache_file_path, self.graphs,
labels={'labels': self.labels, 'mask': self.mask})
def __getitem__(self, item):
"""Get datapoint with index
Parameters
----------
item : int
Datapoint index
Returns
-------
str
SMILES for the ith datapoint
DGLGraph
DGLGraph for the ith datapoint
Tensor of dtype float32
Labels of the datapoint for all tasks
Tensor of dtype float32
Binary masks indicating the existence of labels for all tasks
"""
return self.smiles[item], self.graphs[item], self.labels[item], self.mask[item]
def __len__(self):
"""Length of the dataset
Returns
-------
int
Length of Dataset
"""
return len(self.smiles)
"""PDBBind dataset processed by MoleculeNet."""
import numpy as np
import os
import pandas as pd
from ..utils import multiprocess_load_molecules, ACNN_graph_construction_and_featurization
from ...utils import get_download_dir, download, _get_dgl_url, extract_archive
from ....utils import retry_method_with_fix
from .... import backend as F
from ....contrib.deprecation import deprecated
class PDBBind(object):
"""PDBbind dataset processed by MoleculeNet.
The description below is mainly based on
`[1] <https://pubs.rsc.org/en/content/articlelanding/2018/sc/c7sc02664a#cit50>`__.
The PDBBind database consists of experimentally measured binding affinities for
bio-molecular complexes `[2] <https://www.ncbi.nlm.nih.gov/pubmed/?term=15163179%5Buid%5D>`__,
`[3] <https://www.ncbi.nlm.nih.gov/pubmed/?term=15943484%5Buid%5D>`__. It provides detailed
3D Cartesian coordinates of both ligands and their target proteins derived from experimental
(e.g., X-ray crystallography) measurements. The availability of coordinates of the
protein-ligand complexes permits structure-based featurization that is aware of the
protein-ligand binding geometry. The authors of
`[1] <https://pubs.rsc.org/en/content/articlelanding/2018/sc/c7sc02664a#cit50>`__ use the
"refined" and "core" subsets of the database
`[4] <https://www.ncbi.nlm.nih.gov/pubmed/?term=25301850%5Buid%5D>`__, more carefully
processed for data artifacts, as additional benchmarking targets.
References:
* [1] MoleculeNet: a benchmark for molecular machine learning
* [2] The PDBbind database: collection of binding affinities for protein-ligand complexes
with known three-dimensional structures
* [3] The PDBbind database: methodologies and updates
* [4] PDB-wide collection of binding data: current status of the PDBbind database
Parameters
----------
subset : str
In MoleculeNet, we can use either the "refined" subset or the "core" subset. We can
retrieve them by setting ``subset`` to be ``'refined'`` or ``'core'``. The size
of the ``'core'`` set is 195 and the size of the ``'refined'`` set is 3706.
load_binding_pocket : bool
Whether to load binding pockets or full proteins. Default to True.
add_hydrogens : bool
Whether to add hydrogens via pdbfixer. Default to False.
sanitize : bool
Whether sanitization is performed in initializing RDKit molecule instances. See
https://www.rdkit.org/docs/RDKit_Book.html for details of the sanitization.
Default to False.
calc_charges : bool
Whether to add Gasteiger charges via RDKit. Setting this to be True will enforce
``add_hydrogens`` and ``sanitize`` to be True. Default to False.
remove_hs : bool
Whether to remove hydrogens via RDKit. Note that removing hydrogens can be quite
slow for large molecules. Default to False.
use_conformation : bool
Whether we need to extract molecular conformation from proteins and ligands.
Default to True.
construct_graph_and_featurize : callable
Construct a DGLHeteroGraph for the use of GNNs. Mapping self.ligand_mols[i],
self.protein_mols[i], self.ligand_coordinates[i] and self.protein_coordinates[i]
to a DGLHeteroGraph. Default to :func:`ACNN_graph_construction_and_featurization`.
zero_padding : bool
Whether to perform zero padding. While DGL does not necessarily require zero padding,
pooling operations for variable length inputs can introduce stochastic behaviour, which
is not desired for sensitive scenarios. Default to True.
num_processes : int or None
Number of worker processes to use. If None,
then we will use the number of CPUs in the system. Default to 64.
"""
@deprecated('Import PDBBind from dgllife.data instead.', 'class')
def __init__(self, subset, load_binding_pocket=True, add_hydrogens=False,
sanitize=False, calc_charges=False, remove_hs=False, use_conformation=True,
construct_graph_and_featurize=ACNN_graph_construction_and_featurization,
zero_padding=True, num_processes=64):
self.task_names = ['-logKd/Ki']
self.n_tasks = len(self.task_names)
self._url = 'dataset/pdbbind_v2015.tar.gz'
root_dir_path = get_download_dir()
data_path = root_dir_path + '/pdbbind_v2015.tar.gz'
extracted_data_path = root_dir_path + '/pdbbind_v2015'
if subset == 'core':
index_label_file = extracted_data_path + '/v2015/INDEX_core_data.2013'
elif subset == 'refined':
index_label_file = extracted_data_path + '/v2015/INDEX_refined_data.2015'
else:
raise ValueError(
'Expect the subset_choice to be either '
'core or refined, got {}'.format(subset))
self._data_path = data_path
self._extracted_data_path = extracted_data_path
self._preprocess(extracted_data_path, index_label_file, load_binding_pocket,
add_hydrogens, sanitize, calc_charges, remove_hs, use_conformation,
construct_graph_and_featurize, zero_padding, num_processes)
def _filter_out_invalid(self, ligands_loaded, proteins_loaded, use_conformation):
"""Filter out invalid ligand-protein pairs.
Parameters
----------
ligands_loaded : list
Each element is a 2-tuple of the RDKit molecule instance and its associated atom
coordinates. None is used to represent invalid/non-existing molecule or coordinates.
proteins_loaded : list
Each element is a 2-tuple of the RDKit molecule instance and its associated atom
coordinates. None is used to represent invalid/non-existing molecule or coordinates.
use_conformation : bool
Whether we need conformation information (atom coordinates) and filter out molecules
without valid conformation.
"""
num_pairs = len(proteins_loaded)
self.indices, self.ligand_mols, self.protein_mols = [], [], []
if use_conformation:
self.ligand_coordinates, self.protein_coordinates = [], []
else:
# Use None for placeholders.
self.ligand_coordinates = [None for _ in range(num_pairs)]
self.protein_coordinates = [None for _ in range(num_pairs)]
for i in range(num_pairs):
ligand_mol, ligand_coordinates = ligands_loaded[i]
protein_mol, protein_coordinates = proteins_loaded[i]
if (not use_conformation) and all(v is not None for v in [protein_mol, ligand_mol]):
self.indices.append(i)
self.ligand_mols.append(ligand_mol)
self.protein_mols.append(protein_mol)
elif all(v is not None for v in [
protein_mol, protein_coordinates, ligand_mol, ligand_coordinates]):
self.indices.append(i)
self.ligand_mols.append(ligand_mol)
self.ligand_coordinates.append(ligand_coordinates)
self.protein_mols.append(protein_mol)
self.protein_coordinates.append(protein_coordinates)
def _download_and_extract(self):
download(_get_dgl_url(self._url), path=self._data_path)
extract_archive(self._data_path, self._extracted_data_path)
@retry_method_with_fix(_download_and_extract)
def _preprocess(self, root_path, index_label_file, load_binding_pocket,
add_hydrogens, sanitize, calc_charges, remove_hs, use_conformation,
construct_graph_and_featurize, zero_padding, num_processes):
"""Preprocess the dataset.
The pre-processing proceeds as follows:
1. Load the dataset
2. Clean the dataset and filter out invalid pairs
3. Construct graphs
4. Prepare node and edge features
Parameters
----------
root_path : str
Root path for molecule files.
index_label_file : str
Path to the index file for the dataset.
load_binding_pocket : bool
Whether to load binding pockets or full proteins.
add_hydrogens : bool
Whether to add hydrogens via pdbfixer.
sanitize : bool
Whether sanitization is performed in initializing RDKit molecule instances. See
https://www.rdkit.org/docs/RDKit_Book.html for details of the sanitization.
calc_charges : bool
Whether to add Gasteiger charges via RDKit. Setting this to be True will enforce
``add_hydrogens`` and ``sanitize`` to be True.
remove_hs : bool
Whether to remove hydrogens via RDKit. Note that removing hydrogens can be quite
slow for large molecules.
use_conformation : bool
Whether we need to extract molecular conformation from proteins and ligands.
construct_graph_and_featurize : callable
Construct a DGLHeteroGraph for the use of GNNs. Mapping self.ligand_mols[i],
self.protein_mols[i], self.ligand_coordinates[i] and self.protein_coordinates[i]
to a DGLHeteroGraph. Default to :func:`ACNN_graph_construction_and_featurization`.
zero_padding : bool
Whether to perform zero padding. While DGL does not necessarily require zero padding,
pooling operations for variable length inputs can introduce stochastic behaviour, which
is not desired for sensitive scenarios.
num_processes : int or None
Number of worker processes to use. If None,
then we will use the number of CPUs in the system.
"""
contents = []
with open(index_label_file, 'r') as f:
for line in f.readlines():
if line[0] != "#":
splitted_elements = line.split()
if len(splitted_elements) == 8:
# Ignore "//"
contents.append(splitted_elements[:5] + splitted_elements[6:])
else:
print('Incorrect data format.')
print(splitted_elements)
self.df = pd.DataFrame(contents, columns=(
'PDB_code', 'resolution', 'release_year',
'-logKd/Ki', 'Kd/Ki', 'reference', 'ligand_name'))
pdbs = self.df['PDB_code'].tolist()
self.ligand_files = [os.path.join(
root_path, 'v2015', pdb, '{}_ligand.sdf'.format(pdb)) for pdb in pdbs]
if load_binding_pocket:
self.protein_files = [os.path.join(
root_path, 'v2015', pdb, '{}_pocket.pdb'.format(pdb)) for pdb in pdbs]
else:
self.protein_files = [os.path.join(
root_path, 'v2015', pdb, '{}_protein.pdb'.format(pdb)) for pdb in pdbs]
num_processes = min(num_processes, len(pdbs))
print('Loading ligands...')
ligands_loaded = multiprocess_load_molecules(self.ligand_files,
add_hydrogens=add_hydrogens,
sanitize=sanitize,
calc_charges=calc_charges,
remove_hs=remove_hs,
use_conformation=use_conformation,
num_processes=num_processes)
print('Loading proteins...')
proteins_loaded = multiprocess_load_molecules(self.protein_files,
add_hydrogens=add_hydrogens,
sanitize=sanitize,
calc_charges=calc_charges,
remove_hs=remove_hs,
use_conformation=use_conformation,
num_processes=num_processes)
self._filter_out_invalid(ligands_loaded, proteins_loaded, use_conformation)
self.df = self.df.iloc[self.indices]
self.labels = F.zerocopy_from_numpy(self.df[self.task_names].values.astype(np.float32))
print('Finished cleaning the dataset, '
'got {:d}/{:d} valid pairs'.format(len(self), len(pdbs)))
# Prepare zero padding
if zero_padding:
max_num_ligand_atoms = 0
max_num_protein_atoms = 0
for i in range(len(self)):
max_num_ligand_atoms = max(
max_num_ligand_atoms, self.ligand_mols[i].GetNumAtoms())
max_num_protein_atoms = max(
max_num_protein_atoms, self.protein_mols[i].GetNumAtoms())
else:
max_num_ligand_atoms = None
max_num_protein_atoms = None
print('Start constructing graphs and featurizing them.')
self.graphs = []
for i in range(len(self)):
print('Constructing and featurizing datapoint {:d}/{:d}'.format(i+1, len(self)))
self.graphs.append(construct_graph_and_featurize(
self.ligand_mols[i], self.protein_mols[i],
self.ligand_coordinates[i], self.protein_coordinates[i],
max_num_ligand_atoms, max_num_protein_atoms))
def __len__(self):
"""Get the size of the dataset.
Returns
-------
int
Number of valid ligand-protein pairs in the dataset.
"""
return len(self.indices)
def __getitem__(self, item):
"""Get the datapoint associated with the index.
Parameters
----------
item : int
Index for the datapoint.
Returns
-------
int
Index for the datapoint.
rdkit.Chem.rdchem.Mol
RDKit molecule instance for the ligand molecule.
rdkit.Chem.rdchem.Mol
RDKit molecule instance for the protein molecule.
DGLHeteroGraph
Pre-processed DGLHeteroGraph with features extracted.
Float32 tensor
Label for the datapoint.
"""
return item, self.ligand_mols[item], self.protein_mols[item], \
self.graphs[item], self.labels[item]
import pandas as pd
import sys
from .csv_dataset import MoleculeCSVDataset
from ..utils import smiles_to_bigraph
from ...utils import get_download_dir, download, _get_dgl_url
from ....utils import retry_method_with_fix
from ....base import dgl_warning
from ....contrib.deprecation import deprecated
class PubChemBioAssayAromaticity(MoleculeCSVDataset):
"""Subset of PubChem BioAssay Dataset for aromaticity prediction.
The dataset was constructed in `Pushing the Boundaries of Molecular Representation for Drug
Discovery with the Graph Attention Mechanism.
<https://www.ncbi.nlm.nih.gov/pubmed/31408336>`__ and is accompanied by the task of predicting
the number of aromatic atoms in molecules.
The dataset was constructed by sampling 3945 molecules with 0-40 aromatic atoms from the
PubChem BioAssay dataset.
Parameters
----------
smiles_to_graph: callable, str -> DGLGraph
A function turning smiles into a DGLGraph.
Default to :func:`dgl.data.chem.smiles_to_bigraph`.
node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for nodes like atoms in a molecule, which can be used to update
ndata for a DGLGraph. Default to None.
edge_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for edges like bonds in a molecule, which can be used to update
edata for a DGLGraph. Default to None.
load : bool
Whether to load the previously pre-processed dataset or pre-process from scratch.
``load`` should be False when we want to try different graph construction and
featurization methods and need to pre-process from scratch. Default to True.
"""
@deprecated('Import PubChemBioAssayAromaticity from dgllife.data instead.', 'class')
def __init__(self, smiles_to_graph=smiles_to_bigraph,
node_featurizer=None, edge_featurizer=None, load=True):
if 'pandas' not in sys.modules:
dgl_warning("Please install pandas")
self._url = 'dataset/pubchem_bioassay_aromaticity.csv'
data_path = get_download_dir() + '/pubchem_bioassay_aromaticity.csv'
self._data_path = data_path
self._load(data_path, smiles_to_graph, node_featurizer, edge_featurizer, load)
def _download(self):
download(_get_dgl_url(self._url), path=self._data_path)
@retry_method_with_fix(_download)
def _load(self, data_path, smiles_to_graph, node_featurizer, edge_featurizer, load):
df = pd.read_csv(data_path)
super(PubChemBioAssayAromaticity, self).__init__(
df, smiles_to_graph, node_featurizer, edge_featurizer, "cano_smiles",
"pubchem_aromaticity_dglgraph.bin", load=load)
import sys
from .csv_dataset import MoleculeCSVDataset
from ..utils import smiles_to_bigraph
from ...utils import get_download_dir, download, _get_dgl_url
from .... import backend as F
from ....utils import retry_method_with_fix
from ....base import dgl_warning
from ....contrib.deprecation import deprecated
try:
import pandas as pd
except ImportError:
pass
class Tox21(MoleculeCSVDataset):
"""Tox21 dataset.
The Toxicology in the 21st Century (https://tripod.nih.gov/tox21/challenge/)
initiative created a public database measuring toxicity of compounds, which
has been used in the 2014 Tox21 Data Challenge. The dataset contains qualitative
toxicity measurements for 8014 compounds on 12 different targets, including nuclear
receptors and stress response pathways. Each target results in a binary label.
A common issue for multi-task prediction is that some datapoints are not labeled for
all tasks. This is also the case for Tox21. In data pre-processing, we set non-existing
labels to be 0 so that they can be placed in tensors and used for masking in loss computation.
See examples below for more details.
All molecules are converted into DGLGraphs. After the first-time construction,
the DGLGraphs will be saved for reloading so that we do not need to reconstruct them everytime.
Parameters
----------
smiles_to_graph: callable, str -> DGLGraph
A function turning smiles into a DGLGraph.
Default to :func:`dgl.data.chem.smiles_to_bigraph`.
node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for nodes like atoms in a molecule, which can be used to update
ndata for a DGLGraph. Default to None.
edge_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for edges like bonds in a molecule, which can be used to update
edata for a DGLGraph. Default to None.
load : bool
Whether to load the previously pre-processed dataset or pre-process from scratch.
``load`` should be False when we want to try different graph construction and
featurization methods and need to preprocess from scratch. Default to True.
"""
@deprecated('Import Tox21 from dgllife.data instead.', 'class')
def __init__(self, smiles_to_graph=smiles_to_bigraph,
node_featurizer=None,
edge_featurizer=None,
load=True):
if 'pandas' not in sys.modules:
dgl_warning("Please install pandas")
self._url = 'dataset/tox21.csv.gz'
data_path = get_download_dir() + '/tox21.csv.gz'
self._data_path = data_path
self._load(data_path, smiles_to_graph, node_featurizer, edge_featurizer, load)
def _download(self):
download(_get_dgl_url(self._url), path=self._data_path)
@retry_method_with_fix(_download)
def _load(self, data_path, smiles_to_graph, node_featurizer, edge_featurizer, load):
df = pd.read_csv(data_path)
self.id = df['mol_id']
df = df.drop(columns=['mol_id'])
super(Tox21, self).__init__(df, smiles_to_graph, node_featurizer, edge_featurizer,
"smiles", "tox21_dglgraph.bin", load=load)
self._weight_balancing()
def _weight_balancing(self):
"""Perform re-balancing for each task.
It's quite common that the number of positive samples and the
number of negative samples are significantly different. To compensate
for the class imbalance issue, we can weight each datapoint in
loss computation.
In particular, for each task we will set the weight of negative samples
to be 1 and the weight of positive samples to be the number of negative
samples divided by the number of positive samples.
If weight balancing is performed, one attribute will be affected:
* self._task_pos_weights is set, which is a list of positive sample weights
for each task.
"""
num_pos = F.sum(self.labels, dim=0)
num_indices = F.sum(self.mask, dim=0)
self._task_pos_weights = (num_indices - num_pos) / num_pos
@property
def task_pos_weights(self):
"""Get weights for positive samples on each task
Returns
-------
numpy.ndarray
numpy array gives the weight of positive samples on all tasks
"""
return self._task_pos_weights
from .splitters import *
from .featurizers import *
from .mol_to_graph import *
from .complex_to_graph import *
from .rdkit_utils import *
"""Convert complexes into DGLHeteroGraphs"""
import numpy as np
from ..utils import k_nearest_neighbors
from .... import graph, bipartite, hetero_from_relations
from .... import backend as F
from ....contrib.deprecation import deprecated
__all__ = ['ACNN_graph_construction_and_featurization']
def filter_out_hydrogens(mol):
"""Get indices for non-hydrogen atoms.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance.
Returns
-------
indices_left : list of int
Indices of non-hydrogen atoms.
"""
indices_left = []
for i, atom in enumerate(mol.GetAtoms()):
atomic_num = atom.GetAtomicNum()
# Hydrogen atoms have an atomic number of 1.
if atomic_num != 1:
indices_left.append(i)
return indices_left
def get_atomic_numbers(mol, indices):
"""Get the atomic numbers for the specified atoms.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance.
indices : list of int
Specifying atoms.
Returns
-------
list of int
Atomic numbers computed.
"""
atomic_numbers = []
for i in indices:
atom = mol.GetAtomWithIdx(i)
atomic_numbers.append(atom.GetAtomicNum())
return atomic_numbers
@deprecated('Import it from dgllife.utils instead.')
def ACNN_graph_construction_and_featurization(ligand_mol,
protein_mol,
ligand_coordinates,
protein_coordinates,
max_num_ligand_atoms=None,
max_num_protein_atoms=None,
neighbor_cutoff=12.,
max_num_neighbors=12,
strip_hydrogens=False):
"""Graph construction and featurization for `Atomic Convolutional Networks for
Predicting Protein-Ligand Binding Affinity <https://arxiv.org/abs/1703.10603>`__.
Parameters
----------
ligand_mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance.
protein_mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance.
ligand_coordinates : Float Tensor of shape (V1, 3)
Atom coordinates in a ligand.
protein_coordinates : Float Tensor of shape (V2, 3)
Atom coordinates in a protein.
max_num_ligand_atoms : int or None
Maximum number of atoms in ligands for zero padding, which should be no smaller than
ligand_mol.GetNumAtoms() if not None. If None, no zero padding will be performed.
Default to None.
max_num_protein_atoms : int or None
Maximum number of atoms in proteins for zero padding, which should be no smaller than
protein_mol.GetNumAtoms() if not None. If None, no zero padding will be performed.
Default to None.
neighbor_cutoff : float
Distance cutoff to define 'neighboring'. Default to 12.
max_num_neighbors : int
Maximum number of neighbors allowed for each atom. Default to 12.
strip_hydrogens : bool
Whether to exclude hydrogen atoms. Default to False.
"""
assert ligand_coordinates is not None, 'Expect ligand_coordinates to be provided.'
assert protein_coordinates is not None, 'Expect protein_coordinates to be provided.'
if max_num_ligand_atoms is not None:
assert max_num_ligand_atoms >= ligand_mol.GetNumAtoms(), \
'Expect max_num_ligand_atoms to be no smaller than ligand_mol.GetNumAtoms()'
if max_num_protein_atoms is not None:
assert max_num_protein_atoms >= protein_mol.GetNumAtoms(), \
'Expect max_num_protein_atoms to be no smaller than protein_mol.GetNumAtoms()'
if strip_hydrogens:
# Remove hydrogen atoms and their corresponding coordinates
ligand_atom_indices_left = filter_out_hydrogens(ligand_mol)
protein_atom_indices_left = filter_out_hydrogens(protein_mol)
ligand_coordinates = ligand_coordinates.take(ligand_atom_indices_left, axis=0)
protein_coordinates = protein_coordinates.take(protein_atom_indices_left, axis=0)
else:
ligand_atom_indices_left = list(range(ligand_mol.GetNumAtoms()))
protein_atom_indices_left = list(range(protein_mol.GetNumAtoms()))
# Compute number of nodes for each type
if max_num_ligand_atoms is None:
num_ligand_atoms = len(ligand_atom_indices_left)
else:
num_ligand_atoms = max_num_ligand_atoms
if max_num_protein_atoms is None:
num_protein_atoms = len(protein_atom_indices_left)
else:
num_protein_atoms = max_num_protein_atoms
# Construct graph for atoms in the ligand
ligand_srcs, ligand_dsts, ligand_dists = k_nearest_neighbors(
ligand_coordinates, neighbor_cutoff, max_num_neighbors)
ligand_graph = graph((ligand_srcs, ligand_dsts),
'ligand_atom', 'ligand', num_ligand_atoms)
ligand_graph.edata['distance'] = F.reshape(F.zerocopy_from_numpy(
np.asarray(ligand_dists, dtype=np.float32)), (-1, 1))
# Construct graph for atoms in the protein
protein_srcs, protein_dsts, protein_dists = k_nearest_neighbors(
protein_coordinates, neighbor_cutoff, max_num_neighbors)
protein_graph = graph((protein_srcs, protein_dsts),
'protein_atom', 'protein', num_protein_atoms)
protein_graph.edata['distance'] = F.reshape(F.zerocopy_from_numpy(
np.asarray(protein_dists, dtype=np.float32)), (-1, 1))
# Construct 4 graphs for complex representation, including the connection within
# protein atoms, the connection within ligand atoms and the connection between
# protein and ligand atoms.
complex_srcs, complex_dsts, complex_dists = k_nearest_neighbors(
np.concatenate([ligand_coordinates, protein_coordinates]),
neighbor_cutoff, max_num_neighbors)
complex_srcs = np.asarray(complex_srcs)
complex_dsts = np.asarray(complex_dsts)
complex_dists = np.asarray(complex_dists)
offset = num_ligand_atoms
# ('ligand_atom', 'complex', 'ligand_atom')
inter_ligand_indices = np.intersect1d(
(complex_srcs < offset).nonzero()[0],
(complex_dsts < offset).nonzero()[0],
assume_unique=True)
inter_ligand_graph = graph(
(complex_srcs[inter_ligand_indices].tolist(),
complex_dsts[inter_ligand_indices].tolist()),
'ligand_atom', 'complex', num_ligand_atoms)
inter_ligand_graph.edata['distance'] = F.reshape(F.zerocopy_from_numpy(
complex_dists[inter_ligand_indices].astype(np.float32)), (-1, 1))
# ('protein_atom', 'complex', 'protein_atom')
inter_protein_indices = np.intersect1d(
(complex_srcs >= offset).nonzero()[0],
(complex_dsts >= offset).nonzero()[0],
assume_unique=True)
inter_protein_graph = graph(
((complex_srcs[inter_protein_indices] - offset).tolist(),
(complex_dsts[inter_protein_indices] - offset).tolist()),
'protein_atom', 'complex', num_protein_atoms)
inter_protein_graph.edata['distance'] = F.reshape(F.zerocopy_from_numpy(
complex_dists[inter_protein_indices].astype(np.float32)), (-1, 1))
# ('ligand_atom', 'complex', 'protein_atom')
ligand_protein_indices = np.intersect1d(
(complex_srcs < offset).nonzero()[0],
(complex_dsts >= offset).nonzero()[0],
assume_unique=True)
ligand_protein_graph = bipartite(
(complex_srcs[ligand_protein_indices].tolist(),
(complex_dsts[ligand_protein_indices] - offset).tolist()),
'ligand_atom', 'complex', 'protein_atom',
(num_ligand_atoms, num_protein_atoms))
ligand_protein_graph.edata['distance'] = F.reshape(F.zerocopy_from_numpy(
complex_dists[ligand_protein_indices].astype(np.float32)), (-1, 1))
# ('protein_atom', 'complex', 'ligand_atom')
protein_ligand_indices = np.intersect1d(
(complex_srcs >= offset).nonzero()[0],
(complex_dsts < offset).nonzero()[0],
assume_unique=True)
protein_ligand_graph = bipartite(
((complex_srcs[protein_ligand_indices] - offset).tolist(),
complex_dsts[protein_ligand_indices].tolist()),
'protein_atom', 'complex', 'ligand_atom',
(num_protein_atoms, num_ligand_atoms))
protein_ligand_graph.edata['distance'] = F.reshape(F.zerocopy_from_numpy(
complex_dists[protein_ligand_indices].astype(np.float32)), (-1, 1))
# Merge the graphs
g = hetero_from_relations(
[protein_graph,
ligand_graph,
inter_ligand_graph,
inter_protein_graph,
ligand_protein_graph,
protein_ligand_graph]
)
# Get atomic numbers for all atoms left and set node features
ligand_atomic_numbers = np.asarray(get_atomic_numbers(ligand_mol, ligand_atom_indices_left))
# zero padding
ligand_atomic_numbers = np.concatenate([
ligand_atomic_numbers, np.zeros(num_ligand_atoms - len(ligand_atom_indices_left))])
protein_atomic_numbers = np.asarray(get_atomic_numbers(protein_mol, protein_atom_indices_left))
# zero padding
protein_atomic_numbers = np.concatenate([
protein_atomic_numbers, np.zeros(num_protein_atoms - len(protein_atom_indices_left))])
g.nodes['ligand_atom'].data['atomic_number'] = F.reshape(F.zerocopy_from_numpy(
ligand_atomic_numbers.astype(np.float32)), (-1, 1))
g.nodes['protein_atom'].data['atomic_number'] = F.reshape(F.zerocopy_from_numpy(
protein_atomic_numbers.astype(np.float32)), (-1, 1))
# Prepare mask indicating the existence of nodes
ligand_masks = np.zeros((num_ligand_atoms, 1))
ligand_masks[:len(ligand_atom_indices_left), :] = 1
g.nodes['ligand_atom'].data['mask'] = F.zerocopy_from_numpy(
ligand_masks.astype(np.float32))
protein_masks = np.zeros((num_protein_atoms, 1))
protein_masks[:len(protein_atom_indices_left), :] = 1
g.nodes['protein_atom'].data['mask'] = F.zerocopy_from_numpy(
protein_masks.astype(np.float32))
return g
import itertools
import numpy as np
from collections import defaultdict
from .... import backend as F
from ....contrib.deprecation import deprecated
try:
from rdkit import Chem
from rdkit.Chem import rdmolfiles, rdmolops
except ImportError:
pass
__all__ = ['one_hot_encoding',
'atom_type_one_hot',
'atomic_number_one_hot',
'atomic_number',
'atom_degree_one_hot',
'atom_degree',
'atom_total_degree_one_hot',
'atom_total_degree',
'atom_implicit_valence_one_hot',
'atom_implicit_valence',
'atom_hybridization_one_hot',
'atom_total_num_H_one_hot',
'atom_total_num_H',
'atom_formal_charge_one_hot',
'atom_formal_charge',
'atom_num_radical_electrons_one_hot',
'atom_num_radical_electrons',
'atom_is_aromatic_one_hot',
'atom_is_aromatic',
'atom_chiral_tag_one_hot',
'atom_mass',
'ConcatFeaturizer',
'BaseAtomFeaturizer',
'CanonicalAtomFeaturizer',
'bond_type_one_hot',
'bond_is_conjugated_one_hot',
'bond_is_conjugated',
'bond_is_in_ring_one_hot',
'bond_is_in_ring',
'bond_stereo_one_hot',
'BaseBondFeaturizer',
'CanonicalBondFeaturizer']
@deprecated('Import it from dgllife.utils instead.')
def one_hot_encoding(x, allowable_set, encode_unknown=False):
"""One-hot encoding.
Parameters
----------
x
Value to encode.
allowable_set : list
The elements of the allowable_set should be of the
same type as x.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element.
Returns
-------
list
List of boolean values where at most one value is True.
The list is of length ``len(allowable_set)`` if ``encode_unknown=False``
and ``len(allowable_set) + 1`` otherwise.
"""
if encode_unknown and (allowable_set[-1] is not None):
allowable_set.append(None)
if encode_unknown and (x not in allowable_set):
x = None
return list(map(lambda s: x == s, allowable_set))
#################################################################
# Atom featurization
#################################################################
@deprecated('Import it from dgllife.utils instead.')
def atom_type_one_hot(atom, allowable_set=None, encode_unknown=False):
"""One hot encoding for the type of an atom.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
allowable_set : list of str
Atom types to consider. Default: ``C``, ``N``, ``O``, ``S``, ``F``, ``Si``, ``P``,
``Cl``, ``Br``, ``Mg``, ``Na``, ``Ca``, ``Fe``, ``As``, ``Al``, ``I``, ``B``, ``V``,
``K``, ``Tl``, ``Yb``, ``Sb``, ``Sn``, ``Ag``, ``Pd``, ``Co``, ``Se``, ``Ti``, ``Zn``,
``H``, ``Li``, ``Ge``, ``Cu``, ``Au``, ``Ni``, ``Cd``, ``In``, ``Mn``, ``Zr``, ``Cr``,
``Pt``, ``Hg``, ``Pb``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
Returns
-------
list
List of boolean values where at most one value is True.
"""
if allowable_set is None:
allowable_set = ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na', 'Ca',
'Fe', 'As', 'Al', 'I', 'B', 'V', 'K', 'Tl', 'Yb', 'Sb', 'Sn',
'Ag', 'Pd', 'Co', 'Se', 'Ti', 'Zn', 'H', 'Li', 'Ge', 'Cu', 'Au',
'Ni', 'Cd', 'In', 'Mn', 'Zr', 'Cr', 'Pt', 'Hg', 'Pb']
return one_hot_encoding(atom.GetSymbol(), allowable_set, encode_unknown)
@deprecated('Import it from dgllife.utils instead.')
def atomic_number_one_hot(atom, allowable_set=None, encode_unknown=False):
"""One hot encoding for the atomic number of an atom.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
allowable_set : list of int
Atomic numbers to consider. Default: ``1`` - ``100``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
Returns
-------
list
List of boolean values where at most one value is True.
"""
if allowable_set is None:
allowable_set = list(range(1, 101))
return one_hot_encoding(atom.GetAtomicNum(), allowable_set, encode_unknown)
@deprecated('Import it from dgllife.utils instead.')
def atomic_number(atom):
"""Get the atomic number for an atom.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
Returns
-------
list
List containing one int only.
"""
return [atom.GetAtomicNum()]
@deprecated('Import it from dgllife.utils instead.')
def atom_degree_one_hot(atom, allowable_set=None, encode_unknown=False):
"""One hot encoding for the degree of an atom.
Note that the result will be different depending on whether the Hs are
explicitly modeled in the graph.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
allowable_set : list of int
Atom degrees to consider. Default: ``0`` - ``10``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
Returns
-------
list
List of boolean values where at most one value is True.
See Also
--------
atom_total_degree_one_hot
"""
if allowable_set is None:
allowable_set = list(range(11))
return one_hot_encoding(atom.GetDegree(), allowable_set, encode_unknown)
@deprecated('Import it from dgllife.utils instead.')
def atom_degree(atom):
"""Get the degree of an atom.
Note that the result will be different depending on whether the Hs are
explicitly modeled in the graph.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
Returns
-------
list
List containing one int only.
See Also
--------
atom_total_degree
"""
return [atom.GetDegree()]
@deprecated('Import it from dgllife.utils instead.')
def atom_total_degree_one_hot(atom, allowable_set=None, encode_unknown=False):
"""One hot encoding for the degree of an atom including Hs.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
allowable_set : list
Total degrees to consider. Default: ``0`` - ``5``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
See Also
--------
atom_degree_one_hot
"""
if allowable_set is None:
allowable_set = list(range(6))
return one_hot_encoding(atom.GetTotalDegree(), allowable_set, encode_unknown)
@deprecated('Import it from dgllife.utils instead.')
def atom_total_degree(atom):
"""The degree of an atom including Hs.
See Also
--------
atom_degree
Returns
-------
list
List containing one int only.
"""
return [atom.GetTotalDegree()]
@deprecated('Import it from dgllife.utils instead.')
def atom_implicit_valence_one_hot(atom, allowable_set=None, encode_unknown=False):
"""One hot encoding for the implicit valences of an atom.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
allowable_set : list of int
Atom implicit valences to consider. Default: ``0`` - ``6``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
Returns
-------
list
List of boolean values where at most one value is True.
"""
if allowable_set is None:
allowable_set = list(range(7))
return one_hot_encoding(atom.GetImplicitValence(), allowable_set, encode_unknown)
@deprecated('Import it from dgllife.utils instead.')
def atom_implicit_valence(atom):
"""Get the implicit valence of an atom.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
Reurns
------
list
List containing one int only.
"""
return [atom.GetImplicitValence()]
@deprecated('Import it from dgllife.utils instead.')
def atom_hybridization_one_hot(atom, allowable_set=None, encode_unknown=False):
"""One hot encoding for the hybridization of an atom.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
allowable_set : list of rdkit.Chem.rdchem.HybridizationType
Atom hybridizations to consider. Default: ``Chem.rdchem.HybridizationType.SP``,
``Chem.rdchem.HybridizationType.SP2``, ``Chem.rdchem.HybridizationType.SP3``,
``Chem.rdchem.HybridizationType.SP3D``, ``Chem.rdchem.HybridizationType.SP3D2``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
Returns
-------
list
List of boolean values where at most one value is True.
"""
if allowable_set is None:
allowable_set = [Chem.rdchem.HybridizationType.SP,
Chem.rdchem.HybridizationType.SP2,
Chem.rdchem.HybridizationType.SP3,
Chem.rdchem.HybridizationType.SP3D,
Chem.rdchem.HybridizationType.SP3D2]
return one_hot_encoding(atom.GetHybridization(), allowable_set, encode_unknown)
@deprecated('Import it from dgllife.utils instead.')
def atom_total_num_H_one_hot(atom, allowable_set=None, encode_unknown=False):
"""One hot encoding for the total number of Hs of an atom.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
allowable_set : list of int
Total number of Hs to consider. Default: ``0`` - ``4``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
Returns
-------
list
List of boolean values where at most one value is True.
"""
if allowable_set is None:
allowable_set = list(range(5))
return one_hot_encoding(atom.GetTotalNumHs(), allowable_set, encode_unknown)
@deprecated('Import it from dgllife.utils instead.')
def atom_total_num_H(atom):
"""Get the total number of Hs of an atom.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
Returns
-------
list
List containing one int only.
"""
return [atom.GetTotalNumHs()]
@deprecated('Import it from dgllife.utils instead.')
def atom_formal_charge_one_hot(atom, allowable_set=None, encode_unknown=False):
"""One hot encoding for the formal charge of an atom.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
allowable_set : list of int
Formal charges to consider. Default: ``-2`` - ``2``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
Returns
-------
list
List of boolean values where at most one value is True.
"""
if allowable_set is None:
allowable_set = list(range(-2, 3))
return one_hot_encoding(atom.GetFormalCharge(), allowable_set, encode_unknown)
@deprecated('Import it from dgllife.utils instead.')
def atom_formal_charge(atom):
"""Get formal charge for an atom.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
Returns
-------
list
List containing one int only.
"""
return [atom.GetFormalCharge()]
@deprecated('Import it from dgllife.utils instead.')
def atom_num_radical_electrons_one_hot(atom, allowable_set=None, encode_unknown=False):
"""One hot encoding for the number of radical electrons of an atom.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
allowable_set : list of int
Number of radical electrons to consider. Default: ``0`` - ``4``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
Returns
-------
list
List of boolean values where at most one value is True.
"""
if allowable_set is None:
allowable_set = list(range(5))
return one_hot_encoding(atom.GetNumRadicalElectrons(), allowable_set, encode_unknown)
@deprecated('Import it from dgllife.utils instead.')
def atom_num_radical_electrons(atom):
"""Get the number of radical electrons for an atom.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
Returns
-------
list
List containing one int only.
"""
return [atom.GetNumRadicalElectrons()]
@deprecated('Import it from dgllife.utils instead.')
def atom_is_aromatic_one_hot(atom, allowable_set=None, encode_unknown=False):
"""One hot encoding for whether the atom is aromatic.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
allowable_set : list of bool
Conditions to consider. Default: ``False`` and ``True``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
Returns
-------
list
List of boolean values where at most one value is True.
"""
if allowable_set is None:
allowable_set = [False, True]
return one_hot_encoding(atom.GetIsAromatic(), allowable_set, encode_unknown)
@deprecated('Import it from dgllife.utils instead.')
def atom_is_aromatic(atom):
"""Get whether the atom is aromatic.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
Returns
-------
list
List containing one bool only.
"""
return [atom.GetIsAromatic()]
@deprecated('Import it from dgllife.utils instead.')
def atom_chiral_tag_one_hot(atom, allowable_set=None, encode_unknown=False):
"""One hot encoding for the chiral tag of an atom.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
allowable_set : list of rdkit.Chem.rdchem.ChiralType
Chiral tags to consider. Default: ``rdkit.Chem.rdchem.ChiralType.CHI_UNSPECIFIED``,
``rdkit.Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW``,
``rdkit.Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW``,
``rdkit.Chem.rdchem.ChiralType.CHI_OTHER``.
"""
if allowable_set is None:
allowable_set = [Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW,
Chem.rdchem.ChiralType.CHI_OTHER]
return one_hot_encoding(atom.GetChiralTag(), allowable_set, encode_unknown)
@deprecated('Import it from dgllife.utils instead.')
def atom_mass(atom, coef=0.01):
"""Get the mass of an atom and scale it.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
coef : float
The mass will be multiplied by ``coef``.
Returns
-------
list
List containing one float only.
"""
return [atom.GetMass() * coef]
class ConcatFeaturizer(object):
"""Concatenate the evaluation results of multiple functions as a single feature.
Parameters
----------
func_list : list
List of functions for computing molecular descriptors from objects of a same
particular data type, e.g. ``rdkit.Chem.rdchem.Atom``. Each function is of signature
``func(data_type) -> list of float or bool or int``. The resulting order of
the features will follow that of the functions in the list.
"""
@deprecated('Import ConcatFeaturizer from dgllife.utils instead.', 'class')
def __init__(self, func_list):
self.func_list = func_list
def __call__(self, x):
"""Featurize the input data.
Parameters
----------
x :
Data to featurize.
Returns
-------
list
List of feature values, which can be of type bool, float or int.
"""
return list(itertools.chain.from_iterable(
[func(x) for func in self.func_list]))
class BaseAtomFeaturizer(object):
"""An abstract class for atom featurizers.
Loop over all atoms in a molecule and featurize them with the ``featurizer_funcs``.
**We assume the resulting DGLGraph will not contain any virtual nodes.**
Parameters
----------
featurizer_funcs : dict
Mapping feature name to the featurization function.
Each function is of signature ``func(rdkit.Chem.rdchem.Atom) -> list or 1D numpy array``.
feat_sizes : dict
Mapping feature name to the size of the corresponding feature. If None, they will be
computed when needed. Default: None.
Examples
--------
>>> from dgl.data.chem import BaseAtomFeaturizer, atom_mass, atom_degree_one_hot
>>> from rdkit import Chem
>>> mol = Chem.MolFromSmiles('CCO')
>>> atom_featurizer = BaseAtomFeaturizer({'mass': atom_mass, 'degree': atom_degree_one_hot})
>>> atom_featurizer(mol)
{'mass': tensor([[0.1201],
[0.1201],
[0.1600]]),
'degree': tensor([[0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])}
"""
@deprecated('Import BaseAtomFeaturizer from dgllife.utils instead.', 'class')
def __init__(self, featurizer_funcs, feat_sizes=None):
self.featurizer_funcs = featurizer_funcs
if feat_sizes is None:
feat_sizes = dict()
self._feat_sizes = feat_sizes
def feat_size(self, feat_name):
"""Get the feature size for ``feat_name``.
Returns
-------
int
Feature size for the feature with name ``feat_name``.
"""
if feat_name not in self.featurizer_funcs:
return ValueError('Expect feat_name to be in {}, got {}'.format(
list(self.featurizer_funcs.keys()), feat_name))
if feat_name not in self._feat_sizes:
atom = Chem.MolFromSmiles('C').GetAtomWithIdx(0)
self._feat_sizes[feat_name] = len(self.featurizer_funcs[feat_name](atom))
return self._feat_sizes[feat_name]
def __call__(self, mol):
"""Featurize all atoms in a molecule.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance.
Returns
-------
dict
For each function in self.featurizer_funcs with the key ``k``, store the computed
feature under the key ``k``. Each feature is a tensor of dtype float32 and shape
(N, M), where N is the number of atoms in the molecule.
"""
num_atoms = mol.GetNumAtoms()
atom_features = defaultdict(list)
# Compute features for each atom
for i in range(num_atoms):
atom = mol.GetAtomWithIdx(i)
for feat_name, feat_func in self.featurizer_funcs.items():
atom_features[feat_name].append(feat_func(atom))
# Stack the features and convert them to float arrays
processed_features = dict()
for feat_name, feat_list in atom_features.items():
feat = np.stack(feat_list)
processed_features[feat_name] = F.zerocopy_from_numpy(feat.astype(np.float32))
return processed_features
class CanonicalAtomFeaturizer(BaseAtomFeaturizer):
"""A default featurizer for atoms.
The atom features include:
* **One hot encoding of the atom type**. The supported atom types include
``C``, ``N``, ``O``, ``S``, ``F``, ``Si``, ``P``, ``Cl``, ``Br``, ``Mg``,
``Na``, ``Ca``, ``Fe``, ``As``, ``Al``, ``I``, ``B``, ``V``, ``K``, ``Tl``,
``Yb``, ``Sb``, ``Sn``, ``Ag``, ``Pd``, ``Co``, ``Se``, ``Ti``, ``Zn``,
``H``, ``Li``, ``Ge``, ``Cu``, ``Au``, ``Ni``, ``Cd``, ``In``, ``Mn``, ``Zr``,
``Cr``, ``Pt``, ``Hg``, ``Pb``.
* **One hot encoding of the atom degree**. The supported possibilities
include ``0 - 10``.
* **One hot encoding of the number of implicit Hs on the atom**. The supported
possibilities include ``0 - 6``.
* **Formal charge of the atom**.
* **Number of radical electrons of the atom**.
* **One hot encoding of the atom hybridization**. The supported possibilities include
``SP``, ``SP2``, ``SP3``, ``SP3D``, ``SP3D2``.
* **Whether the atom is aromatic**.
* **One hot encoding of the number of total Hs on the atom**. The supported possibilities
include ``0 - 4``.
**We assume the resulting DGLGraph will not contain any virtual nodes.**
Parameters
----------
atom_data_field : str
Name for storing atom features in DGLGraphs, default to be 'h'.
"""
@deprecated('Import CanonicalAtomFeaturizer from dgllife.utils instead.', 'class')
def __init__(self, atom_data_field='h'):
super(CanonicalAtomFeaturizer, self).__init__(
featurizer_funcs={atom_data_field: ConcatFeaturizer(
[atom_type_one_hot,
atom_degree_one_hot,
atom_implicit_valence_one_hot,
atom_formal_charge,
atom_num_radical_electrons,
atom_hybridization_one_hot,
atom_is_aromatic,
atom_total_num_H_one_hot]
)})
@deprecated('Import it from dgllife.utils instead.')
def bond_type_one_hot(bond, allowable_set=None, encode_unknown=False):
"""One hot encoding for the type of a bond.
Parameters
----------
bond : rdkit.Chem.rdchem.Bond
RDKit bond instance.
allowable_set : list of Chem.rdchem.BondType
Bond types to consider. Default: ``Chem.rdchem.BondType.SINGLE``,
``Chem.rdchem.BondType.DOUBLE``, ``Chem.rdchem.BondType.TRIPLE``,
``Chem.rdchem.BondType.AROMATIC``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
Returns
-------
list
List of boolean values where at most one value is True.
"""
if allowable_set is None:
allowable_set = [Chem.rdchem.BondType.SINGLE,
Chem.rdchem.BondType.DOUBLE,
Chem.rdchem.BondType.TRIPLE,
Chem.rdchem.BondType.AROMATIC]
return one_hot_encoding(bond.GetBondType(), allowable_set, encode_unknown)
@deprecated('Import it from dgllife.utils instead.')
def bond_is_conjugated_one_hot(bond, allowable_set=None, encode_unknown=False):
"""One hot encoding for whether the bond is conjugated.
Parameters
----------
bond : rdkit.Chem.rdchem.Bond
RDKit bond instance.
allowable_set : list of bool
Conditions to consider. Default: ``False`` and ``True``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
Returns
-------
list
List of boolean values where at most one value is True.
"""
if allowable_set is None:
allowable_set = [False, True]
return one_hot_encoding(bond.GetIsConjugated(), allowable_set, encode_unknown)
@deprecated('Import it from dgllife.utils instead.')
def bond_is_conjugated(bond):
"""Get whether the bond is conjugated.
Parameters
----------
bond : rdkit.Chem.rdchem.Bond
RDKit bond instance.
Returns
-------
list
List containing one bool only.
"""
return [bond.GetIsConjugated()]
@deprecated('Import it from dgllife.utils instead.')
def bond_is_in_ring_one_hot(bond, allowable_set=None, encode_unknown=False):
"""One hot encoding for whether the bond is in a ring of any size.
Parameters
----------
bond : rdkit.Chem.rdchem.Bond
RDKit bond instance.
allowable_set : list of bool
Conditions to consider. Default: ``False`` and ``True``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
Returns
-------
list
List of boolean values where at most one value is True.
"""
if allowable_set is None:
allowable_set = [False, True]
return one_hot_encoding(bond.IsInRing(), allowable_set, encode_unknown)
@deprecated('Import it from dgllife.utils instead.')
def bond_is_in_ring(bond):
"""Get whether the bond is in a ring of any size.
Parameters
----------
bond : rdkit.Chem.rdchem.Bond
RDKit bond instance.
Returns
-------
list
List containing one bool only.
"""
return [bond.IsInRing()]
@deprecated('Import it from dgllife.utils instead.')
def bond_stereo_one_hot(bond, allowable_set=None, encode_unknown=False):
"""One hot encoding for the stereo configuration of a bond.
Parameters
----------
bond : rdkit.Chem.rdchem.Bond
RDKit bond instance.
allowable_set : list of rdkit.Chem.rdchem.BondStereo
Stereo configurations to consider. Default: ``rdkit.Chem.rdchem.BondStereo.STEREONONE``,
``rdkit.Chem.rdchem.BondStereo.STEREOANY``, ``rdkit.Chem.rdchem.BondStereo.STEREOZ``,
``rdkit.Chem.rdchem.BondStereo.STEREOE``, ``rdkit.Chem.rdchem.BondStereo.STEREOCIS``,
``rdkit.Chem.rdchem.BondStereo.STEREOTRANS``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
Returns
-------
list
List of boolean values where at most one value is True.
"""
if allowable_set is None:
allowable_set = [Chem.rdchem.BondStereo.STEREONONE,
Chem.rdchem.BondStereo.STEREOANY,
Chem.rdchem.BondStereo.STEREOZ,
Chem.rdchem.BondStereo.STEREOE,
Chem.rdchem.BondStereo.STEREOCIS,
Chem.rdchem.BondStereo.STEREOTRANS]
return one_hot_encoding(bond.GetStereo(), allowable_set, encode_unknown)
class BaseBondFeaturizer(object):
"""An abstract class for bond featurizers.
Loop over all bonds in a molecule and featurize them with the ``featurizer_funcs``.
We assume the constructed ``DGLGraph`` is a bi-directed graph where the **i** th bond in the
molecule, i.e. ``mol.GetBondWithIdx(i)``, corresponds to the **(2i)**-th and **(2i+1)**-th edges
in the DGLGraph.
**We assume the resulting DGLGraph will be created with :func:`smiles_to_bigraph` without
self loops.**
Parameters
----------
featurizer_funcs : dict
Mapping feature name to the featurization function.
Each function is of signature ``func(rdkit.Chem.rdchem.Bond) -> list or 1D numpy array``.
feat_sizes : dict
Mapping feature name to the size of the corresponding feature. If None, they will be
computed when needed. Default: None.
Examples
--------
>>> from dgl.data.chem import BaseBondFeaturizer, bond_type_one_hot, bond_is_in_ring
>>> from rdkit import Chem
>>> mol = Chem.MolFromSmiles('CCO')
>>> bond_featurizer = BaseBondFeaturizer({'bond_type': bond_type_one_hot, 'in_ring': bond_is_in_ring})
>>> bond_featurizer(mol)
{'bond_type': tensor([[1., 0., 0., 0.],
[1., 0., 0., 0.],
[1., 0., 0., 0.],
[1., 0., 0., 0.]]),
'in_ring': tensor([[0.], [0.], [0.], [0.]])}
"""
@deprecated('Import BaseBondFeaturizer from dgllife.utils instead.', 'class')
def __init__(self, featurizer_funcs, feat_sizes=None):
self.featurizer_funcs = featurizer_funcs
if feat_sizes is None:
feat_sizes = dict()
self._feat_sizes = feat_sizes
def feat_size(self, feat_name):
"""Get the feature size for ``feat_name``.
Returns
-------
int
Feature size for the feature with name ``feat_name``.
"""
if feat_name not in self.featurizer_funcs:
return ValueError('Expect feat_name to be in {}, got {}'.format(
list(self.featurizer_funcs.keys()), feat_name))
if feat_name not in self._feat_sizes:
bond = Chem.MolFromSmiles('CO').GetBondWithIdx(0)
self._feat_sizes[feat_name] = len(self.featurizer_funcs[feat_name](bond))
return self._feat_sizes[feat_name]
def __call__(self, mol):
"""Featurize all bonds in a molecule.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance.
Returns
-------
dict
For each function in self.featurizer_funcs with the key ``k``, store the computed
feature under the key ``k``. Each feature is a tensor of dtype float32 and shape
(N, M), where N is the number of atoms in the molecule.
"""
num_bonds = mol.GetNumBonds()
bond_features = defaultdict(list)
# Compute features for each bond
for i in range(num_bonds):
bond = mol.GetBondWithIdx(i)
for feat_name, feat_func in self.featurizer_funcs.items():
feat = feat_func(bond)
bond_features[feat_name].extend([feat, feat.copy()])
# Stack the features and convert them to float arrays
processed_features = dict()
for feat_name, feat_list in bond_features.items():
feat = np.stack(feat_list)
processed_features[feat_name] = F.zerocopy_from_numpy(feat.astype(np.float32))
return processed_features
class CanonicalBondFeaturizer(BaseBondFeaturizer):
"""A default featurizer for bonds.
The bond features include:
* **One hot encoding of the bond type**. The supported bond types include
``SINGLE``, ``DOUBLE``, ``TRIPLE``, ``AROMATIC``.
* **Whether the bond is conjugated.**.
* **Whether the bond is in a ring of any size.**
* **One hot encoding of the stereo configuration of a bond**. The supported bond stereo
configurations include ``STEREONONE``, ``STEREOANY``, ``STEREOZ``, ``STEREOE``,
``STEREOCIS``, ``STEREOTRANS``.
**We assume the resulting DGLGraph will be created with :func:`smiles_to_bigraph` without
self loops.**
"""
@deprecated('Import CanonicalBondFeaturizer from dgllife.utils instead.', 'class')
def __init__(self, bond_data_field='e'):
super(CanonicalBondFeaturizer, self).__init__(
featurizer_funcs={bond_data_field: ConcatFeaturizer(
[bond_type_one_hot,
bond_is_conjugated,
bond_is_in_ring,
bond_stereo_one_hot]
)})
"""Convert molecules into DGLGraphs."""
import numpy as np
from functools import partial
from .... import DGLGraph
from ....contrib.deprecation import deprecated
try:
import mdtraj
from rdkit import Chem
from rdkit.Chem import rdmolfiles, rdmolops
except ImportError:
pass
__all__ = ['mol_to_graph',
'smiles_to_bigraph',
'mol_to_bigraph',
'smiles_to_complete_graph',
'mol_to_complete_graph',
'k_nearest_neighbors']
@deprecated('Import it from dgllife.utils instead.')
def mol_to_graph(mol, graph_constructor, node_featurizer, edge_featurizer):
"""Convert an RDKit molecule object into a DGLGraph and featurize for it.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule holder
graph_constructor : callable
Takes an RDKit molecule as input and returns a DGLGraph
node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for nodes like atoms in a molecule, which can be used to
update ndata for a DGLGraph.
edge_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for edges like bonds in a molecule, which can be used to
update edata for a DGLGraph.
Returns
-------
g : DGLGraph
Converted DGLGraph for the molecule
"""
new_order = rdmolfiles.CanonicalRankAtoms(mol)
mol = rdmolops.RenumberAtoms(mol, new_order)
g = graph_constructor(mol)
if node_featurizer is not None:
g.ndata.update(node_featurizer(mol))
if edge_featurizer is not None:
g.edata.update(edge_featurizer(mol))
return g
def construct_bigraph_from_mol(mol, add_self_loop=False):
"""Construct a bi-directed DGLGraph with topology only for the molecule.
The **i** th atom in the molecule, i.e. ``mol.GetAtomWithIdx(i)``, corresponds to the
**i** th node in the returned DGLGraph.
The **i** th bond in the molecule, i.e. ``mol.GetBondWithIdx(i)``, corresponds to the
**(2i)**-th and **(2i+1)**-th edges in the returned DGLGraph. The **(2i)**-th and
**(2i+1)**-th edges will be separately from **u** to **v** and **v** to **u**, where
**u** is ``bond.GetBeginAtomIdx()`` and **v** is ``bond.GetEndAtomIdx()``.
If self loops are added, the last **n** edges will separately be self loops for
atoms ``0, 1, ..., n-1``.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule holder
add_self_loop : bool
Whether to add self loops in DGLGraphs. Default to False.
Returns
-------
g : DGLGraph
Empty bigraph topology of the molecule
"""
g = DGLGraph()
# Add nodes
num_atoms = mol.GetNumAtoms()
g.add_nodes(num_atoms)
# Add edges
src_list = []
dst_list = []
num_bonds = mol.GetNumBonds()
for i in range(num_bonds):
bond = mol.GetBondWithIdx(i)
u = bond.GetBeginAtomIdx()
v = bond.GetEndAtomIdx()
src_list.extend([u, v])
dst_list.extend([v, u])
g.add_edges(src_list, dst_list)
if add_self_loop:
nodes = g.nodes()
g.add_edges(nodes, nodes)
return g
@deprecated('Import it from dgllife.utils instead.')
def mol_to_bigraph(mol, add_self_loop=False,
node_featurizer=None,
edge_featurizer=None):
"""Convert an RDKit molecule object into a bi-directed DGLGraph and featurize for it.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule holder
add_self_loop : bool
Whether to add self loops in DGLGraphs. Default to False.
node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for nodes like atoms in a molecule, which can be used to update
ndata for a DGLGraph. Default to None.
edge_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for edges like bonds in a molecule, which can be used to update
edata for a DGLGraph. Default to None.
Returns
-------
g : DGLGraph
Bi-directed DGLGraph for the molecule
"""
return mol_to_graph(mol, partial(construct_bigraph_from_mol, add_self_loop=add_self_loop),
node_featurizer, edge_featurizer)
@deprecated('Import it from dgllife.utils instead.')
def smiles_to_bigraph(smiles, add_self_loop=False,
node_featurizer=None,
edge_featurizer=None):
"""Convert a SMILES into a bi-directed DGLGraph and featurize for it.
Parameters
----------
smiles : str
String of SMILES
add_self_loop : bool
Whether to add self loops in DGLGraphs. Default to False.
node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for nodes like atoms in a molecule, which can be used to update
ndata for a DGLGraph. Default to None.
edge_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for edges like bonds in a molecule, which can be used to update
edata for a DGLGraph. Default to None.
Returns
-------
g : DGLGraph
Bi-directed DGLGraph for the molecule
"""
mol = Chem.MolFromSmiles(smiles)
return mol_to_bigraph(mol, add_self_loop, node_featurizer, edge_featurizer)
def construct_complete_graph_from_mol(mol, add_self_loop=False):
"""Construct a complete graph with topology only for the molecule
The **i** th atom in the molecule, i.e. ``mol.GetAtomWithIdx(i)``, corresponds to the
**i** th node in the returned DGLGraph.
The edges are in the order of (0, 0), (1, 0), (2, 0), ... (0, 1), (1, 1), (2, 1), ...
If self loops are not created, we will not have (0, 0), (1, 1), ...
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule holder
add_self_loop : bool
Whether to add self loops in DGLGraphs. Default to False.
Returns
-------
g : DGLGraph
Empty complete graph topology of the molecule
"""
g = DGLGraph()
num_atoms = mol.GetNumAtoms()
g.add_nodes(num_atoms)
if add_self_loop:
g.add_edges(
[i for i in range(num_atoms) for j in range(num_atoms)],
[j for i in range(num_atoms) for j in range(num_atoms)])
else:
g.add_edges(
[i for i in range(num_atoms) for j in range(num_atoms - 1)], [
j for i in range(num_atoms)
for j in range(num_atoms) if i != j
])
return g
@deprecated('Import it from dgllife.utils instead.')
def mol_to_complete_graph(mol, add_self_loop=False,
node_featurizer=None,
edge_featurizer=None):
"""Convert an RDKit molecule into a complete DGLGraph and featurize for it.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule holder
add_self_loop : bool
Whether to add self loops in DGLGraphs. Default to False.
node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for nodes like atoms in a molecule, which can be used to update
ndata for a DGLGraph. Default to None.
edge_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for edges like bonds in a molecule, which can be used to update
edata for a DGLGraph. Default to None.
Returns
-------
g : DGLGraph
Complete DGLGraph for the molecule
"""
return mol_to_graph(mol, partial(construct_complete_graph_from_mol, add_self_loop=add_self_loop),
node_featurizer, edge_featurizer)
@deprecated('Import it from dgllife.utils instead.')
def smiles_to_complete_graph(smiles, add_self_loop=False,
node_featurizer=None,
edge_featurizer=None):
"""Convert a SMILES into a complete DGLGraph and featurize for it.
Parameters
----------
smiles : str
String of SMILES
add_self_loop : bool
Whether to add self loops in DGLGraphs. Default to False.
node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for nodes like atoms in a molecule, which can be used to update
ndata for a DGLGraph. Default to None.
edge_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for edges like bonds in a molecule, which can be used to update
edata for a DGLGraph. Default to None.
Returns
-------
g : DGLGraph
Complete DGLGraph for the molecule
"""
mol = Chem.MolFromSmiles(smiles)
return mol_to_complete_graph(mol, add_self_loop, node_featurizer, edge_featurizer)
@deprecated('Import it from dgllife.utils instead.')
def k_nearest_neighbors(coordinates, neighbor_cutoff, max_num_neighbors):
"""Find k nearest neighbors for each atom based on the 3D coordinates.
Parameters
----------
coordinates : numpy.ndarray of shape (N, 3)
The 3D coordinates of atoms in the molecule. N for the number of atoms.
neighbor_cutoff : float
Distance cutoff to define 'neighboring'.
max_num_neighbors : int or None.
If not None, then this specifies the maximum number of closest neighbors
allowed for each atom.
Returns
-------
Returns
-------
srcs : list of int
Source nodes.
dsts : list of int
Destination nodes.
distances : list of float
Distances between the end nodes.
"""
num_atoms = coordinates.shape[0]
traj = mdtraj.Trajectory(coordinates.reshape((1, num_atoms, 3)), None)
neighbors = mdtraj.geometry.compute_neighborlist(traj, neighbor_cutoff)
srcs, dsts, distances = [], [], []
for i in range(num_atoms):
delta = coordinates[i] - coordinates.take(neighbors[i], axis=0)
dist = np.linalg.norm(delta, axis=1)
if max_num_neighbors is not None and len(neighbors[i]) > max_num_neighbors:
sorted_neighbors = list(zip(dist, neighbors[i]))
# Sort neighbors based on distance from smallest to largest
sorted_neighbors.sort(key=lambda tup: tup[0])
dsts.extend([i for _ in range(max_num_neighbors)])
srcs.extend([int(sorted_neighbors[j][1]) for j in range(max_num_neighbors)])
distances.extend([float(sorted_neighbors[j][0]) for j in range(max_num_neighbors)])
else:
dsts.extend([i for _ in range(len(neighbors[i]))])
srcs.extend(neighbors[i].tolist())
distances.extend(dist.tolist())
return srcs, dsts, distances
"""Utils for RDKit, mostly adapted from DeepChem
(https://github.com/deepchem/deepchem/blob/master/deepchem)."""
import warnings
from functools import partial
from multiprocessing import Pool
from ....contrib.deprecation import deprecated
try:
import pdbfixer
import simtk
from rdkit import Chem
from rdkit.Chem import AllChem
from StringIO import StringIO
except ImportError:
from io import StringIO
__all__ = ['add_hydrogens_to_mol',
'get_mol_3D_coordinates',
'load_molecule',
'multiprocess_load_molecules']
@deprecated('')
def add_hydrogens_to_mol(mol):
"""Add hydrogens to an RDKit molecule instance.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance.
Returns
-------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance with hydrogens added. For failures in adding hydrogens,
the original RDKit molecule instance will be returned.
"""
try:
pdbblock = Chem.MolToPDBBlock(mol)
pdb_stringio = StringIO()
pdb_stringio.write(pdbblock)
pdb_stringio.seek(0)
fixer = pdbfixer.PDBFixer(pdbfile=pdb_stringio)
fixer.findMissingResidues()
fixer.findMissingAtoms()
fixer.addMissingAtoms()
fixer.addMissingHydrogens(7.4)
hydrogenated_io = StringIO()
simtk.openmm.app.PDBFile.writeFile(fixer.topology, fixer.positions,
hydrogenated_io)
hydrogenated_io.seek(0)
mol = Chem.MolFromPDBBlock(hydrogenated_io.read(), sanitize=False, removeHs=False)
pdb_stringio.close()
hydrogenated_io.close()
except ValueError:
warnings.warn('Failed to add hydrogens to the molecule.')
return mol
@deprecated('Import it from dgllife.utils.rdkit_utils instead.')
def get_mol_3D_coordinates(mol):
"""Get 3D coordinates of the molecule.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance.
Returns
-------
numpy.ndarray of shape (N, 3) or None
The 3D coordinates of atoms in the molecule. N for the number of atoms in
the molecule. For failures in getting the conformations, None will be returned.
"""
try:
conf = mol.GetConformer()
conf_num_atoms = conf.GetNumAtoms()
mol_num_atoms = mol.GetNumAtoms()
assert mol_num_atoms == conf_num_atoms, \
'Expect the number of atoms in the molecule and its conformation ' \
'to be the same, got {:d} and {:d}'.format(mol_num_atoms, conf_num_atoms)
return conf.GetPositions()
except:
warnings.warn('Unable to get conformation of the molecule.')
return None
@deprecated('Import it from dgllife.utils.rdkit_utils instead.')
def load_molecule(molecule_file, add_hydrogens=False, sanitize=False, calc_charges=False,
remove_hs=False, use_conformation=True):
"""Load a molecule from a file.
Parameters
----------
molecule_file : str
Path to file for storing a molecule, which can be of format '.mol2', '.sdf',
'.pdbqt', or '.pdb'.
add_hydrogens : bool
Whether to add hydrogens via pdbfixer. Default to False.
sanitize : bool
Whether sanitization is performed in initializing RDKit molecule instances. See
https://www.rdkit.org/docs/RDKit_Book.html for details of the sanitization.
Default to False.
calc_charges : bool
Whether to add Gasteiger charges via RDKit. Setting this to be True will enforce
``add_hydrogens`` and ``sanitize`` to be True. Default to False.
remove_hs : bool
Whether to remove hydrogens via RDKit. Note that removing hydrogens can be quite
slow for large molecules. Default to False.
use_conformation : bool
Whether we need to extract molecular conformation from proteins and ligands.
Default to True.
Returns
-------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance for the loaded molecule.
coordinates : np.ndarray of shape (N, 3) or None
The 3D coordinates of atoms in the molecule. N for the number of atoms in
the molecule. None will be returned if ``use_conformation`` is False or
we failed to get conformation information.
"""
if molecule_file.endswith('.mol2'):
mol = Chem.MolFromMol2File(molecule_file, sanitize=False, removeHs=False)
elif molecule_file.endswith('.sdf'):
supplier = Chem.SDMolSupplier(molecule_file, sanitize=False, removeHs=False)
mol = supplier[0]
elif molecule_file.endswith('.pdbqt'):
with open(molecule_file) as f:
pdbqt_data = f.readlines()
pdb_block = ''
for line in pdbqt_data:
pdb_block += '{}\n'.format(line[:66])
mol = Chem.MolFromPDBBlock(pdb_block, sanitize=False, removeHs=False)
elif molecule_file.endswith('.pdb'):
mol = Chem.MolFromPDBFile(molecule_file, sanitize=False, removeHs=False)
else:
return ValueError('Expect the format of the molecule_file to be '
'one of .mol2, .sdf, .pdbqt and .pdb, got {}'.format(molecule_file))
try:
if add_hydrogens or calc_charges:
mol = add_hydrogens_to_mol(mol)
if sanitize or calc_charges:
Chem.SanitizeMol(mol)
if calc_charges:
# Compute Gasteiger charges on the molecule.
try:
AllChem.ComputeGasteigerCharges(mol)
except:
warnings.warn('Unable to compute charges for the molecule.')
if remove_hs:
mol = Chem.RemoveHs(mol)
except:
return None, None
if use_conformation:
coordinates = get_mol_3D_coordinates(mol)
else:
coordinates = None
return mol, coordinates
@deprecated('Import it from dgllife.utils.rdkit_utils instead.')
def multiprocess_load_molecules(files, add_hydrogens=False, sanitize=False, calc_charges=False,
remove_hs=False, use_conformation=True, num_processes=2):
"""Load molecules from files with multiprocessing.
Parameters
----------
files : list of str
Each element is a path to a file storing a molecule, which can be of format '.mol2',
'.sdf', '.pdbqt', or '.pdb'.
add_hydrogens : bool
Whether to add hydrogens via pdbfixer. Default to False.
sanitize : bool
Whether sanitization is performed in initializing RDKit molecule instances. See
https://www.rdkit.org/docs/RDKit_Book.html for details of the sanitization.
Default to False.
calc_charges : bool
Whether to add Gasteiger charges via RDKit. Setting this to be True will enforce
``add_hydrogens`` and ``sanitize`` to be True. Default to False.
remove_hs : bool
Whether to remove hydrogens via RDKit. Note that removing hydrogens can be quite
slow for large molecules. Default to False.
use_conformation : bool
Whether we need to extract molecular conformation from proteins and ligands.
Default to True.
num_processes : int or None
Number of worker processes to use. If None,
then we will use the number of CPUs in the systetm. Default to 2.
Returns
-------
list of 2-tuples
The first element of each 2-tuple is an RDKit molecule instance. The second element
of each 2-tuple is the 3D atom coordinates of the corresponding molecule if
use_conformation is True and the coordinates has been successfully loaded. Otherwise,
it will be None.
"""
if num_processes == 1:
mols_loaded = []
for i, f in enumerate(files):
mols_loaded.append(load_molecule(
f, add_hydrogens=add_hydrogens, sanitize=sanitize, calc_charges=calc_charges,
remove_hs=remove_hs, use_conformation=use_conformation))
else:
with Pool(processes=num_processes) as pool:
mols_loaded = pool.map_async(partial(
load_molecule, add_hydrogens=add_hydrogens, sanitize=sanitize,
calc_charges=calc_charges, remove_hs=remove_hs,
use_conformation=use_conformation), files)
mols_loaded = mols_loaded.get()
return mols_loaded
"""Various methods for splitting chemical datasets.
We mostly adapt them from deepchem
(https://github.com/deepchem/deepchem/blob/master/deepchem/splits/splitters.py).
"""
import numpy as np
from collections import defaultdict
from functools import partial
from itertools import accumulate, chain
from ...utils import split_dataset, Subset
from .... import backend as F
from ....contrib.deprecation import deprecated
try:
from rdkit import Chem
from rdkit.Chem import rdMolDescriptors
from rdkit.Chem.rdmolops import FastFindRings
from rdkit.Chem.Scaffolds import MurckoScaffold
except ImportError:
pass
__all__ = ['ConsecutiveSplitter',
'RandomSplitter',
'MolecularWeightSplitter',
'ScaffoldSplitter',
'SingleTaskStratifiedSplitter']
def base_k_fold_split(split_method, dataset, k, log):
"""Split dataset for k-fold cross validation.
Parameters
----------
split_method : callable
Arbitrary method for splitting the dataset
into training, validation and test subsets.
dataset
We assume ``len(dataset)`` gives the size for the dataset and ``dataset[i]``
gives the ith datapoint.
k : int
Number of folds to use and should be no smaller than 2.
log : bool
Whether to print a message at the start of preparing each fold.
Returns
-------
all_folds : list of 2-tuples
Each element of the list represents a fold and is a 2-tuple (train_set, val_set).
"""
assert k >= 2, 'Expect the number of folds to be no smaller than 2, got {:d}'.format(k)
all_folds = []
frac_per_part = 1./ k
for i in range(k):
if log:
print('Processing fold {:d}/{:d}'.format(i+1, k))
# We are reusing the code for train-validation-test split.
train_set1, val_set, train_set2 = split_method(dataset,
frac_train=i * frac_per_part,
frac_val=frac_per_part,
frac_test=1. - (i + 1) * frac_per_part)
# For cross validation, each fold consists of only a train subset and
# a validation subset.
train_set = Subset(dataset, train_set1.indices + train_set2.indices)
all_folds.append((train_set, val_set))
return all_folds
def train_val_test_sanity_check(frac_train, frac_val, frac_test):
"""Sanity check for train-val-test split
Ensure that the fractions of the dataset to use for training,
validation and test add up to 1.
Parameters
----------
frac_train : float
Fraction of the dataset to use for training.
frac_val : float
Fraction of the dataset to use for validation.
frac_test : float
Fraction of the dataset to use for test.
"""
total_fraction = frac_train + frac_val + frac_test
assert np.allclose(total_fraction, 1.), \
'Expect the sum of fractions for training, validation and ' \
'test to be 1, got {:.4f}'.format(total_fraction)
def indices_split(dataset, frac_train, frac_val, frac_test, indices):
"""Reorder datapoints based on the specified indices and then take consecutive
chunks as subsets.
Parameters
----------
dataset
We assume ``len(dataset)`` gives the size for the dataset and ``dataset[i]``
gives the ith datapoint.
frac_train : float
Fraction of data to use for training.
frac_val : float
Fraction of data to use for validation.
frac_test : float
Fraction of data to use for test.
indices : list or ndarray
Indices specifying the order of datapoints.
Returns
-------
list of length 3
Subsets for training, validation and test, which are all :class:`Subset` instances.
"""
frac_list = np.asarray([frac_train, frac_val, frac_test])
assert np.allclose(np.sum(frac_list), 1.), \
'Expect frac_list sum to 1, got {:.4f}'.format(np.sum(frac_list))
num_data = len(dataset)
lengths = (num_data * frac_list).astype(int)
lengths[-1] = num_data - np.sum(lengths[:-1])
return [Subset(dataset, list(indices[offset - length:offset]))
for offset, length in zip(accumulate(lengths), lengths)]
def count_and_log(message, i, total, log_every_n):
"""Print a message to reflect the progress of processing once a while.
Parameters
----------
message : str
Message to print.
i : int
Current index.
total : int
Total count.
log_every_n : None or int
Molecule related computation can take a long time for a large dataset and we want
to learn the progress of processing. This can be done by printing a message whenever
a batch of ``log_every_n`` molecules have been processed. If None, no messages will
be printed.
"""
if (log_every_n is not None) and ((i+1) % log_every_n == 0):
print('{} {:d}/{:d}'.format(message, i+1, total))
def prepare_mols(dataset, mols, sanitize, log_every_n=1000):
"""Prepare RDKit molecule instances.
Parameters
----------
dataset
We assume ``len(dataset)`` gives the size for the dataset, ``dataset[i]``
gives the ith datapoint and ``dataset.smiles[i]`` gives the SMILES for the
ith datapoint.
mols : None or list of rdkit.Chem.rdchem.Mol
None or pre-computed RDKit molecule instances. If not None, we expect a
one-on-one correspondence between ``dataset.smiles`` and ``mols``, i.e.
``mols[i]`` corresponds to ``dataset.smiles[i]``.
sanitize : bool
This argument only comes into effect when ``mols`` is None and decides whether
sanitization is performed in initializing RDKit molecule instances. See
https://www.rdkit.org/docs/RDKit_Book.html for details of the sanitization.
log_every_n : None or int
Molecule related computation can take a long time for a large dataset and we want
to learn the progress of processing. This can be done by printing a message whenever
a batch of ``log_every_n`` molecules have been processed. If None, no messages will
be printed. Default to 1000.
Returns
-------
mols : list of rdkit.Chem.rdchem.Mol
RDkit molecule instances where there is a one-on-one correspondence between
``dataset.smiles`` and ``mols``, i.e. ``mols[i]`` corresponds to ``dataset.smiles[i]``.
"""
if mols is not None:
# Sanity check
assert len(mols) == len(dataset), \
'Expect mols to be of the same size as that of the dataset, ' \
'got {:d} and {:d}'.format(len(mols), len(dataset))
else:
if log_every_n is not None:
print('Start initializing RDKit molecule instances...')
mols = []
for i, s in enumerate(dataset.smiles):
count_and_log('Creating RDKit molecule instance',
i, len(dataset.smiles), log_every_n)
mols.append(Chem.MolFromSmiles(s, sanitize=sanitize))
return mols
class ConsecutiveSplitter(object):
"""Split datasets with the input order.
The dataset is split without permutation, so the splitting is deterministic.
"""
@staticmethod
@deprecated('Import ConsecutiveSplitter from dgllife.utils.splitters instead.', 'class')
def train_val_test_split(dataset, frac_train=0.8, frac_val=0.1, frac_test=0.1):
"""Split the dataset into three consecutive chunks for training, validation and test.
Parameters
----------
dataset
We assume ``len(dataset)`` gives the size for the dataset and ``dataset[i]``
gives the ith datapoint.
frac_train : float
Fraction of data to use for training. By default, we set this to be 0.8, i.e.
80% of the dataset is used for training.
frac_val : float
Fraction of data to use for validation. By default, we set this to be 0.1, i.e.
10% of the dataset is used for validation.
frac_test : float
Fraction of data to use for test. By default, we set this to be 0.1, i.e.
10% of the dataset is used for test.
Returns
-------
list of length 3
Subsets for training, validation and test, which are all :class:`Subset` instances.
"""
return split_dataset(dataset, frac_list=[frac_train, frac_val, frac_test], shuffle=False)
@staticmethod
@deprecated('Import ConsecutiveSplitter from dgllife.utils.splitters instead.', 'class')
def k_fold_split(dataset, k=5, log=True):
"""Split the dataset for k-fold cross validation by taking consecutive chunks.
Parameters
----------
dataset
We assume ``len(dataset)`` gives the size for the dataset and ``dataset[i]``
gives the ith datapoint.
k : int
Number of folds to use and should be no smaller than 2. Default to be 5.
log : bool
Whether to print a message at the start of preparing each fold.
Returns
-------
list of 2-tuples
Each element of the list represents a fold and is a 2-tuple (train_set, val_set).
"""
return base_k_fold_split(ConsecutiveSplitter.train_val_test_split, dataset, k, log)
class RandomSplitter(object):
"""Randomly reorder datasets and then split them.
The dataset is split with permutation and the splitting is hence random.
"""
@staticmethod
@deprecated('Import RandomSplitter from dgllife.utils.splitters instead.', 'class')
def train_val_test_split(dataset, frac_train=0.8, frac_val=0.1,
frac_test=0.1, random_state=None):
"""Randomly permute the dataset and then split it into
three consecutive chunks for training, validation and test.
Parameters
----------
dataset
We assume ``len(dataset)`` gives the size for the dataset and ``dataset[i]``
gives the ith datapoint.
frac_train : float
Fraction of data to use for training. By default, we set this to be 0.8, i.e.
80% of the dataset is used for training.
frac_val : float
Fraction of data to use for validation. By default, we set this to be 0.1, i.e.
10% of the dataset is used for validation.
frac_test : float
Fraction of data to use for test. By default, we set this to be 0.1, i.e.
10% of the dataset is used for test.
random_state : None, int or array_like, optional
Random seed used to initialize the pseudo-random number generator.
Can be any integer between 0 and 2**32 - 1 inclusive, an array
(or other sequence) of such integers, or None (the default).
If seed is None, then RandomState will try to read data from /dev/urandom
(or the Windows analogue) if available or seed from the clock otherwise.
Returns
-------
list of length 3
Subsets for training, validation and test.
"""
return split_dataset(dataset, frac_list=[frac_train, frac_val, frac_test],
shuffle=True, random_state=random_state)
@staticmethod
@deprecated('Import RandomSplitter from dgllife.utils.splitters instead.', 'class')
def k_fold_split(dataset, k=5, random_state=None, log=True):
"""Randomly permute the dataset and then split it
for k-fold cross validation by taking consecutive chunks.
Parameters
----------
dataset
We assume ``len(dataset)`` gives the size for the dataset and ``dataset[i]``
gives the ith datapoint.
k : int
Number of folds to use and should be no smaller than 2. Default to be 5.
random_state : None, int or array_like, optional
Random seed used to initialize the pseudo-random number generator.
Can be any integer between 0 and 2**32 - 1 inclusive, an array
(or other sequence) of such integers, or None (the default).
If seed is None, then RandomState will try to read data from /dev/urandom
(or the Windows analogue) if available or seed from the clock otherwise.
log : bool
Whether to print a message at the start of preparing each fold. Default to True.
Returns
-------
list of 2-tuples
Each element of the list represents a fold and is a 2-tuple (train_set, val_set).
"""
# Permute the dataset only once so that each datapoint
# will appear once in exactly one fold.
indices = np.random.RandomState(seed=random_state).permutation(len(dataset))
return base_k_fold_split(partial(indices_split, indices=indices), dataset, k, log)
class MolecularWeightSplitter(object):
"""Sort molecules based on their weights and then split them."""
@staticmethod
@deprecated('Import MolecularWeightSplitter from dgllife.utils.splitters instead.', 'class')
def molecular_weight_indices(molecules, log_every_n):
"""Reorder molecules based on molecular weights.
Parameters
----------
molecules : list of rdkit.Chem.rdchem.Mol
Pre-computed RDKit molecule instances. We expect a one-on-one
correspondence between ``dataset.smiles`` and ``mols``, i.e.
``mols[i]`` corresponds to ``dataset.smiles[i]``.
log_every_n : None or int
Molecule related computation can take a long time for a large dataset and we want
to learn the progress of processing. This can be done by printing a message whenever
a batch of ``log_every_n`` molecules have been processed. If None, no messages will
be printed.
Returns
-------
indices : list or ndarray
Indices specifying the order of datapoints, which are basically
argsort of the molecular weights.
"""
if log_every_n is not None:
print('Start computing molecular weights.')
mws = []
for i, mol in enumerate(molecules):
count_and_log('Computing molecular weight for compound',
i, len(molecules), log_every_n)
mws.append(Chem.rdMolDescriptors.CalcExactMolWt(mol))
return np.argsort(mws)
@staticmethod
@deprecated('Import MolecularWeightSplitter from dgllife.utils.splitters instead.', 'class')
def train_val_test_split(dataset, mols=None, sanitize=True, frac_train=0.8,
frac_val=0.1, frac_test=0.1, log_every_n=1000):
"""Sort molecules based on their weights and then split them into
three consecutive chunks for training, validation and test.
Parameters
----------
dataset
We assume ``len(dataset)`` gives the size for the dataset, ``dataset[i]``
gives the ith datapoint and ``dataset.smiles[i]`` gives the SMILES for the
ith datapoint.
mols : None or list of rdkit.Chem.rdchem.Mol
None or pre-computed RDKit molecule instances. If not None, we expect a
one-on-one correspondence between ``dataset.smiles`` and ``mols``, i.e.
``mols[i]`` corresponds to ``dataset.smiles[i]``. Default to None.
sanitize : bool
This argument only comes into effect when ``mols`` is None and decides whether
sanitization is performed in initializing RDKit molecule instances. See
https://www.rdkit.org/docs/RDKit_Book.html for details of the sanitization.
Default to be True.
frac_train : float
Fraction of data to use for training. By default, we set this to be 0.8, i.e.
80% of the dataset is used for training.
frac_val : float
Fraction of data to use for validation. By default, we set this to be 0.1, i.e.
10% of the dataset is used for validation.
frac_test : float
Fraction of data to use for test. By default, we set this to be 0.1, i.e.
10% of the dataset is used for test.
log_every_n : None or int
Molecule related computation can take a long time for a large dataset and we want
to learn the progress of processing. This can be done by printing a message whenever
a batch of ``log_every_n`` molecules have been processed. If None, no messages will
be printed. Default to 1000.
Returns
-------
list of length 3
Subsets for training, validation and test, which are all :class:`Subset` instances.
"""
# Perform sanity check first as molecule instance initialization and descriptor
# computation can take a long time.
train_val_test_sanity_check(frac_train, frac_val, frac_test)
molecules = prepare_mols(dataset, mols, sanitize, log_every_n)
sorted_indices = MolecularWeightSplitter.molecular_weight_indices(molecules, log_every_n)
return indices_split(dataset, frac_train, frac_val, frac_test, sorted_indices)
@staticmethod
@deprecated('Import MolecularWeightSplitter from dgllife.utils.splitters instead.', 'class')
def k_fold_split(dataset, mols=None, sanitize=True, k=5, log_every_n=1000):
"""Sort molecules based on their weights and then split them
for k-fold cross validation by taking consecutive chunks.
Parameters
----------
dataset
We assume ``len(dataset)`` gives the size for the dataset, ``dataset[i]``
gives the ith datapoint and ``dataset.smiles[i]`` gives the SMILES for the
ith datapoint.
mols : None or list of rdkit.Chem.rdchem.Mol
None or pre-computed RDKit molecule instances. If not None, we expect a
one-on-one correspondence between ``dataset.smiles`` and ``mols``, i.e.
``mols[i]`` corresponds to ``dataset.smiles[i]``. Default to None.
sanitize : bool
This argument only comes into effect when ``mols`` is None and decides whether
sanitization is performed in initializing RDKit molecule instances. See
https://www.rdkit.org/docs/RDKit_Book.html for details of the sanitization.
Default to be True.
k : int
Number of folds to use and should be no smaller than 2. Default to be 5.
log_every_n : None or int
Molecule related computation can take a long time for a large dataset and we want
to learn the progress of processing. This can be done by printing a message whenever
a batch of ``log_every_n`` molecules have been processed. If None, no messages will
be printed. Default to 1000.
Returns
-------
list of 2-tuples
Each element of the list represents a fold and is a 2-tuple (train_set, val_set).
"""
molecules = prepare_mols(dataset, mols, sanitize, log_every_n)
sorted_indices = MolecularWeightSplitter.molecular_weight_indices(molecules, log_every_n)
return base_k_fold_split(partial(indices_split, indices=sorted_indices), dataset, k,
log=(log_every_n is not None))
class ScaffoldSplitter(object):
"""Group molecules based on their Bemis-Murcko scaffolds and then split the groups.
Group molecules so that all molecules in a group have a same scaffold (see reference).
The dataset is then split at the level of groups.
References
----------
Bemis, G. W.; Murcko, M. A. “The Properties of Known Drugs.
1. Molecular Frameworks.” J. Med. Chem. 39:2887-93 (1996).
"""
@staticmethod
@deprecated('Import ScaffoldSplitter from dgllife.utils.splitters instead.', 'class')
def get_ordered_scaffold_sets(molecules, include_chirality, log_every_n):
"""Group molecules based on their Bemis-Murcko scaffolds and
order these groups based on their sizes.
The order is decided by comparing the size of groups, where groups with a larger size
are placed before the ones with a smaller size.
Parameters
----------
molecules : list of rdkit.Chem.rdchem.Mol
Pre-computed RDKit molecule instances. We expect a one-on-one
correspondence between ``dataset.smiles`` and ``mols``, i.e.
``mols[i]`` corresponds to ``dataset.smiles[i]``.
include_chirality : bool
Whether to consider chirality in computing scaffolds.
log_every_n : None or int
Molecule related computation can take a long time for a large dataset and we want
to learn the progress of processing. This can be done by printing a message whenever
a batch of ``log_every_n`` molecules have been processed. If None, no messages will
be printed.
Returns
-------
scaffold_sets : list
Each element of the list is a list of int,
representing the indices of compounds with a same scaffold.
"""
if log_every_n is not None:
print('Start computing Bemis-Murcko scaffolds.')
scaffolds = defaultdict(list)
for i, mol in enumerate(molecules):
count_and_log('Computing Bemis-Murcko for compound',
i, len(molecules), log_every_n)
# For mols that have not been sanitized, we need to compute their ring information
try:
FastFindRings(mol)
mol_scaffold = MurckoScaffold.MurckoScaffoldSmiles(
mol=mol, includeChirality=include_chirality)
# Group molecules that have the same scaffold
scaffolds[mol_scaffold].append(i)
except:
print('Failed to compute the scaffold for molecule {:d} '
'and it will be excluded.'.format(i+1))
# Order groups of molecules by first comparing the size of groups
# and then the index of the first compound in the group.
scaffold_sets = [
scaffold_set for (scaffold, scaffold_set) in sorted(
scaffolds.items(), key=lambda x: (len(x[1]), x[1][0]), reverse=True)
]
return scaffold_sets
@staticmethod
@deprecated('Import ScaffoldSplitter from dgllife.utils.splitters instead.', 'class')
def train_val_test_split(dataset, mols=None, sanitize=True, include_chirality=False,
frac_train=0.8, frac_val=0.1, frac_test=0.1, log_every_n=1000):
"""Split the dataset into training, validation and test set based on molecular scaffolds.
This spliting method ensures that molecules with a same scaffold will be collectively
in only one of the training, validation or test set. As a result, the fraction
of dataset to use for training and validation tend to be smaller than ``frac_train``
and ``frac_val``, while the fraction of dataset to use for test tends to be larger
than ``frac_test``.
Parameters
----------
dataset
We assume ``len(dataset)`` gives the size for the dataset, ``dataset[i]``
gives the ith datapoint and ``dataset.smiles[i]`` gives the SMILES for the
ith datapoint.
mols : None or list of rdkit.Chem.rdchem.Mol
None or pre-computed RDKit molecule instances. If not None, we expect a
one-on-one correspondence between ``dataset.smiles`` and ``mols``, i.e.
``mols[i]`` corresponds to ``dataset.smiles[i]``. Default to None.
sanitize : bool
This argument only comes into effect when ``mols`` is None and decides whether
sanitization is performed in initializing RDKit molecule instances. See
https://www.rdkit.org/docs/RDKit_Book.html for details of the sanitization.
Default to True.
include_chirality : bool
Whether to consider chirality in computing scaffolds. Default to False.
frac_train : float
Fraction of data to use for training. By default, we set this to be 0.8, i.e.
80% of the dataset is used for training.
frac_val : float
Fraction of data to use for validation. By default, we set this to be 0.1, i.e.
10% of the dataset is used for validation.
frac_test : float
Fraction of data to use for test. By default, we set this to be 0.1, i.e.
10% of the dataset is used for test.
log_every_n : None or int
Molecule related computation can take a long time for a large dataset and we want
to learn the progress of processing. This can be done by printing a message whenever
a batch of ``log_every_n`` molecules have been processed. If None, no messages will
be printed. Default to 1000.
Returns
-------
list of length 3
Subsets for training, validation and test, which are all :class:`Subset` instances.
"""
# Perform sanity check first as molecule related computation can take a long time.
train_val_test_sanity_check(frac_train, frac_val, frac_test)
molecules = prepare_mols(dataset, mols, sanitize)
scaffold_sets = ScaffoldSplitter.get_ordered_scaffold_sets(
molecules, include_chirality, log_every_n)
train_indices, val_indices, test_indices = [], [], []
train_cutoff = int(frac_train * len(molecules))
val_cutoff = int((frac_train + frac_val) * len(molecules))
for group_indices in scaffold_sets:
if len(train_indices) + len(group_indices) > train_cutoff:
if len(train_indices) + len(val_indices) + len(group_indices) > val_cutoff:
test_indices.extend(group_indices)
else:
val_indices.extend(group_indices)
else:
train_indices.extend(group_indices)
return [Subset(dataset, train_indices),
Subset(dataset, val_indices),
Subset(dataset, test_indices)]
@staticmethod
@deprecated('Import ScaffoldSplitter from dgllife.utils.splitters instead.', 'class')
def k_fold_split(dataset, mols=None, sanitize=True,
include_chirality=False, k=5, log_every_n=1000):
"""Group molecules based on their scaffolds and sort groups based on their sizes.
The groups are then split for k-fold cross validation.
Same as usual k-fold splitting methods, each molecule will appear only once
in the validation set among all folds. In addition, this method ensures that
molecules with a same scaffold will be collectively in either the training
set or the validation set for each fold.
Note that the folds can be highly imbalanced depending on the
scaffold distribution in the dataset.
Parameters
----------
dataset
We assume ``len(dataset)`` gives the size for the dataset, ``dataset[i]``
gives the ith datapoint and ``dataset.smiles[i]`` gives the SMILES for the
ith datapoint.
mols : None or list of rdkit.Chem.rdchem.Mol
None or pre-computed RDKit molecule instances. If not None, we expect a
one-on-one correspondence between ``dataset.smiles`` and ``mols``, i.e.
``mols[i]`` corresponds to ``dataset.smiles[i]``. Default to None.
sanitize : bool
This argument only comes into effect when ``mols`` is None and decides whether
sanitization is performed in initializing RDKit molecule instances. See
https://www.rdkit.org/docs/RDKit_Book.html for details of the sanitization.
Default to True.
include_chirality : bool
Whether to consider chirality in computing scaffolds. Default to False.
k : int
Number of folds to use and should be no smaller than 2. Default to be 5.
log_every_n : None or int
Molecule related computation can take a long time for a large dataset and we want
to learn the progress of processing. This can be done by printing a message whenever
a batch of ``log_every_n`` molecules have been processed. If None, no messages will
be printed. Default to 1000.
Returns
-------
list of 2-tuples
Each element of the list represents a fold and is a 2-tuple (train_set, val_set).
"""
assert k >= 2, 'Expect the number of folds to be no smaller than 2, got {:d}'.format(k)
molecules = prepare_mols(dataset, mols, sanitize)
scaffold_sets = ScaffoldSplitter.get_ordered_scaffold_sets(
molecules, include_chirality, log_every_n)
# k buckets that form a relatively balanced partition of the dataset
index_buckets = [[] for _ in range(k)]
for group_indices in scaffold_sets:
bucket_chosen = int(np.argmin([len(bucket) for bucket in index_buckets]))
index_buckets[bucket_chosen].extend(group_indices)
all_folds = []
for i in range(k):
if log_every_n is not None:
print('Processing fold {:d}/{:d}'.format(i + 1, k))
train_indices = list(chain.from_iterable(index_buckets[:i] + index_buckets[i+1:]))
val_indices = index_buckets[i]
all_folds.append((Subset(dataset, train_indices), Subset(dataset, val_indices)))
return all_folds
class SingleTaskStratifiedSplitter(object):
"""Splits the dataset by stratification on a single task.
We sort the molecules based on their label values for a task and then repeatedly
take buckets of datapoints to augment the training, validation and test subsets.
"""
@staticmethod
@deprecated('Import SingleTaskStratifiedSplitter from '
'dgllife.utils.splitters instead.', 'class')
def train_val_test_split(dataset, labels, task_id, frac_train=0.8, frac_val=0.1,
frac_test=0.1, bucket_size=10, random_state=None):
"""Split the dataset into training, validation and test subsets as stated above.
Parameters
----------
dataset
We assume ``len(dataset)`` gives the size for the dataset, ``dataset[i]``
gives the ith datapoint and ``dataset.smiles[i]`` gives the SMILES for the
ith datapoint.
labels : tensor of shape (N, T)
Dataset labels all tasks. N for the number of datapoints and T for the number
of tasks.
task_id : int
Index for the task.
frac_train : float
Fraction of data to use for training. By default, we set this to be 0.8, i.e.
80% of the dataset is used for training.
frac_val : float
Fraction of data to use for validation. By default, we set this to be 0.1, i.e.
10% of the dataset is used for validation.
frac_test : float
Fraction of data to use for test. By default, we set this to be 0.1, i.e.
10% of the dataset is used for test.
bucket_size : int
Size of bucket of datapoints. Default to 10.
random_state : None, int or array_like, optional
Random seed used to initialize the pseudo-random number generator.
Can be any integer between 0 and 2**32 - 1 inclusive, an array
(or other sequence) of such integers, or None (the default).
If seed is None, then RandomState will try to read data from /dev/urandom
(or the Windows analogue) if available or seed from the clock otherwise.
Returns
-------
list of length 3
Subsets for training, validation and test, which are all :class:`Subset` instances.
"""
train_val_test_sanity_check(frac_train, frac_val, frac_test)
if random_state is not None:
np.random.seed(random_state)
if not isinstance(labels, np.ndarray):
labels = F.asnumpy(labels)
task_labels = labels[:, task_id]
sorted_indices = np.argsort(task_labels)
train_bucket_cutoff = int(np.round(frac_train * bucket_size))
val_bucket_cutoff = int(np.round(frac_val * bucket_size)) + train_bucket_cutoff
train_indices, val_indices, test_indices = [], [], []
while sorted_indices.shape[0] >= bucket_size:
current_batch, sorted_indices = np.split(sorted_indices, [bucket_size])
shuffled = np.random.permutation(range(bucket_size))
train_indices.extend(
current_batch[shuffled[:train_bucket_cutoff]].tolist())
val_indices.extend(
current_batch[shuffled[train_bucket_cutoff:val_bucket_cutoff]].tolist())
test_indices.extend(
current_batch[shuffled[val_bucket_cutoff:]].tolist())
# Place rest samples in the training set.
train_indices.extend(sorted_indices.tolist())
return [Subset(dataset, train_indices),
Subset(dataset, val_indices),
Subset(dataset, test_indices)]
@staticmethod
@deprecated('Import SingleTaskStratifiedSplitter from '
'dgllife.utils.splitters instead.', 'class')
def k_fold_split(dataset, labels, task_id, k=5, log=True):
"""Sort molecules based on their label values for a task and then split them
for k-fold cross validation by taking consecutive chunks.
Parameters
----------
dataset
We assume ``len(dataset)`` gives the size for the dataset, ``dataset[i]``
gives the ith datapoint and ``dataset.smiles[i]`` gives the SMILES for the
ith datapoint.
labels : tensor of shape (N, T)
Dataset labels all tasks. N for the number of datapoints and T for the number
of tasks.
task_id : int
Index for the task.
k : int
Number of folds to use and should be no smaller than 2. Default to be 5.
log : bool
Whether to print a message at the start of preparing each fold.
Returns
-------
list of 2-tuples
Each element of the list represents a fold and is a 2-tuple (train_set, val_set).
"""
if not isinstance(labels, np.ndarray):
labels = F.asnumpy(labels)
task_labels = labels[:, task_id]
sorted_indices = np.argsort(task_labels).tolist()
return base_k_fold_split(partial(indices_split, indices=sorted_indices), dataset, k, log)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment