Unverified Commit 577cf2e6 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[Model Zoo] Refactor and Add Utils for Chemistry (#928)

* Refactor

* Add note

* Update

* CI
parent 0bf3b6dd
...@@ -142,19 +142,66 @@ Molecular Graphs ...@@ -142,19 +142,66 @@ Molecular Graphs
To work on molecular graphs, make sure you have installed `RDKit 2018.09.3 <https://www.rdkit.org/docs/Install.html>`__. To work on molecular graphs, make sure you have installed `RDKit 2018.09.3 <https://www.rdkit.org/docs/Install.html>`__.
Featurization Featurization Utils
````````````` ```````````````````
For the use of graph neural networks, we need to featurize nodes (atoms) and edges (bonds). Below we list some For the use of graph neural networks, we need to featurize nodes (atoms) and edges (bonds).
featurization methods/utilities:
General utils:
.. autosummary:: .. autosummary::
:toctree: ../../generated/ :toctree: ../../generated/
chem.one_hot_encoding chem.one_hot_encoding
chem.ConcatFeaturizer
chem.ConcatFeaturizer.__call__
Utils for atom featurization:
.. autosummary::
:toctree: ../../generated/
chem.atom_type_one_hot
chem.atomic_number_one_hot
chem.atomic_number
chem.atom_degree_one_hot
chem.atom_degree
chem.atom_total_degree_one_hot
chem.atom_total_degree
chem.atom_implicit_valence_one_hot
chem.atom_implicit_valence
chem.atom_hybridization_one_hot
chem.atom_total_num_H_one_hot
chem.atom_total_num_H
chem.atom_formal_charge_one_hot
chem.atom_formal_charge
chem.atom_num_radical_electrons_one_hot
chem.atom_num_radical_electrons
chem.atom_is_aromatic_one_hot
chem.atom_is_aromatic
chem.atom_chiral_tag_one_hot
chem.atom_mass
chem.BaseAtomFeaturizer chem.BaseAtomFeaturizer
chem.BaseAtomFeaturizer.feat_size
chem.BaseAtomFeaturizer.__call__
chem.CanonicalAtomFeaturizer chem.CanonicalAtomFeaturizer
Utils for bond featurization:
.. autosummary::
:toctree: ../../generated/
chem.bond_type_one_hot
chem.bond_is_conjugated_one_hot
chem.bond_is_conjugated
chem.bond_is_in_ring_one_hot
chem.bond_is_in_ring
chem.bond_stereo_one_hot
chem.BaseBondFeaturizer
chem.BaseBondFeaturizer.feat_size
chem.BaseBondFeaturizer.__call__
chem.CanonicalBondFeaturizer
Graph Construction Graph Construction
`````````````````` ``````````````````
...@@ -164,9 +211,9 @@ Several methods for constructing DGLGraphs from SMILES/RDKit molecule objects ar ...@@ -164,9 +211,9 @@ Several methods for constructing DGLGraphs from SMILES/RDKit molecule objects ar
:toctree: ../../generated/ :toctree: ../../generated/
chem.mol_to_graph chem.mol_to_graph
chem.smile_to_bigraph chem.smiles_to_bigraph
chem.mol_to_bigraph chem.mol_to_bigraph
chem.smile_to_complete_graph chem.smiles_to_complete_graph
chem.mol_to_complete_graph chem.mol_to_complete_graph
Dataset Classes Dataset Classes
......
...@@ -448,7 +448,7 @@ def get_atom_and_bond_types(smiles, log=True): ...@@ -448,7 +448,7 @@ def get_atom_and_bond_types(smiles, log=True):
for i, s in enumerate(smiles): for i, s in enumerate(smiles):
if log: if log:
print('Processing smile {:d}/{:d}'.format(i + 1, n_smiles)) print('Processing smiles {:d}/{:d}'.format(i + 1, n_smiles))
mol = smiles_to_standard_mol(s) mol = smiles_to_standard_mol(s)
if mol is None: if mol is None:
...@@ -517,7 +517,7 @@ def eval_decisions(env, decisions): ...@@ -517,7 +517,7 @@ def eval_decisions(env, decisions):
return env.get_current_smiles() return env.get_current_smiles()
def get_DGMG_smile(env, mol): def get_DGMG_smile(env, mol):
"""Mimics the reproduced SMILE with DGMG for a molecule. """Mimics the reproduced SMILES with DGMG for a molecule.
Given a molecule, we are interested in what SMILES we will Given a molecule, we are interested in what SMILES we will
get if we want to generate it with DGMG. This is an important get if we want to generate it with DGMG. This is an important
......
import numpy as np
import torch import torch
from torch.nn import BCEWithLogitsLoss from torch.nn import BCEWithLogitsLoss
from torch.optim import Adam from torch.optim import Adam
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from dgl import model_zoo from dgl import model_zoo
from dgl.data.utils import split_dataset
from utils import Meter, EarlyStopping, collate_molgraphs_for_classification, set_random_seed from utils import Meter, EarlyStopping, collate_molgraphs, set_random_seed, \
load_dataset_for_classification
def run_a_train_epoch(args, epoch, model, data_loader, loss_criterion, optimizer): def run_a_train_epoch(args, epoch, model, data_loader, loss_criterion, optimizer):
model.train() model.train()
train_meter = Meter() train_meter = Meter()
for batch_id, batch_data in enumerate(data_loader): for batch_id, batch_data in enumerate(data_loader):
smiles, bg, labels, mask = batch_data smiles, bg, labels, masks = batch_data
atom_feats = bg.ndata.pop(args['atom_data_field']) atom_feats = bg.ndata.pop(args['atom_data_field'])
atom_feats, labels, mask = atom_feats.to(args['device']), \ atom_feats, labels, masks = atom_feats.to(args['device']), \
labels.to(args['device']), \ labels.to(args['device']), \
mask.to(args['device']) masks.to(args['device'])
logits = model(bg, atom_feats) logits = model(bg, atom_feats)
# Mask non-existing labels # Mask non-existing labels
loss = (loss_criterion(logits, labels) * (mask != 0).float()).mean() loss = (loss_criterion(logits, labels) * (masks != 0).float()).mean()
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
optimizer.step() optimizer.step()
print('epoch {:d}/{:d}, batch {:d}/{:d}, loss {:.4f}'.format( print('epoch {:d}/{:d}, batch {:d}/{:d}, loss {:.4f}'.format(
epoch + 1, args['num_epochs'], batch_id + 1, len(data_loader), loss.item())) epoch + 1, args['num_epochs'], batch_id + 1, len(data_loader), loss.item()))
train_meter.update(logits, labels, mask) train_meter.update(logits, labels, masks)
train_roc_auc = train_meter.roc_auc_averaged_over_tasks() train_score = np.mean(train_meter.compute_metric(args['metric_name']))
print('epoch {:d}/{:d}, training roc-auc score {:.4f}'.format( print('epoch {:d}/{:d}, training {} {:.4f}'.format(
epoch + 1, args['num_epochs'], train_roc_auc)) epoch + 1, args['num_epochs'], args['metric_name'], train_score))
def run_an_eval_epoch(args, model, data_loader): def run_an_eval_epoch(args, model, data_loader):
model.eval() model.eval()
eval_meter = Meter() eval_meter = Meter()
with torch.no_grad(): with torch.no_grad():
for batch_id, batch_data in enumerate(data_loader): for batch_id, batch_data in enumerate(data_loader):
smiles, bg, labels, mask = batch_data smiles, bg, labels, masks = batch_data
atom_feats = bg.ndata.pop(args['atom_data_field']) atom_feats = bg.ndata.pop(args['atom_data_field'])
atom_feats, labels = atom_feats.to(args['device']), labels.to(args['device']) atom_feats, labels = atom_feats.to(args['device']), labels.to(args['device'])
logits = model(bg, atom_feats) logits = model(bg, atom_feats)
eval_meter.update(logits, labels, mask) eval_meter.update(logits, labels, masks)
return eval_meter.roc_auc_averaged_over_tasks() return np.mean(eval_meter.compute_metric(args['metric_name']))
def main(args): def main(args):
args['device'] = "cuda" if torch.cuda.is_available() else "cpu" args['device'] = "cuda" if torch.cuda.is_available() else "cpu"
set_random_seed() set_random_seed()
# Interchangeable with other datasets # Interchangeable with other datasets
if args['dataset'] == 'Tox21': dataset, train_set, val_set, test_set = load_dataset_for_classification(args)
from dgl.data.chem import Tox21 train_loader = DataLoader(train_set, batch_size=args['batch_size'],
dataset = Tox21() collate_fn=collate_molgraphs)
val_loader = DataLoader(val_set, batch_size=args['batch_size'],
trainset, valset, testset = split_dataset(dataset, args['train_val_test_split']) collate_fn=collate_molgraphs)
train_loader = DataLoader(trainset, batch_size=args['batch_size'], test_loader = DataLoader(test_set, batch_size=args['batch_size'],
collate_fn=collate_molgraphs_for_classification) collate_fn=collate_molgraphs)
val_loader = DataLoader(valset, batch_size=args['batch_size'],
collate_fn=collate_molgraphs_for_classification)
test_loader = DataLoader(testset, batch_size=args['batch_size'],
collate_fn=collate_molgraphs_for_classification)
if args['pre_trained']: if args['pre_trained']:
args['num_epochs'] = 0 args['num_epochs'] = 0
...@@ -87,17 +84,18 @@ def main(args): ...@@ -87,17 +84,18 @@ def main(args):
run_a_train_epoch(args, epoch, model, train_loader, loss_criterion, optimizer) run_a_train_epoch(args, epoch, model, train_loader, loss_criterion, optimizer)
# Validation and early stop # Validation and early stop
val_roc_auc = run_an_eval_epoch(args, model, val_loader) val_score = run_an_eval_epoch(args, model, val_loader)
early_stop = stopper.step(val_roc_auc, model) early_stop = stopper.step(val_score, model)
print('epoch {:d}/{:d}, validation roc-auc score {:.4f}, best validation roc-auc score {:.4f}'.format( print('epoch {:d}/{:d}, validation {} {:.4f}, best validation {} {:.4f}'.format(
epoch + 1, args['num_epochs'], val_roc_auc, stopper.best_score)) epoch + 1, args['num_epochs'], args['metric_name'],
val_score, args['metric_name'], stopper.best_score))
if early_stop: if early_stop:
break break
if not args['pre_trained']: if not args['pre_trained']:
stopper.load_checkpoint(model) stopper.load_checkpoint(model)
test_roc_auc = run_an_eval_epoch(args, model, test_loader) test_score = run_an_eval_epoch(args, model, test_loader)
print('test roc-auc score {:.4f}'.format(test_roc_auc)) print('test {} {:.4f}'.format(args['metric_name'], test_score))
if __name__ == '__main__': if __name__ == '__main__':
import argparse import argparse
......
from dgl.data.chem import CanonicalAtomFeaturizer
GCN_Tox21 = { GCN_Tox21 = {
'batch_size': 128, 'batch_size': 128,
'lr': 1e-3, 'lr': 1e-3,
...@@ -7,7 +9,9 @@ GCN_Tox21 = { ...@@ -7,7 +9,9 @@ GCN_Tox21 = {
'in_feats': 74, 'in_feats': 74,
'gcn_hidden_feats': [64, 64], 'gcn_hidden_feats': [64, 64],
'classifier_hidden_feats': 64, 'classifier_hidden_feats': 64,
'patience': 10 'patience': 10,
'atom_featurizer': CanonicalAtomFeaturizer(),
'metric_name': 'roc_auc'
} }
GAT_Tox21 = { GAT_Tox21 = {
...@@ -20,15 +24,20 @@ GAT_Tox21 = { ...@@ -20,15 +24,20 @@ GAT_Tox21 = {
'gat_hidden_feats': [32, 32], 'gat_hidden_feats': [32, 32],
'classifier_hidden_feats': 64, 'classifier_hidden_feats': 64,
'num_heads': [4, 4], 'num_heads': [4, 4],
'patience': 10 'patience': 10,
'atom_featurizer': CanonicalAtomFeaturizer(),
'metric_name': 'roc_auc'
} }
MPNN_Alchemy = { MPNN_Alchemy = {
'batch_size': 16, 'batch_size': 16,
'num_epochs': 250, 'num_epochs': 250,
'node_in_feats': 15,
'edge_in_feats': 5,
'output_dim': 12, 'output_dim': 12,
'lr': 0.0001, 'lr': 0.0001,
'patience': 50 'patience': 50,
'metric_name': 'l1'
} }
SCHNET_Alchemy = { SCHNET_Alchemy = {
...@@ -37,7 +46,8 @@ SCHNET_Alchemy = { ...@@ -37,7 +46,8 @@ SCHNET_Alchemy = {
'norm': True, 'norm': True,
'output_dim': 12, 'output_dim': 12,
'lr': 0.0001, 'lr': 0.0001,
'patience': 50 'patience': 50,
'metric_name': 'l1'
} }
MGCN_Alchemy = { MGCN_Alchemy = {
...@@ -46,7 +56,8 @@ MGCN_Alchemy = { ...@@ -46,7 +56,8 @@ MGCN_Alchemy = {
'norm': True, 'norm': True,
'output_dim': 12, 'output_dim': 12,
'lr': 0.0001, 'lr': 0.0001,
'patience': 50 'patience': 50,
'metric_name': 'l1'
} }
experiment_configures = { experiment_configures = {
......
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from dgl import model_zoo from dgl import model_zoo
from utils import set_random_seed, collate_molgraphs_for_regression, EarlyStopping from utils import Meter, set_random_seed, collate_molgraphs, EarlyStopping, \
load_dataset_for_regression
def regress(args, model, bg): def regress(args, model, bg):
if args['model'] == 'MPNN': if args['model'] == 'MPNN':
...@@ -20,36 +22,35 @@ def regress(args, model, bg): ...@@ -20,36 +22,35 @@ def regress(args, model, bg):
return model(bg, node_types, edge_distances) return model(bg, node_types, edge_distances)
def run_a_train_epoch(args, epoch, model, data_loader, def run_a_train_epoch(args, epoch, model, data_loader,
loss_criterion, score_criterion, optimizer): loss_criterion, optimizer):
model.train() model.train()
total_loss, total_score = 0, 0 train_meter = Meter()
total_loss = 0
for batch_id, batch_data in enumerate(data_loader): for batch_id, batch_data in enumerate(data_loader):
smiles, bg, labels = batch_data smiles, bg, labels, masks = batch_data
labels = labels.to(args['device']) labels, masks = labels.to(args['device']), masks.to(args['device'])
prediction = regress(args, model, bg) prediction = regress(args, model, bg)
loss = loss_criterion(prediction, labels) loss = (loss_criterion(prediction, labels) * (masks != 0).float()).mean()
score = score_criterion(prediction, labels)
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
optimizer.step() optimizer.step()
total_loss += loss.detach().item() * bg.batch_size total_loss += loss.detach().item() * bg.batch_size
total_score += score.detach().item() * bg.batch_size train_meter.update(prediction, labels, masks)
total_loss /= len(data_loader.dataset) total_loss /= len(data_loader.dataset)
total_score /= len(data_loader.dataset) total_score = np.mean(train_meter.compute_metric(args['metric_name']))
print('epoch {:d}/{:d}, training loss {:.4f}, training score {:.4f}'.format( print('epoch {:d}/{:d}, training loss {:.4f}, training {} {:.4f}'.format(
epoch + 1, args['num_epochs'], total_loss, total_score)) epoch + 1, args['num_epochs'], total_loss, args['metric_name'], total_score))
def run_an_eval_epoch(args, model, data_loader, score_criterion): def run_an_eval_epoch(args, model, data_loader):
model.eval() model.eval()
total_score = 0 eval_meter = Meter()
with torch.no_grad(): with torch.no_grad():
for batch_id, batch_data in enumerate(data_loader): for batch_id, batch_data in enumerate(data_loader):
smiles, bg, labels = batch_data smiles, bg, labels, masks = batch_data
labels = labels.to(args['device']) labels = labels.to(args['device'])
prediction = regress(args, model, bg) prediction = regress(args, model, bg)
score = score_criterion(prediction, labels) eval_meter.update(prediction, labels, masks)
total_score += score.detach().item() * bg.batch_size total_score = np.mean(eval_meter.compute_metric(args['metric_name']))
total_score /= len(data_loader.dataset)
return total_score return total_score
def main(args): def main(args):
...@@ -57,20 +58,22 @@ def main(args): ...@@ -57,20 +58,22 @@ def main(args):
set_random_seed() set_random_seed()
# Interchangeable with other datasets # Interchangeable with other datasets
if args['dataset'] == 'Alchemy': train_set, val_set, test_set = load_dataset_for_regression(args)
from dgl.data.chem import TencentAlchemyDataset
train_set = TencentAlchemyDataset(mode='dev')
val_set = TencentAlchemyDataset(mode='valid')
train_loader = DataLoader(dataset=train_set, train_loader = DataLoader(dataset=train_set,
batch_size=args['batch_size'], batch_size=args['batch_size'],
collate_fn=collate_molgraphs_for_regression) collate_fn=collate_molgraphs)
val_loader = DataLoader(dataset=val_set, val_loader = DataLoader(dataset=val_set,
batch_size=args['batch_size'], batch_size=args['batch_size'],
collate_fn=collate_molgraphs_for_regression) collate_fn=collate_molgraphs)
if test_set is not None:
test_loader = DataLoader(dataset=test_set,
batch_size=args['batch_size'],
collate_fn=collate_molgraphs)
if args['model'] == 'MPNN': if args['model'] == 'MPNN':
model = model_zoo.chem.MPNNModel(output_dim=args['output_dim']) model = model_zoo.chem.MPNNModel(node_input_dim=args['node_in_feats'],
edge_input_dim=args['edge_in_feats'],
output_dim=args['output_dim'])
elif args['model'] == 'SCHNET': elif args['model'] == 'SCHNET':
model = model_zoo.chem.SchNet(norm=args['norm'], output_dim=args['output_dim']) model = model_zoo.chem.SchNet(norm=args['norm'], output_dim=args['output_dim'])
model.set_mean_std(train_set.mean, train_set.std, args['device']) model.set_mean_std(train_set.mean, train_set.std, args['device'])
...@@ -79,23 +82,28 @@ def main(args): ...@@ -79,23 +82,28 @@ def main(args):
model.set_mean_std(train_set.mean, train_set.std, args['device']) model.set_mean_std(train_set.mean, train_set.std, args['device'])
model.to(args['device']) model.to(args['device'])
loss_fn = nn.MSELoss() loss_fn = nn.MSELoss(reduction='none')
score_fn = nn.L1Loss()
optimizer = torch.optim.Adam(model.parameters(), lr=args['lr']) optimizer = torch.optim.Adam(model.parameters(), lr=args['lr'])
stopper = EarlyStopping(mode='lower', patience=args['patience']) stopper = EarlyStopping(mode='lower', patience=args['patience'])
for epoch in range(args['num_epochs']): for epoch in range(args['num_epochs']):
# Train # Train
run_a_train_epoch(args, epoch, model, train_loader, loss_fn, score_fn, optimizer) run_a_train_epoch(args, epoch, model, train_loader, loss_fn, optimizer)
# Validation and early stop # Validation and early stop
val_score = run_an_eval_epoch(args, model, val_loader, score_fn) val_score = run_an_eval_epoch(args, model, val_loader)
early_stop = stopper.step(val_score, model) early_stop = stopper.step(val_score, model)
print('epoch {:d}/{:d}, validation score {:.4f}, best validation score {:.4f}'.format( print('epoch {:d}/{:d}, validation {} {:.4f}, best validation {} {:.4f}'.format(
epoch + 1, args['num_epochs'], val_score, stopper.best_score)) epoch + 1, args['num_epochs'], args['metric_name'], val_score,
args['metric_name'], stopper.best_score))
if early_stop: if early_stop:
break break
if test_set is not None:
stopper.load_checkpoint(model)
test_score = run_an_eval_epoch(args, model, test_loader)
print('test {} {:.4f}'.format(args['metric_name'], test_score))
if __name__ == "__main__": if __name__ == "__main__":
import argparse import argparse
......
import datetime import datetime
import dgl import dgl
import math
import numpy as np import numpy as np
import random import random
import torch import torch
from sklearn.metrics import roc_auc_score import torch.nn.functional as F
from dgl.data.utils import split_dataset
from sklearn.metrics import roc_auc_score, mean_squared_error
def set_random_seed(seed=0): def set_random_seed(seed=0):
"""Set random seed. """Set random seed.
...@@ -45,13 +49,13 @@ class Meter(object): ...@@ -45,13 +49,13 @@ class Meter(object):
self.y_true.append(y_true.detach().cpu()) self.y_true.append(y_true.detach().cpu())
self.mask.append(mask.detach().cpu()) self.mask.append(mask.detach().cpu())
def roc_auc_averaged_over_tasks(self): def roc_auc_score(self):
"""Compute roc-auc score for each task and return the average. """Compute roc-auc score for each task.
Returns Returns
------- -------
float list of float
roc-auc score averaged over all tasks roc-auc score for all tasks
""" """
mask = torch.cat(self.mask, dim=0) mask = torch.cat(self.mask, dim=0)
y_pred = torch.cat(self.y_pred, dim=0) y_pred = torch.cat(self.y_pred, dim=0)
...@@ -60,13 +64,83 @@ class Meter(object): ...@@ -60,13 +64,83 @@ class Meter(object):
# This assumes binary case only # This assumes binary case only
y_pred = torch.sigmoid(y_pred) y_pred = torch.sigmoid(y_pred)
n_tasks = y_true.shape[1] n_tasks = y_true.shape[1]
total_score = 0 scores = []
for task in range(n_tasks):
task_w = mask[:, task]
task_y_true = y_true[:, task][task_w != 0].numpy()
task_y_pred = y_pred[:, task][task_w != 0].numpy()
scores.append(roc_auc_score(task_y_true, task_y_pred))
return scores
def l1_loss(self, reduction):
"""Compute l1 loss for each task.
Returns
-------
list of float
l1 loss for all tasks
reduction : str
* 'mean': average the metric over all labeled data points for each task
* 'sum': sum the metric over all labeled data points for each task
"""
mask = torch.cat(self.mask, dim=0)
y_pred = torch.cat(self.y_pred, dim=0)
y_true = torch.cat(self.y_true, dim=0)
n_tasks = y_true.shape[1]
scores = []
for task in range(n_tasks):
task_w = mask[:, task]
task_y_true = y_true[:, task][task_w != 0]
task_y_pred = y_pred[:, task][task_w != 0]
scores.append(F.l1_loss(task_y_true, task_y_pred, reduction=reduction).item())
return scores
def rmse(self):
"""Compute RMSE for each task.
Returns
-------
list of float
rmse for all tasks
"""
mask = torch.cat(self.mask, dim=0)
y_pred = torch.cat(self.y_pred, dim=0)
y_true = torch.cat(self.y_true, dim=0)
n_data, n_tasks = y_true.shape
scores = []
for task in range(n_tasks): for task in range(n_tasks):
task_w = mask[:, task] task_w = mask[:, task]
task_y_true = y_true[:, task][task_w != 0].numpy() task_y_true = y_true[:, task][task_w != 0].numpy()
task_y_pred = y_pred[:, task][task_w != 0].numpy() task_y_pred = y_pred[:, task][task_w != 0].numpy()
total_score += roc_auc_score(task_y_true, task_y_pred) scores.append(math.sqrt(mean_squared_error(task_y_true, task_y_pred)))
return total_score / n_tasks return scores
def compute_metric(self, metric_name, reduction='mean'):
"""Compute metric for each task.
Parameters
----------
metric_name : str
Name for the metric to compute.
reduction : str
Only comes into effect when the metric_name is l1_loss.
* 'mean': average the metric over all labeled data points for each task
* 'sum': sum the metric over all labeled data points for each task
Returns
-------
list of float
Metric value for each task
"""
assert metric_name in ['roc_auc', 'l1', 'rmse'], \
'Expect metric name to be "roc_auc", "l1" or "rmse", got {}'.format(metric_name)
assert reduction in ['mean', 'sum']
if metric_name == 'roc_auc':
return self.roc_auc_score()
if metric_name == 'l1':
return self.l1_loss(reduction)
if metric_name == 'rmse':
return self.rmse()
class EarlyStopping(object): class EarlyStopping(object):
"""Early stop performing """Early stop performing
...@@ -131,14 +205,15 @@ class EarlyStopping(object): ...@@ -131,14 +205,15 @@ class EarlyStopping(object):
'''Load model saved with early stopping.''' '''Load model saved with early stopping.'''
model.load_state_dict(torch.load(self.filename)['model_state_dict']) model.load_state_dict(torch.load(self.filename)['model_state_dict'])
def collate_molgraphs_for_classification(data): def collate_molgraphs(data):
"""Batching a list of datapoints for dataloader in classification tasks. """Batching a list of datapoints for dataloader.
Parameters Parameters
---------- ----------
data : list of 4-tuples data : list of 3-tuples or 4-tuples.
Each tuple is for a single datapoint, consisting of Each tuple is for a single datapoint, consisting of
a SMILE, a DGLGraph, all-task labels and all-task weights a SMILES, a DGLGraph, all-task labels and optionally
a binary mask indicating the existence of labels.
Returns Returns
------- -------
...@@ -149,40 +224,79 @@ def collate_molgraphs_for_classification(data): ...@@ -149,40 +224,79 @@ def collate_molgraphs_for_classification(data):
labels : Tensor of dtype float32 and shape (B, T) labels : Tensor of dtype float32 and shape (B, T)
Batched datapoint labels. B is len(data) and Batched datapoint labels. B is len(data) and
T is the number of total tasks. T is the number of total tasks.
weights : Tensor of dtype float32 and shape (B, T) masks : Tensor of dtype float32 and shape (B, T)
Batched datapoint weights. T is the number of Batched datapoint binary mask, indicating the
total tasks. existence of labels. If binary masks are not
provided, return a tensor with ones.
""" """
smiles, graphs, labels, mask = map(list, zip(*data)) assert len(data[0]) in [3, 4], \
'Expect the tuple to be of length 3 or 4, got {:d}'.format(len(data[0]))
if len(data[0]) == 3:
smiles, graphs, labels = map(list, zip(*data))
masks = None
else:
smiles, graphs, labels, masks = map(list, zip(*data))
bg = dgl.batch(graphs) bg = dgl.batch(graphs)
bg.set_n_initializer(dgl.init.zero_initializer) bg.set_n_initializer(dgl.init.zero_initializer)
bg.set_e_initializer(dgl.init.zero_initializer) bg.set_e_initializer(dgl.init.zero_initializer)
labels = torch.stack(labels, dim=0) labels = torch.stack(labels, dim=0)
mask = torch.stack(mask, dim=0)
return smiles, bg, labels, mask
def collate_molgraphs_for_regression(data): if masks is None:
"""Batching a list of datapoints for dataloader in regression tasks. masks = torch.ones(labels.shape)
else:
masks = torch.stack(masks, dim=0)
return smiles, bg, labels, masks
def load_dataset_for_classification(args):
"""Load dataset for classification tasks.
Parameters Parameters
---------- ----------
data : list of 3-tuples args : dict
Each tuple is for a single datapoint, consisting of Configurations.
a SMILE, a DGLGraph and all-task labels.
Returns Returns
------- -------
smiles : list dataset
List of smiles The whole dataset.
bg : BatchedDGLGraph train_set
Batched DGLGraphs Subset for training.
labels : Tensor of dtype float32 and shape (B, T) val_set
Batched datapoint labels. B is len(data) and Subset for validation.
T is the number of total tasks. test_set
Subset for test.
""" """
smiles, graphs, labels = map(list, zip(*data)) assert args['dataset'] in ['Tox21']
bg = dgl.batch(graphs) if args['dataset'] == 'Tox21':
bg.set_n_initializer(dgl.init.zero_initializer) from dgl.data.chem import Tox21
bg.set_e_initializer(dgl.init.zero_initializer) dataset = Tox21(atom_featurizer=args['atom_featurizer'])
labels = torch.stack(labels, dim=0) train_set, val_set, test_set = split_dataset(dataset, args['train_val_test_split'])
return smiles, bg, labels
return dataset, train_set, val_set, test_set
def load_dataset_for_regression(args):
"""Load dataset for regression tasks.
Parameters
----------
args : dict
Configurations.
Returns
-------
train_set
Subset for training.
val_set
Subset for validation.
test_set
Subset for test.
"""
assert args['dataset'] in ['Alchemy']
if args['dataset'] == 'Alchemy':
from dgl.data.chem import TencentAlchemyDataset
train_set = TencentAlchemyDataset(mode='dev')
val_set = TencentAlchemyDataset(mode='valid')
test_set = None
return train_set, val_set, test_set
...@@ -10,7 +10,8 @@ import pickle ...@@ -10,7 +10,8 @@ import pickle
import zipfile import zipfile
from collections import defaultdict from collections import defaultdict
from .utils import mol_to_complete_graph from .utils import mol_to_complete_graph, atom_type_one_hot, atom_hybridization_one_hot, \
atom_is_aromatic
from ..utils import download, get_download_dir, _get_dgl_url, save_graphs, load_graphs from ..utils import download, get_download_dir, _get_dgl_url, save_graphs, load_graphs
from ... import backend as F from ... import backend as F
...@@ -59,25 +60,19 @@ def alchemy_nodes(mol): ...@@ -59,25 +60,19 @@ def alchemy_nodes(mol):
num_atoms = mol.GetNumAtoms() num_atoms = mol.GetNumAtoms()
for u in range(num_atoms): for u in range(num_atoms):
atom = mol.GetAtomWithIdx(u) atom = mol.GetAtomWithIdx(u)
symbol = atom.GetSymbol()
atom_type = atom.GetAtomicNum() atom_type = atom.GetAtomicNum()
aromatic = atom.GetIsAromatic()
hybridization = atom.GetHybridization()
num_h = atom.GetTotalNumHs() num_h = atom.GetTotalNumHs()
atom_feats_dict['node_type'].append(atom_type) atom_feats_dict['node_type'].append(atom_type)
h_u = [] h_u = []
h_u += [int(symbol == x) for x in ['H', 'C', 'N', 'O', 'F', 'S', 'Cl']] h_u += atom_type_one_hot(atom, ['H', 'C', 'N', 'O', 'F', 'S', 'Cl'])
h_u.append(atom_type) h_u.append(atom_type)
h_u.append(is_acceptor[u]) h_u.append(is_acceptor[u])
h_u.append(is_donor[u]) h_u.append(is_donor[u])
h_u.append(int(aromatic)) h_u += atom_is_aromatic(atom)
h_u += [ h_u += atom_hybridization_one_hot(atom, [Chem.rdchem.HybridizationType.SP,
int(hybridization == x) Chem.rdchem.HybridizationType.SP2,
for x in (Chem.rdchem.HybridizationType.SP, Chem.rdchem.HybridizationType.SP3])
Chem.rdchem.HybridizationType.SP2,
Chem.rdchem.HybridizationType.SP3)
]
h_u.append(num_h) h_u.append(num_h)
atom_feats_dict['n_feat'].append(F.tensor(np.array(h_u).astype(np.float32))) atom_feats_dict['n_feat'].append(F.tensor(np.array(h_u).astype(np.float32)))
...@@ -155,9 +150,34 @@ class TencentAlchemyDataset(object): ...@@ -155,9 +150,34 @@ class TencentAlchemyDataset(object):
contest is ongoing. contest is ongoing.
from_raw : bool from_raw : bool
Whether to process the dataset from scratch or use a Whether to process the dataset from scratch or use a
processed one for faster speed. Default to be False. processed one for faster speed. If you use different ways
to featurize atoms or bonds, you should set this to be True.
Default to be False.
mol_to_graph: callable, str -> DGLGraph
A function turning an RDKit molecule instance into a DGLGraph.
Default to :func:`dgl.data.chem.mol_to_complete_graph`.
atom_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for atoms in a molecule, which can be used to update
ndata for a DGLGraph. By default, we store the atom atomic numbers
under the name ``"node_type"`` and store the atom features under the
name ``"n_feat"``. The atom features include:
* One hot encoding for atom types
* Atomic number of atoms
* Whether the atom is a donor
* Whether the atom is an acceptor
* Whether the atom is aromatic
* One hot encoding for atom hybridization
* Total number of Hs on the atom
bond_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for bonds in a molecule, which can be used to update
edata for a DGLGraph. By default, we store the distance between the
end atoms under the name ``"distance"`` and store the bond features under
the name ``"e_feat"``. The bond features are one-hot encodings of the bond type.
""" """
def __init__(self, mode='dev', from_raw=False): def __init__(self, mode='dev', from_raw=False,
mol_to_graph=mol_to_complete_graph,
atom_featurizer=alchemy_nodes,
bond_featurizer=alchemy_edges):
if mode == 'test': if mode == 'test':
raise ValueError('The test mode is not supported before ' raise ValueError('The test mode is not supported before '
'the Alchemy contest finishes.') 'the Alchemy contest finishes.')
...@@ -185,9 +205,9 @@ class TencentAlchemyDataset(object): ...@@ -185,9 +205,9 @@ class TencentAlchemyDataset(object):
archive.extractall(file_dir) archive.extractall(file_dir)
archive.close() archive.close()
self._load() self._load(mol_to_graph, atom_featurizer, bond_featurizer)
def _load(self): def _load(self, mol_to_graph, atom_featurizer, bond_featurizer):
if not self.from_raw: if not self.from_raw:
self.graphs, label_dict = load_graphs(osp.join(self.file_dir, "%s_graphs.bin" % self.mode)) self.graphs, label_dict = load_graphs(osp.join(self.file_dir, "%s_graphs.bin" % self.mode))
self.labels = label_dict['labels'] self.labels = label_dict['labels']
...@@ -210,8 +230,8 @@ class TencentAlchemyDataset(object): ...@@ -210,8 +230,8 @@ class TencentAlchemyDataset(object):
for mol, label in zip(supp, self.target.iterrows()): for mol, label in zip(supp, self.target.iterrows()):
cnt += 1 cnt += 1
print('Processing molecule {:d}/{:d}'.format(cnt, dataset_size)) print('Processing molecule {:d}/{:d}'.format(cnt, dataset_size))
graph = mol_to_complete_graph(mol, atom_featurizer=alchemy_nodes, graph = mol_to_graph(mol, atom_featurizer=atom_featurizer,
bond_featurizer=alchemy_edges) bond_featurizer=bond_featurizer)
smiles = Chem.MolToSmiles(mol) smiles = Chem.MolToSmiles(mol)
self.smiles.append(smiles) self.smiles.append(smiles)
self.graphs.append(graph) self.graphs.append(graph)
......
...@@ -5,10 +5,8 @@ import numpy as np ...@@ -5,10 +5,8 @@ import numpy as np
import os import os
import sys import sys
from .utils import smile_to_bigraph
from ..utils import save_graphs, load_graphs from ..utils import save_graphs, load_graphs
from ... import backend as F from ... import backend as F
from ...graph import DGLGraph
class MoleculeCSVDataset(object): class MoleculeCSVDataset(object):
"""MoleculeCSVDataset """MoleculeCSVDataset
...@@ -27,28 +25,33 @@ class MoleculeCSVDataset(object): ...@@ -27,28 +25,33 @@ class MoleculeCSVDataset(object):
Dataframe including smiles and labels. Can be loaded by pandas.read_csv(file_path). Dataframe including smiles and labels. Can be loaded by pandas.read_csv(file_path).
One column includes smiles and other columns for labels. One column includes smiles and other columns for labels.
Column names other than smiles column would be considered as task names. Column names other than smiles column would be considered as task names.
smile_to_graph: callable, str -> DGLGraph smiles_to_graph: callable, str -> DGLGraph
A function turns smiles into a DGLGraph. Default one can be found A function turning a SMILES into a DGLGraph.
at python/dgl/data/chem/utils.py named with smile_to_bigraph. atom_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
smile_column: str Featurization for atoms in a molecule, which can be used to update
Column name that including smiles ndata for a DGLGraph.
bond_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for bonds in a molecule, which can be used to update
edata for a DGLGraph.
smiles_column: str
Column name that including smiles.
cache_file_path: str cache_file_path: str
Path to store the preprocessed data Path to store the preprocessed data.
""" """
def __init__(self, df, smile_to_graph=smile_to_bigraph, smile_column='smiles', def __init__(self, df, smiles_to_graph, atom_featurizer, bond_featurizer,
cache_file_path="csvdata_dglgraph.bin"): smiles_column, cache_file_path):
if 'rdkit' not in sys.modules: if 'rdkit' not in sys.modules:
from ...base import dgl_warning from ...base import dgl_warning
dgl_warning( dgl_warning(
"Please install RDKit (Recommended Version is 2018.09.3)") "Please install RDKit (Recommended Version is 2018.09.3)")
self.df = df self.df = df
self.smiles = self.df[smile_column].tolist() self.smiles = self.df[smiles_column].tolist()
self.task_names = self.df.columns.drop([smile_column]).tolist() self.task_names = self.df.columns.drop([smiles_column]).tolist()
self.n_tasks = len(self.task_names) self.n_tasks = len(self.task_names)
self.cache_file_path = cache_file_path self.cache_file_path = cache_file_path
self._pre_process(smile_to_graph) self._pre_process(smiles_to_graph, atom_featurizer, bond_featurizer)
def _pre_process(self, smile_to_graph): def _pre_process(self, smiles_to_graph, atom_featurizer, bond_featurizer):
"""Pre-process the dataset """Pre-process the dataset
* Convert molecules from smiles format into DGLGraphs * Convert molecules from smiles format into DGLGraphs
...@@ -58,8 +61,14 @@ class MoleculeCSVDataset(object): ...@@ -58,8 +61,14 @@ class MoleculeCSVDataset(object):
Parameters Parameters
---------- ----------
smile_to_graph : callable, SMILES -> DGLGraph smiles_to_graph : callable, SMILES -> DGLGraph
Function for converting a SMILES (str) into a DGLGraph Function for converting a SMILES (str) into a DGLGraph.
atom_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for atoms in a molecule, which can be used to update
ndata for a DGLGraph.
bond_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for bonds in a molecule, which can be used to update
edata for a DGLGraph.
""" """
if os.path.exists(self.cache_file_path): if os.path.exists(self.cache_file_path):
# DGLGraphs have been constructed before, reload them # DGLGraphs have been constructed before, reload them
...@@ -72,7 +81,8 @@ class MoleculeCSVDataset(object): ...@@ -72,7 +81,8 @@ class MoleculeCSVDataset(object):
self.graphs = [] self.graphs = []
for i, s in enumerate(self.smiles): for i, s in enumerate(self.smiles):
print('Processing molecule {:d}/{:d}'.format(i+1, len(self))) print('Processing molecule {:d}/{:d}'.format(i+1, len(self)))
self.graphs.append(smile_to_graph(s)) self.graphs.append(smiles_to_graph(s, atom_featurizer=atom_featurizer,
bond_featurizer=bond_featurizer))
_label_values = self.df[self.task_names].values _label_values = self.df[self.task_names].values
# np.nan_to_num will also turn inf into a very large number # np.nan_to_num will also turn inf into a very large number
self.labels = F.zerocopy_from_numpy(np.nan_to_num(_label_values).astype(np.float32)) self.labels = F.zerocopy_from_numpy(np.nan_to_num(_label_values).astype(np.float32))
......
...@@ -2,7 +2,7 @@ import numpy as np ...@@ -2,7 +2,7 @@ import numpy as np
import sys import sys
from .csv_dataset import MoleculeCSVDataset from .csv_dataset import MoleculeCSVDataset
from .utils import smile_to_bigraph from .utils import smiles_to_bigraph
from ..utils import get_download_dir, download, _get_dgl_url from ..utils import get_download_dir, download, _get_dgl_url
from ... import backend as F from ... import backend as F
...@@ -30,11 +30,19 @@ class Tox21(MoleculeCSVDataset): ...@@ -30,11 +30,19 @@ class Tox21(MoleculeCSVDataset):
Parameters Parameters
---------- ----------
smile_to_graph: callable, str -> DGLGraph smiles_to_graph: callable, str -> DGLGraph
A function turns smiles into a DGLGraph. Default one can be found A function turning smiles into a DGLGraph.
at python/dgl/data/chem/utils.py named with smile_to_bigraph. Default to :func:`dgl.data.chem.smiles_to_bigraph`.
atom_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for atoms in a molecule, which can be used to update
ndata for a DGLGraph. Default to None.
bond_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for bonds in a molecule, which can be used to update
edata for a DGLGraph. Default to None.
""" """
def __init__(self, smile_to_graph=smile_to_bigraph): def __init__(self, smiles_to_graph=smiles_to_bigraph,
atom_featurizer=None,
bond_featurizer=None):
if 'pandas' not in sys.modules: if 'pandas' not in sys.modules:
from ...base import dgl_warning from ...base import dgl_warning
dgl_warning("Please install pandas") dgl_warning("Please install pandas")
...@@ -47,10 +55,10 @@ class Tox21(MoleculeCSVDataset): ...@@ -47,10 +55,10 @@ class Tox21(MoleculeCSVDataset):
df = df.drop(columns=['mol_id']) df = df.drop(columns=['mol_id'])
super().__init__(df, smile_to_graph, cache_file_path="tox21_dglgraph.bin") super(Tox21, self).__init__(df, smiles_to_graph, atom_featurizer, bond_featurizer,
"smiles", "tox21_dglgraph.bin")
self._weight_balancing() self._weight_balancing()
def _weight_balancing(self): def _weight_balancing(self):
"""Perform re-balancing for each task. """Perform re-balancing for each task.
...@@ -71,7 +79,6 @@ class Tox21(MoleculeCSVDataset): ...@@ -71,7 +79,6 @@ class Tox21(MoleculeCSVDataset):
num_pos = F.sum(self.labels, dim=0) num_pos = F.sum(self.labels, dim=0)
num_indices = F.sum(self.mask, dim=0) num_indices = F.sum(self.mask, dim=0)
self._task_pos_weights = (num_indices - num_pos) / num_pos self._task_pos_weights = (num_indices - num_pos) / num_pos
@property @property
def task_pos_weights(self): def task_pos_weights(self):
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment