Unverified Commit 577cf2e6 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[Model Zoo] Refactor and Add Utils for Chemistry (#928)

* Refactor

* Add note

* Update

* CI
parent 0bf3b6dd
......@@ -142,19 +142,66 @@ Molecular Graphs
To work on molecular graphs, make sure you have installed `RDKit 2018.09.3 <https://www.rdkit.org/docs/Install.html>`__.
Featurization
`````````````
Featurization Utils
```````````````````
For the use of graph neural networks, we need to featurize nodes (atoms) and edges (bonds). Below we list some
featurization methods/utilities:
For the use of graph neural networks, we need to featurize nodes (atoms) and edges (bonds).
General utils:
.. autosummary::
:toctree: ../../generated/
chem.one_hot_encoding
chem.ConcatFeaturizer
chem.ConcatFeaturizer.__call__
Utils for atom featurization:
.. autosummary::
:toctree: ../../generated/
chem.atom_type_one_hot
chem.atomic_number_one_hot
chem.atomic_number
chem.atom_degree_one_hot
chem.atom_degree
chem.atom_total_degree_one_hot
chem.atom_total_degree
chem.atom_implicit_valence_one_hot
chem.atom_implicit_valence
chem.atom_hybridization_one_hot
chem.atom_total_num_H_one_hot
chem.atom_total_num_H
chem.atom_formal_charge_one_hot
chem.atom_formal_charge
chem.atom_num_radical_electrons_one_hot
chem.atom_num_radical_electrons
chem.atom_is_aromatic_one_hot
chem.atom_is_aromatic
chem.atom_chiral_tag_one_hot
chem.atom_mass
chem.BaseAtomFeaturizer
chem.BaseAtomFeaturizer.feat_size
chem.BaseAtomFeaturizer.__call__
chem.CanonicalAtomFeaturizer
Utils for bond featurization:
.. autosummary::
:toctree: ../../generated/
chem.bond_type_one_hot
chem.bond_is_conjugated_one_hot
chem.bond_is_conjugated
chem.bond_is_in_ring_one_hot
chem.bond_is_in_ring
chem.bond_stereo_one_hot
chem.BaseBondFeaturizer
chem.BaseBondFeaturizer.feat_size
chem.BaseBondFeaturizer.__call__
chem.CanonicalBondFeaturizer
Graph Construction
``````````````````
......@@ -164,9 +211,9 @@ Several methods for constructing DGLGraphs from SMILES/RDKit molecule objects ar
:toctree: ../../generated/
chem.mol_to_graph
chem.smile_to_bigraph
chem.smiles_to_bigraph
chem.mol_to_bigraph
chem.smile_to_complete_graph
chem.smiles_to_complete_graph
chem.mol_to_complete_graph
Dataset Classes
......
......@@ -448,7 +448,7 @@ def get_atom_and_bond_types(smiles, log=True):
for i, s in enumerate(smiles):
if log:
print('Processing smile {:d}/{:d}'.format(i + 1, n_smiles))
print('Processing smiles {:d}/{:d}'.format(i + 1, n_smiles))
mol = smiles_to_standard_mol(s)
if mol is None:
......@@ -517,7 +517,7 @@ def eval_decisions(env, decisions):
return env.get_current_smiles()
def get_DGMG_smile(env, mol):
"""Mimics the reproduced SMILE with DGMG for a molecule.
"""Mimics the reproduced SMILES with DGMG for a molecule.
Given a molecule, we are interested in what SMILES we will
get if we want to generate it with DGMG. This is an important
......
import numpy as np
import torch
from torch.nn import BCEWithLogitsLoss
from torch.optim import Adam
from torch.utils.data import DataLoader
from dgl import model_zoo
from dgl.data.utils import split_dataset
from utils import Meter, EarlyStopping, collate_molgraphs_for_classification, set_random_seed
from utils import Meter, EarlyStopping, collate_molgraphs, set_random_seed, \
load_dataset_for_classification
def run_a_train_epoch(args, epoch, model, data_loader, loss_criterion, optimizer):
model.train()
train_meter = Meter()
for batch_id, batch_data in enumerate(data_loader):
smiles, bg, labels, mask = batch_data
smiles, bg, labels, masks = batch_data
atom_feats = bg.ndata.pop(args['atom_data_field'])
atom_feats, labels, mask = atom_feats.to(args['device']), \
labels.to(args['device']), \
mask.to(args['device'])
atom_feats, labels, masks = atom_feats.to(args['device']), \
labels.to(args['device']), \
masks.to(args['device'])
logits = model(bg, atom_feats)
# Mask non-existing labels
loss = (loss_criterion(logits, labels) * (mask != 0).float()).mean()
loss = (loss_criterion(logits, labels) * (masks != 0).float()).mean()
optimizer.zero_grad()
loss.backward()
optimizer.step()
print('epoch {:d}/{:d}, batch {:d}/{:d}, loss {:.4f}'.format(
epoch + 1, args['num_epochs'], batch_id + 1, len(data_loader), loss.item()))
train_meter.update(logits, labels, mask)
train_roc_auc = train_meter.roc_auc_averaged_over_tasks()
print('epoch {:d}/{:d}, training roc-auc score {:.4f}'.format(
epoch + 1, args['num_epochs'], train_roc_auc))
train_meter.update(logits, labels, masks)
train_score = np.mean(train_meter.compute_metric(args['metric_name']))
print('epoch {:d}/{:d}, training {} {:.4f}'.format(
epoch + 1, args['num_epochs'], args['metric_name'], train_score))
def run_an_eval_epoch(args, model, data_loader):
model.eval()
eval_meter = Meter()
with torch.no_grad():
for batch_id, batch_data in enumerate(data_loader):
smiles, bg, labels, mask = batch_data
smiles, bg, labels, masks = batch_data
atom_feats = bg.ndata.pop(args['atom_data_field'])
atom_feats, labels = atom_feats.to(args['device']), labels.to(args['device'])
logits = model(bg, atom_feats)
eval_meter.update(logits, labels, mask)
return eval_meter.roc_auc_averaged_over_tasks()
eval_meter.update(logits, labels, masks)
return np.mean(eval_meter.compute_metric(args['metric_name']))
def main(args):
args['device'] = "cuda" if torch.cuda.is_available() else "cpu"
set_random_seed()
# Interchangeable with other datasets
if args['dataset'] == 'Tox21':
from dgl.data.chem import Tox21
dataset = Tox21()
trainset, valset, testset = split_dataset(dataset, args['train_val_test_split'])
train_loader = DataLoader(trainset, batch_size=args['batch_size'],
collate_fn=collate_molgraphs_for_classification)
val_loader = DataLoader(valset, batch_size=args['batch_size'],
collate_fn=collate_molgraphs_for_classification)
test_loader = DataLoader(testset, batch_size=args['batch_size'],
collate_fn=collate_molgraphs_for_classification)
dataset, train_set, val_set, test_set = load_dataset_for_classification(args)
train_loader = DataLoader(train_set, batch_size=args['batch_size'],
collate_fn=collate_molgraphs)
val_loader = DataLoader(val_set, batch_size=args['batch_size'],
collate_fn=collate_molgraphs)
test_loader = DataLoader(test_set, batch_size=args['batch_size'],
collate_fn=collate_molgraphs)
if args['pre_trained']:
args['num_epochs'] = 0
......@@ -87,17 +84,18 @@ def main(args):
run_a_train_epoch(args, epoch, model, train_loader, loss_criterion, optimizer)
# Validation and early stop
val_roc_auc = run_an_eval_epoch(args, model, val_loader)
early_stop = stopper.step(val_roc_auc, model)
print('epoch {:d}/{:d}, validation roc-auc score {:.4f}, best validation roc-auc score {:.4f}'.format(
epoch + 1, args['num_epochs'], val_roc_auc, stopper.best_score))
val_score = run_an_eval_epoch(args, model, val_loader)
early_stop = stopper.step(val_score, model)
print('epoch {:d}/{:d}, validation {} {:.4f}, best validation {} {:.4f}'.format(
epoch + 1, args['num_epochs'], args['metric_name'],
val_score, args['metric_name'], stopper.best_score))
if early_stop:
break
if not args['pre_trained']:
stopper.load_checkpoint(model)
test_roc_auc = run_an_eval_epoch(args, model, test_loader)
print('test roc-auc score {:.4f}'.format(test_roc_auc))
test_score = run_an_eval_epoch(args, model, test_loader)
print('test {} {:.4f}'.format(args['metric_name'], test_score))
if __name__ == '__main__':
import argparse
......
from dgl.data.chem import CanonicalAtomFeaturizer
GCN_Tox21 = {
'batch_size': 128,
'lr': 1e-3,
......@@ -7,7 +9,9 @@ GCN_Tox21 = {
'in_feats': 74,
'gcn_hidden_feats': [64, 64],
'classifier_hidden_feats': 64,
'patience': 10
'patience': 10,
'atom_featurizer': CanonicalAtomFeaturizer(),
'metric_name': 'roc_auc'
}
GAT_Tox21 = {
......@@ -20,15 +24,20 @@ GAT_Tox21 = {
'gat_hidden_feats': [32, 32],
'classifier_hidden_feats': 64,
'num_heads': [4, 4],
'patience': 10
'patience': 10,
'atom_featurizer': CanonicalAtomFeaturizer(),
'metric_name': 'roc_auc'
}
MPNN_Alchemy = {
'batch_size': 16,
'num_epochs': 250,
'node_in_feats': 15,
'edge_in_feats': 5,
'output_dim': 12,
'lr': 0.0001,
'patience': 50
'patience': 50,
'metric_name': 'l1'
}
SCHNET_Alchemy = {
......@@ -37,7 +46,8 @@ SCHNET_Alchemy = {
'norm': True,
'output_dim': 12,
'lr': 0.0001,
'patience': 50
'patience': 50,
'metric_name': 'l1'
}
MGCN_Alchemy = {
......@@ -46,7 +56,8 @@ MGCN_Alchemy = {
'norm': True,
'output_dim': 12,
'lr': 0.0001,
'patience': 50
'patience': 50,
'metric_name': 'l1'
}
experiment_configures = {
......
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from dgl import model_zoo
from utils import set_random_seed, collate_molgraphs_for_regression, EarlyStopping
from utils import Meter, set_random_seed, collate_molgraphs, EarlyStopping, \
load_dataset_for_regression
def regress(args, model, bg):
if args['model'] == 'MPNN':
......@@ -20,36 +22,35 @@ def regress(args, model, bg):
return model(bg, node_types, edge_distances)
def run_a_train_epoch(args, epoch, model, data_loader,
loss_criterion, score_criterion, optimizer):
loss_criterion, optimizer):
model.train()
total_loss, total_score = 0, 0
train_meter = Meter()
total_loss = 0
for batch_id, batch_data in enumerate(data_loader):
smiles, bg, labels = batch_data
labels = labels.to(args['device'])
smiles, bg, labels, masks = batch_data
labels, masks = labels.to(args['device']), masks.to(args['device'])
prediction = regress(args, model, bg)
loss = loss_criterion(prediction, labels)
score = score_criterion(prediction, labels)
loss = (loss_criterion(prediction, labels) * (masks != 0).float()).mean()
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.detach().item() * bg.batch_size
total_score += score.detach().item() * bg.batch_size
train_meter.update(prediction, labels, masks)
total_loss /= len(data_loader.dataset)
total_score /= len(data_loader.dataset)
print('epoch {:d}/{:d}, training loss {:.4f}, training score {:.4f}'.format(
epoch + 1, args['num_epochs'], total_loss, total_score))
total_score = np.mean(train_meter.compute_metric(args['metric_name']))
print('epoch {:d}/{:d}, training loss {:.4f}, training {} {:.4f}'.format(
epoch + 1, args['num_epochs'], total_loss, args['metric_name'], total_score))
def run_an_eval_epoch(args, model, data_loader, score_criterion):
def run_an_eval_epoch(args, model, data_loader):
model.eval()
total_score = 0
eval_meter = Meter()
with torch.no_grad():
for batch_id, batch_data in enumerate(data_loader):
smiles, bg, labels = batch_data
smiles, bg, labels, masks = batch_data
labels = labels.to(args['device'])
prediction = regress(args, model, bg)
score = score_criterion(prediction, labels)
total_score += score.detach().item() * bg.batch_size
total_score /= len(data_loader.dataset)
eval_meter.update(prediction, labels, masks)
total_score = np.mean(eval_meter.compute_metric(args['metric_name']))
return total_score
def main(args):
......@@ -57,20 +58,22 @@ def main(args):
set_random_seed()
# Interchangeable with other datasets
if args['dataset'] == 'Alchemy':
from dgl.data.chem import TencentAlchemyDataset
train_set = TencentAlchemyDataset(mode='dev')
val_set = TencentAlchemyDataset(mode='valid')
train_set, val_set, test_set = load_dataset_for_regression(args)
train_loader = DataLoader(dataset=train_set,
batch_size=args['batch_size'],
collate_fn=collate_molgraphs_for_regression)
collate_fn=collate_molgraphs)
val_loader = DataLoader(dataset=val_set,
batch_size=args['batch_size'],
collate_fn=collate_molgraphs_for_regression)
collate_fn=collate_molgraphs)
if test_set is not None:
test_loader = DataLoader(dataset=test_set,
batch_size=args['batch_size'],
collate_fn=collate_molgraphs)
if args['model'] == 'MPNN':
model = model_zoo.chem.MPNNModel(output_dim=args['output_dim'])
model = model_zoo.chem.MPNNModel(node_input_dim=args['node_in_feats'],
edge_input_dim=args['edge_in_feats'],
output_dim=args['output_dim'])
elif args['model'] == 'SCHNET':
model = model_zoo.chem.SchNet(norm=args['norm'], output_dim=args['output_dim'])
model.set_mean_std(train_set.mean, train_set.std, args['device'])
......@@ -79,23 +82,28 @@ def main(args):
model.set_mean_std(train_set.mean, train_set.std, args['device'])
model.to(args['device'])
loss_fn = nn.MSELoss()
score_fn = nn.L1Loss()
loss_fn = nn.MSELoss(reduction='none')
optimizer = torch.optim.Adam(model.parameters(), lr=args['lr'])
stopper = EarlyStopping(mode='lower', patience=args['patience'])
for epoch in range(args['num_epochs']):
# Train
run_a_train_epoch(args, epoch, model, train_loader, loss_fn, score_fn, optimizer)
run_a_train_epoch(args, epoch, model, train_loader, loss_fn, optimizer)
# Validation and early stop
val_score = run_an_eval_epoch(args, model, val_loader, score_fn)
val_score = run_an_eval_epoch(args, model, val_loader)
early_stop = stopper.step(val_score, model)
print('epoch {:d}/{:d}, validation score {:.4f}, best validation score {:.4f}'.format(
epoch + 1, args['num_epochs'], val_score, stopper.best_score))
print('epoch {:d}/{:d}, validation {} {:.4f}, best validation {} {:.4f}'.format(
epoch + 1, args['num_epochs'], args['metric_name'], val_score,
args['metric_name'], stopper.best_score))
if early_stop:
break
if test_set is not None:
stopper.load_checkpoint(model)
test_score = run_an_eval_epoch(args, model, test_loader)
print('test {} {:.4f}'.format(args['metric_name'], test_score))
if __name__ == "__main__":
import argparse
......
import datetime
import dgl
import math
import numpy as np
import random
import torch
from sklearn.metrics import roc_auc_score
import torch.nn.functional as F
from dgl.data.utils import split_dataset
from sklearn.metrics import roc_auc_score, mean_squared_error
def set_random_seed(seed=0):
"""Set random seed.
......@@ -45,13 +49,13 @@ class Meter(object):
self.y_true.append(y_true.detach().cpu())
self.mask.append(mask.detach().cpu())
def roc_auc_averaged_over_tasks(self):
"""Compute roc-auc score for each task and return the average.
def roc_auc_score(self):
"""Compute roc-auc score for each task.
Returns
-------
float
roc-auc score averaged over all tasks
list of float
roc-auc score for all tasks
"""
mask = torch.cat(self.mask, dim=0)
y_pred = torch.cat(self.y_pred, dim=0)
......@@ -60,13 +64,83 @@ class Meter(object):
# This assumes binary case only
y_pred = torch.sigmoid(y_pred)
n_tasks = y_true.shape[1]
total_score = 0
scores = []
for task in range(n_tasks):
task_w = mask[:, task]
task_y_true = y_true[:, task][task_w != 0].numpy()
task_y_pred = y_pred[:, task][task_w != 0].numpy()
scores.append(roc_auc_score(task_y_true, task_y_pred))
return scores
def l1_loss(self, reduction):
"""Compute l1 loss for each task.
Returns
-------
list of float
l1 loss for all tasks
reduction : str
* 'mean': average the metric over all labeled data points for each task
* 'sum': sum the metric over all labeled data points for each task
"""
mask = torch.cat(self.mask, dim=0)
y_pred = torch.cat(self.y_pred, dim=0)
y_true = torch.cat(self.y_true, dim=0)
n_tasks = y_true.shape[1]
scores = []
for task in range(n_tasks):
task_w = mask[:, task]
task_y_true = y_true[:, task][task_w != 0]
task_y_pred = y_pred[:, task][task_w != 0]
scores.append(F.l1_loss(task_y_true, task_y_pred, reduction=reduction).item())
return scores
def rmse(self):
"""Compute RMSE for each task.
Returns
-------
list of float
rmse for all tasks
"""
mask = torch.cat(self.mask, dim=0)
y_pred = torch.cat(self.y_pred, dim=0)
y_true = torch.cat(self.y_true, dim=0)
n_data, n_tasks = y_true.shape
scores = []
for task in range(n_tasks):
task_w = mask[:, task]
task_y_true = y_true[:, task][task_w != 0].numpy()
task_y_pred = y_pred[:, task][task_w != 0].numpy()
total_score += roc_auc_score(task_y_true, task_y_pred)
return total_score / n_tasks
scores.append(math.sqrt(mean_squared_error(task_y_true, task_y_pred)))
return scores
def compute_metric(self, metric_name, reduction='mean'):
"""Compute metric for each task.
Parameters
----------
metric_name : str
Name for the metric to compute.
reduction : str
Only comes into effect when the metric_name is l1_loss.
* 'mean': average the metric over all labeled data points for each task
* 'sum': sum the metric over all labeled data points for each task
Returns
-------
list of float
Metric value for each task
"""
assert metric_name in ['roc_auc', 'l1', 'rmse'], \
'Expect metric name to be "roc_auc", "l1" or "rmse", got {}'.format(metric_name)
assert reduction in ['mean', 'sum']
if metric_name == 'roc_auc':
return self.roc_auc_score()
if metric_name == 'l1':
return self.l1_loss(reduction)
if metric_name == 'rmse':
return self.rmse()
class EarlyStopping(object):
"""Early stop performing
......@@ -131,14 +205,15 @@ class EarlyStopping(object):
'''Load model saved with early stopping.'''
model.load_state_dict(torch.load(self.filename)['model_state_dict'])
def collate_molgraphs_for_classification(data):
"""Batching a list of datapoints for dataloader in classification tasks.
def collate_molgraphs(data):
"""Batching a list of datapoints for dataloader.
Parameters
----------
data : list of 4-tuples
data : list of 3-tuples or 4-tuples.
Each tuple is for a single datapoint, consisting of
a SMILE, a DGLGraph, all-task labels and all-task weights
a SMILES, a DGLGraph, all-task labels and optionally
a binary mask indicating the existence of labels.
Returns
-------
......@@ -149,40 +224,79 @@ def collate_molgraphs_for_classification(data):
labels : Tensor of dtype float32 and shape (B, T)
Batched datapoint labels. B is len(data) and
T is the number of total tasks.
weights : Tensor of dtype float32 and shape (B, T)
Batched datapoint weights. T is the number of
total tasks.
masks : Tensor of dtype float32 and shape (B, T)
Batched datapoint binary mask, indicating the
existence of labels. If binary masks are not
provided, return a tensor with ones.
"""
smiles, graphs, labels, mask = map(list, zip(*data))
assert len(data[0]) in [3, 4], \
'Expect the tuple to be of length 3 or 4, got {:d}'.format(len(data[0]))
if len(data[0]) == 3:
smiles, graphs, labels = map(list, zip(*data))
masks = None
else:
smiles, graphs, labels, masks = map(list, zip(*data))
bg = dgl.batch(graphs)
bg.set_n_initializer(dgl.init.zero_initializer)
bg.set_e_initializer(dgl.init.zero_initializer)
labels = torch.stack(labels, dim=0)
mask = torch.stack(mask, dim=0)
return smiles, bg, labels, mask
def collate_molgraphs_for_regression(data):
"""Batching a list of datapoints for dataloader in regression tasks.
if masks is None:
masks = torch.ones(labels.shape)
else:
masks = torch.stack(masks, dim=0)
return smiles, bg, labels, masks
def load_dataset_for_classification(args):
"""Load dataset for classification tasks.
Parameters
----------
data : list of 3-tuples
Each tuple is for a single datapoint, consisting of
a SMILE, a DGLGraph and all-task labels.
args : dict
Configurations.
Returns
-------
smiles : list
List of smiles
bg : BatchedDGLGraph
Batched DGLGraphs
labels : Tensor of dtype float32 and shape (B, T)
Batched datapoint labels. B is len(data) and
T is the number of total tasks.
dataset
The whole dataset.
train_set
Subset for training.
val_set
Subset for validation.
test_set
Subset for test.
"""
smiles, graphs, labels = map(list, zip(*data))
bg = dgl.batch(graphs)
bg.set_n_initializer(dgl.init.zero_initializer)
bg.set_e_initializer(dgl.init.zero_initializer)
labels = torch.stack(labels, dim=0)
return smiles, bg, labels
assert args['dataset'] in ['Tox21']
if args['dataset'] == 'Tox21':
from dgl.data.chem import Tox21
dataset = Tox21(atom_featurizer=args['atom_featurizer'])
train_set, val_set, test_set = split_dataset(dataset, args['train_val_test_split'])
return dataset, train_set, val_set, test_set
def load_dataset_for_regression(args):
"""Load dataset for regression tasks.
Parameters
----------
args : dict
Configurations.
Returns
-------
train_set
Subset for training.
val_set
Subset for validation.
test_set
Subset for test.
"""
assert args['dataset'] in ['Alchemy']
if args['dataset'] == 'Alchemy':
from dgl.data.chem import TencentAlchemyDataset
train_set = TencentAlchemyDataset(mode='dev')
val_set = TencentAlchemyDataset(mode='valid')
test_set = None
return train_set, val_set, test_set
......@@ -10,7 +10,8 @@ import pickle
import zipfile
from collections import defaultdict
from .utils import mol_to_complete_graph
from .utils import mol_to_complete_graph, atom_type_one_hot, atom_hybridization_one_hot, \
atom_is_aromatic
from ..utils import download, get_download_dir, _get_dgl_url, save_graphs, load_graphs
from ... import backend as F
......@@ -59,25 +60,19 @@ def alchemy_nodes(mol):
num_atoms = mol.GetNumAtoms()
for u in range(num_atoms):
atom = mol.GetAtomWithIdx(u)
symbol = atom.GetSymbol()
atom_type = atom.GetAtomicNum()
aromatic = atom.GetIsAromatic()
hybridization = atom.GetHybridization()
num_h = atom.GetTotalNumHs()
atom_feats_dict['node_type'].append(atom_type)
h_u = []
h_u += [int(symbol == x) for x in ['H', 'C', 'N', 'O', 'F', 'S', 'Cl']]
h_u += atom_type_one_hot(atom, ['H', 'C', 'N', 'O', 'F', 'S', 'Cl'])
h_u.append(atom_type)
h_u.append(is_acceptor[u])
h_u.append(is_donor[u])
h_u.append(int(aromatic))
h_u += [
int(hybridization == x)
for x in (Chem.rdchem.HybridizationType.SP,
Chem.rdchem.HybridizationType.SP2,
Chem.rdchem.HybridizationType.SP3)
]
h_u += atom_is_aromatic(atom)
h_u += atom_hybridization_one_hot(atom, [Chem.rdchem.HybridizationType.SP,
Chem.rdchem.HybridizationType.SP2,
Chem.rdchem.HybridizationType.SP3])
h_u.append(num_h)
atom_feats_dict['n_feat'].append(F.tensor(np.array(h_u).astype(np.float32)))
......@@ -155,9 +150,34 @@ class TencentAlchemyDataset(object):
contest is ongoing.
from_raw : bool
Whether to process the dataset from scratch or use a
processed one for faster speed. Default to be False.
processed one for faster speed. If you use different ways
to featurize atoms or bonds, you should set this to be True.
Default to be False.
mol_to_graph: callable, str -> DGLGraph
A function turning an RDKit molecule instance into a DGLGraph.
Default to :func:`dgl.data.chem.mol_to_complete_graph`.
atom_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for atoms in a molecule, which can be used to update
ndata for a DGLGraph. By default, we store the atom atomic numbers
under the name ``"node_type"`` and store the atom features under the
name ``"n_feat"``. The atom features include:
* One hot encoding for atom types
* Atomic number of atoms
* Whether the atom is a donor
* Whether the atom is an acceptor
* Whether the atom is aromatic
* One hot encoding for atom hybridization
* Total number of Hs on the atom
bond_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for bonds in a molecule, which can be used to update
edata for a DGLGraph. By default, we store the distance between the
end atoms under the name ``"distance"`` and store the bond features under
the name ``"e_feat"``. The bond features are one-hot encodings of the bond type.
"""
def __init__(self, mode='dev', from_raw=False):
def __init__(self, mode='dev', from_raw=False,
mol_to_graph=mol_to_complete_graph,
atom_featurizer=alchemy_nodes,
bond_featurizer=alchemy_edges):
if mode == 'test':
raise ValueError('The test mode is not supported before '
'the Alchemy contest finishes.')
......@@ -185,9 +205,9 @@ class TencentAlchemyDataset(object):
archive.extractall(file_dir)
archive.close()
self._load()
self._load(mol_to_graph, atom_featurizer, bond_featurizer)
def _load(self):
def _load(self, mol_to_graph, atom_featurizer, bond_featurizer):
if not self.from_raw:
self.graphs, label_dict = load_graphs(osp.join(self.file_dir, "%s_graphs.bin" % self.mode))
self.labels = label_dict['labels']
......@@ -210,8 +230,8 @@ class TencentAlchemyDataset(object):
for mol, label in zip(supp, self.target.iterrows()):
cnt += 1
print('Processing molecule {:d}/{:d}'.format(cnt, dataset_size))
graph = mol_to_complete_graph(mol, atom_featurizer=alchemy_nodes,
bond_featurizer=alchemy_edges)
graph = mol_to_graph(mol, atom_featurizer=atom_featurizer,
bond_featurizer=bond_featurizer)
smiles = Chem.MolToSmiles(mol)
self.smiles.append(smiles)
self.graphs.append(graph)
......
......@@ -5,10 +5,8 @@ import numpy as np
import os
import sys
from .utils import smile_to_bigraph
from ..utils import save_graphs, load_graphs
from ... import backend as F
from ...graph import DGLGraph
class MoleculeCSVDataset(object):
"""MoleculeCSVDataset
......@@ -27,28 +25,33 @@ class MoleculeCSVDataset(object):
Dataframe including smiles and labels. Can be loaded by pandas.read_csv(file_path).
One column includes smiles and other columns for labels.
Column names other than smiles column would be considered as task names.
smile_to_graph: callable, str -> DGLGraph
A function turns smiles into a DGLGraph. Default one can be found
at python/dgl/data/chem/utils.py named with smile_to_bigraph.
smile_column: str
Column name that including smiles
smiles_to_graph: callable, str -> DGLGraph
A function turning a SMILES into a DGLGraph.
atom_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for atoms in a molecule, which can be used to update
ndata for a DGLGraph.
bond_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for bonds in a molecule, which can be used to update
edata for a DGLGraph.
smiles_column: str
Column name that including smiles.
cache_file_path: str
Path to store the preprocessed data
Path to store the preprocessed data.
"""
def __init__(self, df, smile_to_graph=smile_to_bigraph, smile_column='smiles',
cache_file_path="csvdata_dglgraph.bin"):
def __init__(self, df, smiles_to_graph, atom_featurizer, bond_featurizer,
smiles_column, cache_file_path):
if 'rdkit' not in sys.modules:
from ...base import dgl_warning
dgl_warning(
"Please install RDKit (Recommended Version is 2018.09.3)")
self.df = df
self.smiles = self.df[smile_column].tolist()
self.task_names = self.df.columns.drop([smile_column]).tolist()
self.smiles = self.df[smiles_column].tolist()
self.task_names = self.df.columns.drop([smiles_column]).tolist()
self.n_tasks = len(self.task_names)
self.cache_file_path = cache_file_path
self._pre_process(smile_to_graph)
self._pre_process(smiles_to_graph, atom_featurizer, bond_featurizer)
def _pre_process(self, smile_to_graph):
def _pre_process(self, smiles_to_graph, atom_featurizer, bond_featurizer):
"""Pre-process the dataset
* Convert molecules from smiles format into DGLGraphs
......@@ -58,8 +61,14 @@ class MoleculeCSVDataset(object):
Parameters
----------
smile_to_graph : callable, SMILES -> DGLGraph
Function for converting a SMILES (str) into a DGLGraph
smiles_to_graph : callable, SMILES -> DGLGraph
Function for converting a SMILES (str) into a DGLGraph.
atom_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for atoms in a molecule, which can be used to update
ndata for a DGLGraph.
bond_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for bonds in a molecule, which can be used to update
edata for a DGLGraph.
"""
if os.path.exists(self.cache_file_path):
# DGLGraphs have been constructed before, reload them
......@@ -72,7 +81,8 @@ class MoleculeCSVDataset(object):
self.graphs = []
for i, s in enumerate(self.smiles):
print('Processing molecule {:d}/{:d}'.format(i+1, len(self)))
self.graphs.append(smile_to_graph(s))
self.graphs.append(smiles_to_graph(s, atom_featurizer=atom_featurizer,
bond_featurizer=bond_featurizer))
_label_values = self.df[self.task_names].values
# np.nan_to_num will also turn inf into a very large number
self.labels = F.zerocopy_from_numpy(np.nan_to_num(_label_values).astype(np.float32))
......
......@@ -2,7 +2,7 @@ import numpy as np
import sys
from .csv_dataset import MoleculeCSVDataset
from .utils import smile_to_bigraph
from .utils import smiles_to_bigraph
from ..utils import get_download_dir, download, _get_dgl_url
from ... import backend as F
......@@ -30,11 +30,19 @@ class Tox21(MoleculeCSVDataset):
Parameters
----------
smile_to_graph: callable, str -> DGLGraph
A function turns smiles into a DGLGraph. Default one can be found
at python/dgl/data/chem/utils.py named with smile_to_bigraph.
smiles_to_graph: callable, str -> DGLGraph
A function turning smiles into a DGLGraph.
Default to :func:`dgl.data.chem.smiles_to_bigraph`.
atom_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for atoms in a molecule, which can be used to update
ndata for a DGLGraph. Default to None.
bond_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for bonds in a molecule, which can be used to update
edata for a DGLGraph. Default to None.
"""
def __init__(self, smile_to_graph=smile_to_bigraph):
def __init__(self, smiles_to_graph=smiles_to_bigraph,
atom_featurizer=None,
bond_featurizer=None):
if 'pandas' not in sys.modules:
from ...base import dgl_warning
dgl_warning("Please install pandas")
......@@ -47,10 +55,10 @@ class Tox21(MoleculeCSVDataset):
df = df.drop(columns=['mol_id'])
super().__init__(df, smile_to_graph, cache_file_path="tox21_dglgraph.bin")
super(Tox21, self).__init__(df, smiles_to_graph, atom_featurizer, bond_featurizer,
"smiles", "tox21_dglgraph.bin")
self._weight_balancing()
def _weight_balancing(self):
"""Perform re-balancing for each task.
......@@ -71,7 +79,6 @@ class Tox21(MoleculeCSVDataset):
num_pos = F.sum(self.labels, dim=0)
num_indices = F.sum(self.mask, dim=0)
self._task_pos_weights = (num_indices - num_pos) / num_pos
@property
def task_pos_weights(self):
......
import dgl.backend as F
import itertools
import numpy as np
from functools import partial
from collections import defaultdict
from dgl import DGLGraph
try:
......@@ -10,42 +12,568 @@ try:
except ImportError:
pass
__all__ = ['one_hot_encoding', 'BaseAtomFeaturizer', 'CanonicalAtomFeaturizer',
'mol_to_graph', 'smile_to_bigraph', 'mol_to_bigraph',
'smile_to_complete_graph', 'mol_to_complete_graph']
__all__ = ['one_hot_encoding', 'atom_type_one_hot', 'atomic_number_one_hot', 'atomic_number',
'atom_degree_one_hot', 'atom_degree', 'atom_total_degree_one_hot', 'atom_total_degree',
'atom_implicit_valence_one_hot', 'atom_implicit_valence', 'atom_hybridization_one_hot',
'atom_total_num_H_one_hot', 'atom_total_num_H', 'atom_formal_charge_one_hot',
'atom_formal_charge', 'atom_num_radical_electrons_one_hot',
'atom_num_radical_electrons', 'atom_is_aromatic_one_hot', 'atom_is_aromatic',
'atom_chiral_tag_one_hot', 'atom_mass', 'ConcatFeaturizer', 'BaseAtomFeaturizer',
'CanonicalAtomFeaturizer', 'mol_to_graph', 'smiles_to_bigraph',
'mol_to_bigraph', 'smiles_to_complete_graph', 'mol_to_complete_graph',
'bond_type_one_hot', 'bond_is_conjugated_one_hot', 'bond_is_conjugated',
'bond_is_in_ring_one_hot', 'bond_is_in_ring', 'bond_stereo_one_hot',
'BaseBondFeaturizer', 'CanonicalBondFeaturizer']
def one_hot_encoding(x, allowable_set):
def one_hot_encoding(x, allowable_set, encode_unknown=False):
"""One-hot encoding.
Parameters
----------
x : str, int or Chem.rdchem.HybridizationType
x
Value to encode.
allowable_set : list
The elements of the allowable_set should be of the
same type as x.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element.
Returns
-------
list
List of boolean values where at most one value is True.
If the i-th value is True, then we must have
x == allowable_set[i].
The list is of length ``len(allowable_set)`` if ``encode_unknown=False``
and ``len(allowable_set) + 1`` otherwise.
"""
if encode_unknown:
allowable_set.append(None)
if x not in allowable_set:
x = None
return list(map(lambda s: x == s, allowable_set))
#################################################################
# Atom featurization
#################################################################
def atom_type_one_hot(atom, allowable_set=None, encode_unknown=False):
"""One hot encoding for the type of an atom.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
allowable_set : list of str
Atom types to consider. Default: ``C``, ``N``, ``O``, ``S``, ``F``, ``Si``, ``P``,
``Cl``, ``Br``, ``Mg``, ``Na``, ``Ca``, ``Fe``, ``As``, ``Al``, ``I``, ``B``, ``V``,
``K``, ``Tl``, ``Yb``, ``Sb``, ``Sn``, ``Ag``, ``Pd``, ``Co``, ``Se``, ``Ti``, ``Zn``,
``H``, ``Li``, ``Ge``, ``Cu``, ``Au``, ``Ni``, ``Cd``, ``In``, ``Mn``, ``Zr``, ``Cr``,
``Pt``, ``Hg``, ``Pb``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
Returns
-------
list
List of boolean values where at most one value is True.
"""
if allowable_set is None:
allowable_set = ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na', 'Ca',
'Fe', 'As', 'Al', 'I', 'B', 'V', 'K', 'Tl', 'Yb', 'Sb', 'Sn',
'Ag', 'Pd', 'Co', 'Se', 'Ti', 'Zn', 'H', 'Li', 'Ge', 'Cu', 'Au',
'Ni', 'Cd', 'In', 'Mn', 'Zr', 'Cr', 'Pt', 'Hg', 'Pb']
return one_hot_encoding(atom.GetSymbol(), allowable_set, encode_unknown)
def atomic_number_one_hot(atom, allowable_set=None, encode_unknown=False):
"""One hot encoding for the atomic number of an atom.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
allowable_set : list of int
Atomic numbers to consider. Default: ``1`` - ``100``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
Returns
-------
list
List of boolean values where at most one value is True.
"""
if allowable_set is None:
allowable_set = list(range(1, 101))
return one_hot_encoding(atom.GetAtomicNum(), allowable_set, encode_unknown)
def atomic_number(atom):
"""Get the atomic number for an atom.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
Returns
-------
list
List containing one int only.
"""
return [atom.GetAtomicNum()]
def atom_degree_one_hot(atom, allowable_set=None, encode_unknown=False):
"""One hot encoding for the degree of an atom.
Note that the result will be different depending on whether the Hs are
explicitly modeled in the graph.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
allowable_set : list of int
Atom degrees to consider. Default: ``0`` - ``10``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
Returns
-------
list
List of boolean values where at most one value is True.
See Also
--------
atom_total_degree_one_hot
"""
if allowable_set is None:
allowable_set = list(range(11))
return one_hot_encoding(atom.GetDegree(), allowable_set, encode_unknown)
def atom_degree(atom):
"""Get the degree of an atom.
Note that the result will be different depending on whether the Hs are
explicitly modeled in the graph.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
Returns
-------
list
List containing one int only.
See Also
--------
atom_total_degree
"""
return [atom.GetDegree()]
def atom_total_degree_one_hot(atom, allowable_set=None, encode_unknown=False):
"""One hot encoding for the degree of an atom including Hs.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
allowable_set : list
Total degrees to consider. Default: ``0`` - ``5``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
See Also
--------
atom_degree_one_hot
"""
if allowable_set is None:
allowable_set = list(range(6))
return one_hot_encoding(atom.GetTotalDegree(), allowable_set, encode_unknown)
def atom_total_degree(atom):
"""
See Also
--------
atom_degree
Returns
-------
list
List containing one int only.
"""
return [atom.GetTotalDegree()]
def atom_implicit_valence_one_hot(atom, allowable_set=None, encode_unknown=False):
"""One hot encoding for the implicit valences of an atom.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
allowable_set : list of int
Atom implicit valences to consider. Default: ``0`` - ``6``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
Returns
-------
list
List of boolean values where at most one value is True.
"""
if allowable_set is None:
allowable_set = list(range(7))
return one_hot_encoding(atom.GetImplicitValence(), allowable_set, encode_unknown)
def atom_implicit_valence(atom):
"""Get the implicit valence of an atom.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
Reurns
------
list
List containing one int only.
"""
return [atom.GetImplicitValence()]
def atom_hybridization_one_hot(atom, allowable_set=None, encode_unknown=False):
"""One hot encoding for the hybridization of an atom.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
allowable_set : list of rdkit.Chem.rdchem.HybridizationType
Atom hybridizations to consider. Default: ``Chem.rdchem.HybridizationType.SP``,
``Chem.rdchem.HybridizationType.SP2``, ``Chem.rdchem.HybridizationType.SP3``,
``Chem.rdchem.HybridizationType.SP3D``, ``Chem.rdchem.HybridizationType.SP3D2``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
Returns
-------
list
List of boolean values where at most one value is True.
"""
if allowable_set is None:
allowable_set = [Chem.rdchem.HybridizationType.SP,
Chem.rdchem.HybridizationType.SP2,
Chem.rdchem.HybridizationType.SP3,
Chem.rdchem.HybridizationType.SP3D,
Chem.rdchem.HybridizationType.SP3D2]
return one_hot_encoding(atom.GetHybridization(), allowable_set, encode_unknown)
def atom_total_num_H_one_hot(atom, allowable_set=None, encode_unknown=False):
"""One hot encoding for the total number of Hs of an atom.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
allowable_set : list of int
Total number of Hs to consider. Default: ``0`` - ``4``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
Returns
-------
list
List of boolean values where at most one value is True.
"""
if allowable_set is None:
allowable_set = list(range(5))
return one_hot_encoding(atom.GetTotalNumHs(), allowable_set, encode_unknown)
def atom_total_num_H(atom):
"""Get the total number of Hs of an atom.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
Returns
-------
list
List containing one int only.
"""
return [atom.GetTotalNumHs()]
def atom_formal_charge_one_hot(atom, allowable_set=None, encode_unknown=False):
"""One hot encoding for the formal charge of an atom.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
allowable_set : list of int
Formal charges to consider. Default: ``-2`` - ``2``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
Returns
-------
list
List of boolean values where at most one value is True.
"""
if allowable_set is None:
allowable_set = list(range(-2, 3))
return one_hot_encoding(atom.GetFormalCharge(), allowable_set, encode_unknown)
def atom_formal_charge(atom):
"""Get formal charge for an atom.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
Returns
-------
list
List containing one int only.
"""
return [atom.GetFormalCharge()]
def atom_num_radical_electrons_one_hot(atom, allowable_set=None, encode_unknown=False):
"""One hot encoding for the number of radical electrons of an atom.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
allowable_set : list of int
Number of radical electrons to consider. Default: ``0`` - ``4``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
Returns
-------
list
List of boolean values where at most one value is True.
"""
if allowable_set is None:
allowable_set = list(range(5))
return one_hot_encoding(atom.GetNumRadicalElectrons(), allowable_set, encode_unknown)
def atom_num_radical_electrons(atom):
"""Get the number of radical electrons for an atom.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
Returns
-------
list
List containing one int only.
"""
return [atom.GetNumRadicalElectrons()]
def atom_is_aromatic_one_hot(atom, allowable_set=None, encode_unknown=False):
"""One hot encoding for whether the atom is aromatic.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
allowable_set : list of bool
Conditions to consider. Default: ``False`` and ``True``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
Returns
-------
list
List of boolean values where at most one value is True.
"""
if allowable_set is None:
allowable_set = [False, True]
return one_hot_encoding(atom.GetIsAromatic(), allowable_set, encode_unknown)
def atom_is_aromatic(atom):
"""Get whether the atom is aromatic.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
Returns
-------
list
List containing one bool only.
"""
return [atom.GetIsAromatic()]
def atom_chiral_tag_one_hot(atom, allowable_set=None, encode_unknown=False):
"""One hot encoding for the chiral tag of an atom.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
allowable_set : list of rdkit.Chem.rdchem.ChiralType
Chiral tags to consider. Default: ``rdkit.Chem.rdchem.ChiralType.CHI_UNSPECIFIED``,
``rdkit.Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW``,
``rdkit.Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW``,
``rdkit.Chem.rdchem.ChiralType.CHI_OTHER``.
"""
if allowable_set is None:
allowable_set = [Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW,
Chem.rdchem.ChiralType.CHI_OTHER]
return one_hot_encoding(atom.GetChiralTag(), allowable_set, encode_unknown)
def atom_mass(atom, coef=0.01):
"""Get the mass of an atom and scale it.
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
RDKit atom instance.
coef : float
The mass will be multiplied by ``coef``.
Returns
-------
list
List containing one float only.
"""
return [atom.GetMass() * coef]
class ConcatFeaturizer(object):
"""Concatenate the evaluation results of multiple functions as a single feature.
Parameters
----------
func_list : list
List of functions for computing molecular descriptors from objects of a same
particular data type, e.g. ``rdkit.Chem.rdchem.Atom``. Each function is of signature
``func(data_type) -> list of float or bool or int``. The resulting order of
the features will follow that of the functions in the list.
"""
def __init__(self, func_list):
self.func_list = func_list
def __call__(self, x):
"""Featurize the input data.
Parameters
----------
x :
Data to featurize.
Returns
-------
list
List of feature values, which can be of type bool, float or int.
"""
return list(itertools.chain.from_iterable(
[func(x) for func in self.func_list]))
class BaseAtomFeaturizer(object):
"""An abstract class for atom featurizers
"""An abstract class for atom featurizers.
Loop over all atoms in a molecule and featurize them with the ``featurizer_funcs``.
**We assume the resulting DGLGraph will not contain any virtual nodes.**
All atom featurizers that map a molecule to atom features should subclass it.
All subclasses should overwrite ``_featurize_atom``, which featurizes a single
atom and ``__call__``, which featurizes all atoms in a molecule.
Parameters
----------
featurizer_funcs : dict
Mapping feature name to the featurization function.
Each function is of signature ``func(rdkit.Chem.rdchem.Atom) -> list or 1D numpy array``.
feat_sizes : dict
Mapping feature name to the size of the corresponding feature. If None, they will be
computed when needed. Default: None.
Examples
--------
>>> from dgl.data.chem import BaseAtomFeaturizer, atom_mass, atom_degree_one_hot
>>> from rdkit import Chem
>>> mol = Chem.MolFromSmiles('CCO')
>>> atom_featurizer = BaseAtomFeaturizer({'mass': atom_mass, 'degree': atom_degree_one_hot})
>>> atom_featurizer(mol)
{'mass': tensor([[0.1201],
[0.1201],
[0.1600]]),
'degree': tensor([[0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])}
"""
def __init__(self, featurizer_funcs, feat_sizes=None):
self.featurizer_funcs = featurizer_funcs
if feat_sizes is None:
feat_sizes = dict()
self._feat_sizes = feat_sizes
def _featurize_atom(self, atom):
return NotImplementedError
def feat_size(self, feat_name):
"""Get the feature size for ``feat_name``.
Returns
-------
int
Feature size for the feature with name ``feat_name``.
"""
if feat_name not in self.featurizer_funcs:
return ValueError('Expect feat_name to be in {}, got {}'.format(
list(self.featurizer_funcs.keys()), feat_name))
if feat_name not in self._feat_sizes:
atom = Chem.MolFromSmiles('C').GetAtomWithIdx(0)
self._feat_sizes[feat_name] = len(self.featurizer_funcs[feat_name](atom))
return self._feat_sizes[feat_name]
def __call__(self, mol):
return NotImplementedError
"""Featurize all atoms in a molecule.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance.
Returns
-------
dict
For each function in self.featurizer_funcs with the key ``k``, store the computed
feature under the key ``k``. Each feature is a tensor of dtype float32 and shape
(N, M), where N is the number of atoms in the molecule.
"""
num_atoms = mol.GetNumAtoms()
atom_features = defaultdict(list)
# Compute features for each atom
for i in range(num_atoms):
atom = mol.GetAtomWithIdx(i)
for feat_name, feat_func in self.featurizer_funcs.items():
atom_features[feat_name].append(feat_func(atom))
# Stack the features and convert them to float arrays
processed_features = dict()
for feat_name, feat_list in atom_features.items():
feat = np.stack(feat_list)
processed_features[feat_name] = F.zerocopy_from_numpy(feat.astype(np.float32))
return processed_features
class CanonicalAtomFeaturizer(BaseAtomFeaturizer):
"""A default featurizer for atoms.
......@@ -70,55 +598,207 @@ class CanonicalAtomFeaturizer(BaseAtomFeaturizer):
* **One hot encoding of the number of total Hs on the atom**. The supported possibilities
include ``0 - 4``.
**We assume the resulting DGLGraph will not contain any virtual nodes.**
Parameters
----------
atom_data_field : str
Name for storing atom features in DGLGraphs, default to be 'h'.
"""
def __init__(self, atom_data_field='h'):
super(CanonicalAtomFeaturizer, self).__init__()
self.atom_data_field = atom_data_field
super(CanonicalAtomFeaturizer, self).__init__(
featurizer_funcs={atom_data_field: ConcatFeaturizer(
[atom_type_one_hot,
atom_degree_one_hot,
atom_implicit_valence_one_hot,
atom_formal_charge,
atom_num_radical_electrons,
atom_hybridization_one_hot,
atom_is_aromatic,
atom_total_num_H_one_hot]
)})
@property
def feat_size(self):
"""Returns feature size"""
return 74
def bond_type_one_hot(bond, allowable_set=None, encode_unknown=False):
"""One hot encoding for the type of a bond.
def _featurize_atom(self, atom):
"""Featurize an atom
Parameters
----------
bond : rdkit.Chem.rdchem.Bond
RDKit bond instance.
allowable_set : list of Chem.rdchem.BondType
Bond types to consider. Default: ``Chem.rdchem.BondType.SINGLE``,
``Chem.rdchem.BondType.DOUBLE``, ``Chem.rdchem.BondType.TRIPLE``,
``Chem.rdchem.BondType.AROMATIC``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
Parameters
----------
atom : rdkit.Chem.rdchem.Atom
Returns
-------
list
List of boolean values where at most one value is True.
"""
if allowable_set is None:
allowable_set = [Chem.rdchem.BondType.SINGLE,
Chem.rdchem.BondType.DOUBLE,
Chem.rdchem.BondType.TRIPLE,
Chem.rdchem.BondType.AROMATIC]
return one_hot_encoding(bond.GetBondType(), allowable_set, encode_unknown)
def bond_is_conjugated_one_hot(bond, allowable_set=None, encode_unknown=False):
"""One hot encoding for whether the bond is conjugated.
Parameters
----------
bond : rdkit.Chem.rdchem.Bond
RDKit bond instance.
allowable_set : list of bool
Conditions to consider. Default: ``False`` and ``True``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
Returns
-------
list
List of boolean values where at most one value is True.
"""
if allowable_set is None:
allowable_set = [False, True]
return one_hot_encoding(bond.GetIsConjugated(), allowable_set, encode_unknown)
def bond_is_conjugated(bond):
"""Get whether the bond is conjugated.
Parameters
----------
bond : rdkit.Chem.rdchem.Bond
RDKit bond instance.
Returns
-------
list
List containing one bool only.
"""
return [bond.GetIsConjugated()]
def bond_is_in_ring_one_hot(bond, allowable_set=None, encode_unknown=False):
"""One hot encoding for whether the bond is in a ring of any size.
Parameters
----------
bond : rdkit.Chem.rdchem.Bond
RDKit bond instance.
allowable_set : list of bool
Conditions to consider. Default: ``False`` and ``True``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
Returns
-------
list
List of boolean values where at most one value is True.
"""
if allowable_set is None:
allowable_set = [False, True]
return one_hot_encoding(bond.IsInRing(), allowable_set, encode_unknown)
def bond_is_in_ring(bond):
"""Get whether the bond is in a ring of any size.
Parameters
----------
bond : rdkit.Chem.rdchem.Bond
RDKit bond instance.
Returns
-------
list
List containing one bool only.
"""
return [bond.IsInRing()]
def bond_stereo_one_hot(bond, allowable_set=None, encode_unknown=False):
"""One hot encoding for the stereo configuration of a bond.
Parameters
----------
bond : rdkit.Chem.rdchem.Bond
RDKit bond instance.
allowable_set : list of rdkit.Chem.rdchem.BondStereo
Stereo configurations to consider. Default: ``rdkit.Chem.rdchem.BondStereo.STEREONONE``,
``rdkit.Chem.rdchem.BondStereo.STEREOANY``, ``rdkit.Chem.rdchem.BondStereo.STEREOZ``,
``rdkit.Chem.rdchem.BondStereo.STEREOE``, ``rdkit.Chem.rdchem.BondStereo.STEREOCIS``,
``rdkit.Chem.rdchem.BondStereo.STEREOTRANS``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
Returns
-------
list
List of boolean values where at most one value is True.
"""
if allowable_set is None:
allowable_set = [Chem.rdchem.BondStereo.STEREONONE,
Chem.rdchem.BondStereo.STEREOANY,
Chem.rdchem.BondStereo.STEREOZ,
Chem.rdchem.BondStereo.STEREOE,
Chem.rdchem.BondStereo.STEREOCIS,
Chem.rdchem.BondStereo.STEREOTRANS]
return one_hot_encoding(bond.GetStereo(), allowable_set, encode_unknown)
class BaseBondFeaturizer(object):
"""An abstract class for bond featurizers.
Loop over all bonds in a molecule and featurize them with the ``featurizer_funcs``.
We assume the constructed ``DGLGraph`` is a bi-directed graph where the **i** th bond in the
molecule, i.e. ``mol.GetBondWithIdx(i)``, corresponds to the **(2i)**-th and **(2i+1)**-th edges
in the DGLGraph.
**We assume the resulting DGLGraph will be created with :func:`smiles_to_bigraph` without
self loops.**
Parameters
----------
featurizer_funcs : dict
Mapping feature name to the featurization function.
Each function is of signature ``func(rdkit.Chem.rdchem.Bond) -> list or 1D numpy array``.
feat_sizes : dict
Mapping feature name to the size of the corresponding feature. If None, they will be
computed when needed. Default: None.
Examples
--------
>>> from dgl.data.chem import BaseBondFeaturizer, bond_type_one_hot, bond_is_in_ring
>>> from rdkit import Chem
>>> mol = Chem.MolFromSmiles('CCO')
>>> bond_featurizer = BaseBondFeaturizer({'bond_type': bond_type_one_hot, 'in_ring': bond_is_in_ring})
>>> bond_featurizer(mol)
{'bond_type': tensor([[1., 0., 0., 0.],
[1., 0., 0., 0.],
[1., 0., 0., 0.],
[1., 0., 0., 0.]]),
'in_ring': tensor([[0.], [0.], [0.], [0.]])}
"""
def __init__(self, featurizer_funcs, feat_sizes=None):
self.featurizer_funcs = featurizer_funcs
if feat_sizes is None:
feat_sizes = dict()
self._feat_sizes = feat_sizes
def feat_size(self, feat_name):
"""Get the feature size for ``feat_name``.
Returns
-------
results : list
List of feature values, including boolean values and numbers
int
Feature size for the feature with name ``feat_name``.
"""
atom_types = ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br',
'Mg', 'Na', 'Ca', 'Fe', 'As', 'Al', 'I', 'B', 'V',
'K', 'Tl', 'Yb', 'Sb', 'Sn', 'Ag', 'Pd', 'Co', 'Se',
'Ti', 'Zn', 'H', 'Li', 'Ge', 'Cu', 'Au', 'Ni', 'Cd',
'In', 'Mn', 'Zr', 'Cr', 'Pt', 'Hg', 'Pb']
results = one_hot_encoding(atom.GetSymbol(), atom_types) + \
one_hot_encoding(atom.GetDegree(), [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + \
one_hot_encoding(atom.GetImplicitValence(), [0, 1, 2, 3, 4, 5, 6]) + \
[atom.GetFormalCharge(), atom.GetNumRadicalElectrons()] + \
one_hot_encoding(atom.GetHybridization(),
[Chem.rdchem.HybridizationType.SP,
Chem.rdchem.HybridizationType.SP2,
Chem.rdchem.HybridizationType.SP3,
Chem.rdchem.HybridizationType.SP3D,
Chem.rdchem.HybridizationType.SP3D2]) + \
[atom.GetIsAromatic()] + \
one_hot_encoding(atom.GetTotalNumHs(), [0, 1, 2, 3, 4])
return results
if feat_name not in self.featurizer_funcs:
return ValueError('Expect feat_name to be in {}, got {}'.format(
list(self.featurizer_funcs.keys()), feat_name))
if feat_name not in self._feat_sizes:
bond = Chem.MolFromSmiles('CO').GetBondWithIdx(0)
self._feat_sizes[feat_name] = len(self.featurizer_funcs[feat_name](bond))
return self._feat_sizes[feat_name]
def __call__(self, mol):
"""Featurize a molecule
"""Featurize all bonds in a molecule.
Parameters
----------
......@@ -128,18 +808,55 @@ class CanonicalAtomFeaturizer(BaseAtomFeaturizer):
Returns
-------
dict
Atom features of shape (N, 74),
where N is the number of atoms in the molecule
For each function in self.featurizer_funcs with the key ``k``, store the computed
feature under the key ``k``. Each feature is a tensor of dtype float32 and shape
(N, M), where N is the number of atoms in the molecule.
"""
num_atoms = mol.GetNumAtoms()
atom_features = []
for i in range(num_atoms):
atom = mol.GetAtomWithIdx(i)
atom_features.append(self._featurize_atom(atom))
atom_features = np.stack(atom_features)
atom_features = F.zerocopy_from_numpy(atom_features.astype(np.float32))
num_bonds = mol.GetNumBonds()
bond_features = defaultdict(list)
# Compute features for each bond
for i in range(num_bonds):
bond = mol.GetBondWithIdx(i)
for feat_name, feat_func in self.featurizer_funcs.items():
feat = feat_func(bond)
bond_features[feat_name].extend([feat, feat.copy()])
# Stack the features and convert them to float arrays
processed_features = dict()
for feat_name, feat_list in bond_features.items():
feat = np.stack(feat_list)
processed_features[feat_name] = F.zerocopy_from_numpy(feat.astype(np.float32))
return {self.atom_data_field: atom_features}
return processed_features
class CanonicalBondFeaturizer(BaseBondFeaturizer):
"""A default featurizer for bonds.
The bond features include:
* **One hot encoding of the bond type**. The supported bond types include
``SINGLE``, ``DOUBLE``, ``TRIPLE``, ``AROMATIC``.
* **Whether the bond is conjugated.**.
* **Whether the bond is in a ring of any size.**
* **One hot encoding of the stereo configuration of a bond**. The supported bond stereo
configurations include ``STEREONONE``, ``STEREOANY``, ``STEREOZ``, ``STEREOE``,
``STEREOCIS``, ``STEREOTRANS``.
**We assume the resulting DGLGraph will be created with :func:`smiles_to_bigraph` without
self loops.**
"""
def __init__(self, bond_data_field='e'):
super(CanonicalBondFeaturizer, self).__init__(
featurizer_funcs={bond_data_field: ConcatFeaturizer(
[bond_type_one_hot,
bond_is_conjugated,
bond_is_in_ring,
bond_stereo_one_hot]
)})
#################################################################
# DGLGraph Construction
#################################################################
def mol_to_graph(mol, graph_constructor, atom_featurizer, bond_featurizer):
"""Convert an RDKit molecule object into a DGLGraph and featurize for it.
......@@ -193,7 +910,7 @@ def construct_bigraph_from_mol(mol, add_self_loop=False):
mol : rdkit.Chem.rdchem.Mol
RDKit molecule holder
add_self_loop : bool
Whether to add self loops in DGLGraphs.
Whether to add self loops in DGLGraphs. Default to False.
Returns
-------
......@@ -225,7 +942,7 @@ def construct_bigraph_from_mol(mol, add_self_loop=False):
return g
def mol_to_bigraph(mol, add_self_loop=False,
atom_featurizer=CanonicalAtomFeaturizer(),
atom_featurizer=None,
bond_featurizer=None):
"""Convert an RDKit molecule object into a bi-directed DGLGraph and featurize for it.
......@@ -234,13 +951,13 @@ def mol_to_bigraph(mol, add_self_loop=False,
mol : rdkit.Chem.rdchem.Mol
RDKit molecule holder
add_self_loop : bool
Whether to add self loops in DGLGraphs.
Whether to add self loops in DGLGraphs. Default to False.
atom_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for atoms in a molecule, which can be used to update
ndata for a DGLGraph. Default to CanonicalAtomFeaturizer().
ndata for a DGLGraph. Default to None.
bond_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for bonds in a molecule, which can be used to update
edata for a DGLGraph.
edata for a DGLGraph. Default to None.
Returns
-------
......@@ -250,30 +967,30 @@ def mol_to_bigraph(mol, add_self_loop=False,
return mol_to_graph(mol, partial(construct_bigraph_from_mol, add_self_loop=add_self_loop),
atom_featurizer, bond_featurizer)
def smile_to_bigraph(smile, add_self_loop=False,
atom_featurizer=CanonicalAtomFeaturizer(),
bond_featurizer=None):
def smiles_to_bigraph(smiles, add_self_loop=False,
atom_featurizer=None,
bond_featurizer=None):
"""Convert a SMILES into a bi-directed DGLGraph and featurize for it.
Parameters
----------
smile : str
smiles : str
String of SMILES
add_self_loop : bool
Whether to add self loops in DGLGraphs.
Whether to add self loops in DGLGraphs. Default to False.
atom_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for atoms in a molecule, which can be used to update
ndata for a DGLGraph. Default to CanonicalAtomFeaturizer().
ndata for a DGLGraph. Default to None.
bond_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for bonds in a molecule, which can be used to update
edata for a DGLGraph.
edata for a DGLGraph. Default to None.
Returns
-------
g : DGLGraph
Bi-directed DGLGraph for the molecule
"""
mol = Chem.MolFromSmiles(smile)
mol = Chem.MolFromSmiles(smiles)
return mol_to_bigraph(mol, add_self_loop, atom_featurizer, bond_featurizer)
def construct_complete_graph_from_mol(mol, add_self_loop=False):
......@@ -290,7 +1007,7 @@ def construct_complete_graph_from_mol(mol, add_self_loop=False):
mol : rdkit.Chem.rdchem.Mol
RDKit molecule holder
add_self_loop : bool
Whether to add self loops in DGLGraphs.
Whether to add self loops in DGLGraphs. Default to False.
Returns
-------
......@@ -324,13 +1041,13 @@ def mol_to_complete_graph(mol, add_self_loop=False,
mol : rdkit.Chem.rdchem.Mol
RDKit molecule holder
add_self_loop : bool
Whether to add self loops in DGLGraphs.
Whether to add self loops in DGLGraphs. Default to False.
atom_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for atoms in a molecule, which can be used to update
ndata for a DGLGraph. Default to CanonicalAtomFeaturizer().
ndata for a DGLGraph. Default to None.
bond_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for bonds in a molecule, which can be used to update
edata for a DGLGraph.
edata for a DGLGraph. Default to None.
Returns
-------
......@@ -340,28 +1057,28 @@ def mol_to_complete_graph(mol, add_self_loop=False,
return mol_to_graph(mol, partial(construct_complete_graph_from_mol, add_self_loop=add_self_loop),
atom_featurizer, bond_featurizer)
def smile_to_complete_graph(smile, add_self_loop=False,
atom_featurizer=None,
bond_featurizer=None):
def smiles_to_complete_graph(smiles, add_self_loop=False,
atom_featurizer=None,
bond_featurizer=None):
"""Convert a SMILES into a complete DGLGraph and featurize for it.
Parameters
----------
smile : str
smiles : str
String of SMILES
add_self_loop : bool
Whether to add self loops in DGLGraphs.
Whether to add self loops in DGLGraphs. Default to False.
atom_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for atoms in a molecule, which can be used to update
ndata for a DGLGraph. Default to CanonicalAtomFeaturizer().
ndata for a DGLGraph. Default to None.
bond_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for bonds in a molecule, which can be used to update
edata for a DGLGraph.
edata for a DGLGraph. Default to None.
Returns
-------
g : DGLGraph
Complete DGLGraph for the molecule
"""
mol = Chem.MolFromSmiles(smile)
mol = Chem.MolFromSmiles(smiles)
return mol_to_complete_graph(mol, add_self_loop, atom_featurizer, bond_featurizer)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment