Unverified Commit 52c7ef49 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[Model Zoo] AttentiveFP (#955)

* Update

* Fix style

* Update

* Update

* Fix

* Update

* Update
parent 43028f70
...@@ -224,13 +224,17 @@ If your dataset is stored in a ``.csv`` file, you may find it helpful to use ...@@ -224,13 +224,17 @@ If your dataset is stored in a ``.csv`` file, you may find it helpful to use
.. autoclass:: dgl.data.chem.CSVDataset .. autoclass:: dgl.data.chem.CSVDataset
:members: __getitem__, __len__ :members: __getitem__, __len__
Currently two datasets are supported: Currently three datasets are supported:
* Tox21 * Tox21
* TencentAlchemyDataset * TencentAlchemyDataset
* PubChemBioAssayAromaticity
.. autoclass:: dgl.data.chem.Tox21 .. autoclass:: dgl.data.chem.Tox21
:members: __getitem__, __len__, task_pos_weights :members: __getitem__, __len__, task_pos_weights
.. autoclass:: dgl.data.chem.TencentAlchemyDataset .. autoclass:: dgl.data.chem.TencentAlchemyDataset
:members: __getitem__, __len__, set_mean_and_std :members: __getitem__, __len__, set_mean_and_std
.. autoclass:: dgl.data.chem.PubChemBioAssayAromaticity
:members: __getitem__, __len__
...@@ -26,6 +26,7 @@ Currently supported model architectures: ...@@ -26,6 +26,7 @@ Currently supported model architectures:
* MPNN * MPNN
* SchNet * SchNet
* MGCN * MGCN
* AttentiveFP
.. autoclass:: dgl.model_zoo.chem.GCNClassifier .. autoclass:: dgl.model_zoo.chem.GCNClassifier
:members: forward :members: forward
...@@ -42,6 +43,9 @@ Currently supported model architectures: ...@@ -42,6 +43,9 @@ Currently supported model architectures:
.. autoclass:: dgl.model_zoo.chem.MGCNModel .. autoclass:: dgl.model_zoo.chem.MGCNModel
:members: forward :members: forward
.. autoclass:: dgl.model_zoo.chem.AttentiveFP
:members: forward
Generative Models Generative Models
````````````````` `````````````````
......
...@@ -34,8 +34,8 @@ We use GPU whenever it is available. ...@@ -34,8 +34,8 @@ We use GPU whenever it is available.
#### GCN on Tox21 #### GCN on Tox21
| Source | Averaged ROC-AUC Score | | Source | Averaged Test ROC-AUC Score |
| ---------------- | ---------------------- | | ---------------- | --------------------------- |
| MoleculeNet [1] | 0.829 | | MoleculeNet [1] | 0.829 |
| [DeepChem example](https://github.com/deepchem/deepchem/blob/master/examples/tox21/tox21_tensorgraph_graph_conv.py) | 0.813 | | [DeepChem example](https://github.com/deepchem/deepchem/blob/master/examples/tox21/tox21_tensorgraph_graph_conv.py) | 0.813 |
| Pretrained model | 0.826 | | Pretrained model | 0.826 |
...@@ -45,8 +45,8 @@ a real difference. ...@@ -45,8 +45,8 @@ a real difference.
#### GAT on Tox21 #### GAT on Tox21
| Source | Averaged ROC-AUC Score | | Source | Averaged Test ROC-AUC Score |
| ---------------- | ---------------------- | | ---------------- | --------------------------- |
| Pretrained model | 0.827 | | Pretrained model | 0.827 |
## Regression ## Regression
...@@ -60,6 +60,10 @@ machine learning models useful for chemistry and materials science. The dataset ...@@ -60,6 +60,10 @@ machine learning models useful for chemistry and materials science. The dataset
molecules comprising up to 12 heavy atoms (C, N, O, S, F and Cl), sampled from the [GDBMedChem](http://gdb.unibe.ch/downloads/) database. molecules comprising up to 12 heavy atoms (C, N, O, S, F and Cl), sampled from the [GDBMedChem](http://gdb.unibe.ch/downloads/) database.
These properties have been calculated using the open-source computational chemistry program Python-based Simulation of Chemistry Framework These properties have been calculated using the open-source computational chemistry program Python-based Simulation of Chemistry Framework
([PySCF](https://github.com/pyscf/pyscf)). The Alchemy dataset expands on the volume and diversity of existing molecular datasets such as QM9. ([PySCF](https://github.com/pyscf/pyscf)). The Alchemy dataset expands on the volume and diversity of existing molecular datasets such as QM9.
- **PubChem BioAssay Aromaticity**. The dataset is introduced in
[Pushing the Boundaries of Molecular Representation for Drug Discovery with the Graph Attention Mechanism](https://www.ncbi.nlm.nih.gov/pubmed/31408336),
for the task of predicting the number of aromatic atoms in molecules. The dataset was constructed by sampling 3945 molecules with 0-40 aromatic atoms
from the PubChem BioAssay dataset.
### Models ### Models
...@@ -70,15 +74,20 @@ without requiring them to lie on grids. ...@@ -70,15 +74,20 @@ without requiring them to lie on grids.
- **Multilevel Graph Convolutional Neural Network** [5]. Multilevel Graph Convolutional Neural Networks (MGCN) are - **Multilevel Graph Convolutional Neural Network** [5]. Multilevel Graph Convolutional Neural Networks (MGCN) are
hierarchical graph neural networks that extract features from the conformation and spatial information followed by the hierarchical graph neural networks that extract features from the conformation and spatial information followed by the
multilevel interactions. multilevel interactions.
- **AttentiveFP** [8]. AttentiveFP combines attention and GRU for better model capacity and shows competitive
performance across datasetts.
### Usage ### Usage
Use `regression.py` with arguments Use `regression.py` with arguments
``` ```
-m {MPNN,SCHNET,MGCN}, Model to use -m {MPNN, SCHNET, MGCN, AttentiveFP}, Model to use
-d {Alchemy}, Dataset to use -d {Alchemy, Aromaticity}, Dataset to use
``` ```
If you want to use the pre-trained model, simply add `-p`. Currently we only support pre-trained models of AttentiveFP
on PubChem BioAssay Aromaticity dataset.
### Performance ### Performance
#### Alchemy #### Alchemy
...@@ -92,6 +101,20 @@ on the training and validation set for reference. ...@@ -92,6 +101,20 @@ on the training and validation set for reference.
| MGCN [5] | 0.2395 | 0.6463 | | MGCN [5] | 0.2395 | 0.6463 |
| MPNN [6] | 0.2452 | 0.6259 | | MPNN [6] | 0.2452 | 0.6259 |
#### PubChem BioAssay Aromaticity
| Model | Test RMSE |
| --------------- | --------- |
| AttentiveFP [8] | 0.6998 |
## Interpretation
[8] visualizes the weights of atoms in readout for possible interpretations like the figure below.
We provide a jupyter notebook for performing the visualization and you can download it with
`wget https://s3.us-east-2.amazonaws.com/dgl.ai/model_zoo/drug_discovery/AttentiveFP/atom_weight_visualization.ipynb`.
![](https://s3.us-east-2.amazonaws.com/dgl.ai/model_zoo/drug_discovery/AttentiveFP/vis_example.png)
## Dataset Customization ## Dataset Customization
To customize your own dataset, see the instructions To customize your own dataset, see the instructions
...@@ -117,3 +140,6 @@ Machine Learning*, JMLR. 1263-1272. ...@@ -117,3 +140,6 @@ Machine Learning*, JMLR. 1263-1272.
[7] Veličković et al. (2018) Graph Attention Networks. [7] Veličković et al. (2018) Graph Attention Networks.
*The International Conference on Learning Representations (ICLR)*. *The International Conference on Learning Representations (ICLR)*.
[8] Xiong et al. (2019) Pushing the Boundaries of Molecular Representation for Drug Discovery with the Graph
Attention Mechanism. *Journal of Medicinal Chemistry*.
...@@ -7,7 +7,7 @@ from torch.utils.data import DataLoader ...@@ -7,7 +7,7 @@ from torch.utils.data import DataLoader
from dgl import model_zoo from dgl import model_zoo
from utils import Meter, EarlyStopping, collate_molgraphs, set_random_seed, \ from utils import Meter, EarlyStopping, collate_molgraphs, set_random_seed, \
load_dataset_for_classification load_dataset_for_classification, load_model
def run_a_train_epoch(args, epoch, model, data_loader, loss_criterion, optimizer): def run_a_train_epoch(args, epoch, model, data_loader, loss_criterion, optimizer):
model.train() model.train()
...@@ -60,19 +60,8 @@ def main(args): ...@@ -60,19 +60,8 @@ def main(args):
args['num_epochs'] = 0 args['num_epochs'] = 0
model = model_zoo.chem.load_pretrained(args['exp']) model = model_zoo.chem.load_pretrained(args['exp'])
else: else:
# Interchangeable with other models args['n_tasks'] = dataset.n_tasks
if args['model'] == 'GCN': model = load_model(args)
model = model_zoo.chem.GCNClassifier(in_feats=args['in_feats'],
gcn_hidden_feats=args['gcn_hidden_feats'],
classifier_hidden_feats=args['classifier_hidden_feats'],
n_tasks=dataset.n_tasks)
elif args['model'] == 'GAT':
model = model_zoo.chem.GATClassifier(in_feats=args['in_feats'],
gat_hidden_feats=args['gat_hidden_feats'],
num_heads=args['num_heads'],
classifier_hidden_feats=args['classifier_hidden_feats'],
n_tasks=dataset.n_tasks)
loss_criterion = BCEWithLogitsLoss(pos_weight=dataset.task_pos_weights.to(args['device']), loss_criterion = BCEWithLogitsLoss(pos_weight=dataset.task_pos_weights.to(args['device']),
reduction='none') reduction='none')
optimizer = Adam(model.parameters(), lr=args['lr']) optimizer = Adam(model.parameters(), lr=args['lr'])
......
from dgl.data.chem import CanonicalAtomFeaturizer from dgl.data.chem import BaseAtomFeaturizer, CanonicalAtomFeaturizer, ConcatFeaturizer, \
atom_type_one_hot, atom_degree_one_hot, atom_formal_charge, atom_num_radical_electrons, \
atom_hybridization_one_hot, atom_total_num_H_one_hot, BaseBondFeaturizer
from functools import partial
from utils import chirality
GCN_Tox21 = { GCN_Tox21 = {
'batch_size': 128, 'batch_size': 128,
...@@ -37,7 +42,8 @@ MPNN_Alchemy = { ...@@ -37,7 +42,8 @@ MPNN_Alchemy = {
'output_dim': 12, 'output_dim': 12,
'lr': 0.0001, 'lr': 0.0001,
'patience': 50, 'patience': 50,
'metric_name': 'l1' 'metric_name': 'l1',
'weight_decay': 0
} }
SCHNET_Alchemy = { SCHNET_Alchemy = {
...@@ -47,7 +53,8 @@ SCHNET_Alchemy = { ...@@ -47,7 +53,8 @@ SCHNET_Alchemy = {
'output_dim': 12, 'output_dim': 12,
'lr': 0.0001, 'lr': 0.0001,
'patience': 50, 'patience': 50,
'metric_name': 'l1' 'metric_name': 'l1',
'weight_decay': 0
} }
MGCN_Alchemy = { MGCN_Alchemy = {
...@@ -57,7 +64,43 @@ MGCN_Alchemy = { ...@@ -57,7 +64,43 @@ MGCN_Alchemy = {
'output_dim': 12, 'output_dim': 12,
'lr': 0.0001, 'lr': 0.0001,
'patience': 50, 'patience': 50,
'metric_name': 'l1' 'metric_name': 'l1',
'weight_decay': 0
}
AttentiveFP_Aromaticity = {
'random_seed': 8,
'graph_feat_size': 200,
'num_layers': 2,
'num_timesteps': 2,
'node_feat_size': 39,
'edge_feat_size': 10,
'output_size': 1,
'dropout': 0.2,
'weight_decay': 10 ** (-5.0),
'lr': 10 ** (-2.5),
'batch_size': 128,
'num_epochs': 800,
'train_val_test_split': [0.8, 0.1, 0.1],
'patience': 80,
'metric_name': 'rmse',
# Follow the atom featurization in the original work
'atom_featurizer': BaseAtomFeaturizer(
featurizer_funcs={'hv': ConcatFeaturizer([
partial(atom_type_one_hot, allowable_set=[
'B', 'C', 'N', 'O', 'F', 'Si', 'P', 'S', 'Cl', 'As', 'Se', 'Br', 'Te', 'I', 'At'],
encode_unknown=True),
partial(atom_degree_one_hot, allowable_set=list(range(6))),
atom_formal_charge, atom_num_radical_electrons,
partial(atom_hybridization_one_hot, encode_unknown=True),
lambda atom: [0], # A placeholder for aromatic information,
atom_total_num_H_one_hot, chirality
],
)}
),
'bond_featurizer': BaseBondFeaturizer({
'he': lambda bond: [0 for _ in range(10)]
})
} }
experiment_configures = { experiment_configures = {
...@@ -65,7 +108,8 @@ experiment_configures = { ...@@ -65,7 +108,8 @@ experiment_configures = {
'GAT_Tox21': GAT_Tox21, 'GAT_Tox21': GAT_Tox21,
'MPNN_Alchemy': MPNN_Alchemy, 'MPNN_Alchemy': MPNN_Alchemy,
'SCHNET_Alchemy': SCHNET_Alchemy, 'SCHNET_Alchemy': SCHNET_Alchemy,
'MGCN_Alchemy': MGCN_Alchemy 'MGCN_Alchemy': MGCN_Alchemy,
'AttentiveFP_Aromaticity': AttentiveFP_Aromaticity
} }
def get_exp_configure(exp_name): def get_exp_configure(exp_name):
......
...@@ -6,7 +6,7 @@ from torch.utils.data import DataLoader ...@@ -6,7 +6,7 @@ from torch.utils.data import DataLoader
from dgl import model_zoo from dgl import model_zoo
from utils import Meter, set_random_seed, collate_molgraphs, EarlyStopping, \ from utils import Meter, set_random_seed, collate_molgraphs, EarlyStopping, \
load_dataset_for_regression load_dataset_for_regression, load_model
def regress(args, model, bg): def regress(args, model, bg):
if args['model'] == 'MPNN': if args['model'] == 'MPNN':
...@@ -14,18 +14,21 @@ def regress(args, model, bg): ...@@ -14,18 +14,21 @@ def regress(args, model, bg):
e = bg.edata.pop('e_feat') e = bg.edata.pop('e_feat')
h, e = h.to(args['device']), e.to(args['device']) h, e = h.to(args['device']), e.to(args['device'])
return model(bg, h, e) return model(bg, h, e)
else: elif args['model'] in ['SCHNET', 'MGCN']:
node_types = bg.ndata.pop('node_type') node_types = bg.ndata.pop('node_type')
edge_distances = bg.edata.pop('distance') edge_distances = bg.edata.pop('distance')
node_types, edge_distances = node_types.to(args['device']), \ node_types, edge_distances = node_types.to(args['device']), \
edge_distances.to(args['device']) edge_distances.to(args['device'])
return model(bg, node_types, edge_distances) return model(bg, node_types, edge_distances)
else:
atom_feats, bond_feats = bg.ndata.pop('hv'), bg.edata.pop('he')
atom_feats, bond_feats = atom_feats.to(args['device']), bond_feats.to(args['device'])
return model(bg, atom_feats, bond_feats)
def run_a_train_epoch(args, epoch, model, data_loader, def run_a_train_epoch(args, epoch, model, data_loader,
loss_criterion, optimizer): loss_criterion, optimizer):
model.train() model.train()
train_meter = Meter() train_meter = Meter()
total_loss = 0
for batch_id, batch_data in enumerate(data_loader): for batch_id, batch_data in enumerate(data_loader):
smiles, bg, labels, masks = batch_data smiles, bg, labels, masks = batch_data
labels, masks = labels.to(args['device']), masks.to(args['device']) labels, masks = labels.to(args['device']), masks.to(args['device'])
...@@ -34,12 +37,10 @@ def run_a_train_epoch(args, epoch, model, data_loader, ...@@ -34,12 +37,10 @@ def run_a_train_epoch(args, epoch, model, data_loader,
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
optimizer.step() optimizer.step()
total_loss += loss.detach().item() * bg.batch_size
train_meter.update(prediction, labels, masks) train_meter.update(prediction, labels, masks)
total_loss /= len(data_loader.dataset)
total_score = np.mean(train_meter.compute_metric(args['metric_name'])) total_score = np.mean(train_meter.compute_metric(args['metric_name']))
print('epoch {:d}/{:d}, training loss {:.4f}, training {} {:.4f}'.format( print('epoch {:d}/{:d}, training {} {:.4f}'.format(
epoch + 1, args['num_epochs'], total_loss, args['metric_name'], total_score)) epoch + 1, args['num_epochs'], args['metric_name'], total_score))
def run_an_eval_epoch(args, model, data_loader): def run_an_eval_epoch(args, model, data_loader):
model.eval() model.eval()
...@@ -57,34 +58,31 @@ def main(args): ...@@ -57,34 +58,31 @@ def main(args):
args['device'] = "cuda" if torch.cuda.is_available() else "cpu" args['device'] = "cuda" if torch.cuda.is_available() else "cpu"
set_random_seed() set_random_seed()
# Interchangeable with other datasets
train_set, val_set, test_set = load_dataset_for_regression(args) train_set, val_set, test_set = load_dataset_for_regression(args)
train_loader = DataLoader(dataset=train_set, train_loader = DataLoader(dataset=train_set,
batch_size=args['batch_size'], batch_size=args['batch_size'],
shuffle=True,
collate_fn=collate_molgraphs) collate_fn=collate_molgraphs)
val_loader = DataLoader(dataset=val_set, val_loader = DataLoader(dataset=val_set,
batch_size=args['batch_size'], batch_size=args['batch_size'],
shuffle=True,
collate_fn=collate_molgraphs) collate_fn=collate_molgraphs)
if test_set is not None: if test_set is not None:
test_loader = DataLoader(dataset=test_set, test_loader = DataLoader(dataset=test_set,
batch_size=args['batch_size'], batch_size=args['batch_size'],
collate_fn=collate_molgraphs) collate_fn=collate_molgraphs)
if args['model'] == 'MPNN': if args['pre_trained']:
model = model_zoo.chem.MPNNModel(node_input_dim=args['node_in_feats'], args['num_epochs'] = 0
edge_input_dim=args['edge_in_feats'], model = model_zoo.chem.load_pretrained(args['exp'])
output_dim=args['output_dim']) else:
elif args['model'] == 'SCHNET': model = load_model(args)
model = model_zoo.chem.SchNet(norm=args['norm'], output_dim=args['output_dim']) if args['model'] in ['SCHNET', 'MGCN']:
model.set_mean_std(train_set.mean, train_set.std, args['device'])
elif args['model'] == 'MGCN':
model = model_zoo.chem.MGCNModel(norm=args['norm'], output_dim=args['output_dim'])
model.set_mean_std(train_set.mean, train_set.std, args['device']) model.set_mean_std(train_set.mean, train_set.std, args['device'])
model.to(args['device'])
loss_fn = nn.MSELoss(reduction='none') loss_fn = nn.MSELoss(reduction='none')
optimizer = torch.optim.Adam(model.parameters(), lr=args['lr']) optimizer = torch.optim.Adam(model.parameters(), lr=args['lr'], weight_decay=args['weight_decay'])
stopper = EarlyStopping(mode='lower', patience=args['patience']) stopper = EarlyStopping(mode='lower', patience=args['patience'])
model.to(args['device'])
for epoch in range(args['num_epochs']): for epoch in range(args['num_epochs']):
# Train # Train
...@@ -96,10 +94,12 @@ def main(args): ...@@ -96,10 +94,12 @@ def main(args):
print('epoch {:d}/{:d}, validation {} {:.4f}, best validation {} {:.4f}'.format( print('epoch {:d}/{:d}, validation {} {:.4f}, best validation {} {:.4f}'.format(
epoch + 1, args['num_epochs'], args['metric_name'], val_score, epoch + 1, args['num_epochs'], args['metric_name'], val_score,
args['metric_name'], stopper.best_score)) args['metric_name'], stopper.best_score))
if early_stop: if early_stop:
break break
if test_set is not None: if test_set is not None:
if not args['pre_trained']:
stopper.load_checkpoint(model) stopper.load_checkpoint(model)
test_score = run_an_eval_epoch(args, model, test_loader) test_score = run_an_eval_epoch(args, model, test_loader)
print('test {} {:.4f}'.format(args['metric_name'], test_score)) print('test {} {:.4f}'.format(args['metric_name'], test_score))
...@@ -110,10 +110,13 @@ if __name__ == "__main__": ...@@ -110,10 +110,13 @@ if __name__ == "__main__":
from configure import get_exp_configure from configure import get_exp_configure
parser = argparse.ArgumentParser(description='Molecule Regression') parser = argparse.ArgumentParser(description='Molecule Regression')
parser.add_argument('-m', '--model', type=str, choices=['MPNN', 'SCHNET', 'MGCN'], parser.add_argument('-m', '--model', type=str,
choices=['MPNN', 'SCHNET', 'MGCN', 'AttentiveFP'],
help='Model to use') help='Model to use')
parser.add_argument('-d', '--dataset', type=str, choices=['Alchemy'], parser.add_argument('-d', '--dataset', type=str, choices=['Alchemy', 'Aromaticity'],
help='Dataset to use') help='Dataset to use')
parser.add_argument('-p', '--pre-trained', action='store_true',
help='Whether to skip training and use a pre-trained model')
args = parser.parse_args().__dict__ args = parser.parse_args().__dict__
args['exp'] = '_'.join([args['model'], args['dataset']]) args['exp'] = '_'.join([args['model'], args['dataset']])
args.update(get_exp_configure(args['exp'])) args.update(get_exp_configure(args['exp']))
......
import datetime import datetime
import dgl import dgl
import math
import numpy as np import numpy as np
import random import random
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from dgl import model_zoo
from dgl.data.chem import one_hot_encoding
from dgl.data.utils import split_dataset from dgl.data.utils import split_dataset
from sklearn.metrics import roc_auc_score, mean_squared_error from sklearn.metrics import roc_auc_score
def set_random_seed(seed=0): def set_random_seed(seed=0):
"""Set random seed. """Set random seed.
...@@ -23,6 +24,13 @@ def set_random_seed(seed=0): ...@@ -23,6 +24,13 @@ def set_random_seed(seed=0):
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.manual_seed(seed) torch.cuda.manual_seed(seed)
def chirality(atom):
try:
return one_hot_encoding(atom.GetProp('_CIPCode'), ['R', 'S']) + \
[atom.HasProp('_ChiralityPossible')]
except:
return [False, False] + [atom.HasProp('_ChiralityPossible')]
class Meter(object): class Meter(object):
"""Track and summarize model performance on a dataset for """Track and summarize model performance on a dataset for
(multi-label) binary classification.""" (multi-label) binary classification."""
...@@ -110,9 +118,9 @@ class Meter(object): ...@@ -110,9 +118,9 @@ class Meter(object):
scores = [] scores = []
for task in range(n_tasks): for task in range(n_tasks):
task_w = mask[:, task] task_w = mask[:, task]
task_y_true = y_true[:, task][task_w != 0].numpy() task_y_true = y_true[:, task][task_w != 0]
task_y_pred = y_pred[:, task][task_w != 0].numpy() task_y_pred = y_pred[:, task][task_w != 0]
scores.append(math.sqrt(mean_squared_error(task_y_true, task_y_pred))) scores.append(np.sqrt(F.mse_loss(task_y_pred, task_y_true).cpu().item()))
return scores return scores
def compute_metric(self, metric_name, reduction='mean'): def compute_metric(self, metric_name, reduction='mean'):
...@@ -292,11 +300,55 @@ def load_dataset_for_regression(args): ...@@ -292,11 +300,55 @@ def load_dataset_for_regression(args):
test_set test_set
Subset for test. Subset for test.
""" """
assert args['dataset'] in ['Alchemy'] assert args['dataset'] in ['Alchemy', 'Aromaticity']
if args['dataset'] == 'Alchemy': if args['dataset'] == 'Alchemy':
from dgl.data.chem import TencentAlchemyDataset from dgl.data.chem import TencentAlchemyDataset
train_set = TencentAlchemyDataset(mode='dev') train_set = TencentAlchemyDataset(mode='dev')
val_set = TencentAlchemyDataset(mode='valid') val_set = TencentAlchemyDataset(mode='valid')
test_set = None test_set = None
if args['dataset'] == 'Aromaticity':
from dgl.data.chem import PubChemBioAssayAromaticity
dataset = PubChemBioAssayAromaticity(atom_featurizer=args['atom_featurizer'],
bond_featurizer=args['bond_featurizer'])
train_set, val_set, test_set = split_dataset(dataset, frac_list=args['train_val_test_split'],
shuffle=True, random_state=args['random_seed'])
return train_set, val_set, test_set return train_set, val_set, test_set
def load_model(args):
if args['model'] == 'GCN':
model = model_zoo.chem.GCNClassifier(in_feats=args['in_feats'],
gcn_hidden_feats=args['gcn_hidden_feats'],
classifier_hidden_feats=args['classifier_hidden_feats'],
n_tasks=args['n_tasks'])
if args['model'] == 'GAT':
model = model_zoo.chem.GATClassifier(in_feats=args['in_feats'],
gat_hidden_feats=args['gat_hidden_feats'],
num_heads=args['num_heads'],
classifier_hidden_feats=args['classifier_hidden_feats'],
n_tasks=args['n_tasks'])
if args['model'] == 'MPNN':
model = model_zoo.chem.MPNNModel(node_input_dim=args['node_in_feats'],
edge_input_dim=args['edge_in_feats'],
output_dim=args['output_dim'])
if args['model'] == 'SCHNET':
model = model_zoo.chem.SchNet(norm=args['norm'], output_dim=args['output_dim'])
if args['model'] == 'MGCN':
model = model_zoo.chem.MGCNModel(norm=args['norm'], output_dim=args['output_dim'])
if args['model'] == 'AttentiveFP':
model = model_zoo.chem.AttentiveFP(node_feat_size=args['node_feat_size'],
edge_feat_size=args['edge_feat_size'],
num_layers=args['num_layers'],
num_timesteps=args['num_timesteps'],
graph_feat_size=args['graph_feat_size'],
output_size=args['output_size'],
dropout=args['dropout'])
return model
...@@ -2,3 +2,4 @@ from .utils import * ...@@ -2,3 +2,4 @@ from .utils import *
from .csv_dataset import MoleculeCSVDataset from .csv_dataset import MoleculeCSVDataset
from .tox21 import Tox21 from .tox21 import Tox21
from .alchemy import TencentAlchemyDataset from .alchemy import TencentAlchemyDataset
from .pubchem_aromaticity import PubChemBioAssayAromaticity
import pandas as pd
import sys
from .csv_dataset import MoleculeCSVDataset
from .utils import smiles_to_bigraph
from ..utils import get_download_dir, download, _get_dgl_url
class PubChemBioAssayAromaticity(MoleculeCSVDataset):
"""Subset of PubChem BioAssay Dataset for aromaticity prediction.
The dataset was constructed in `Pushing the Boundaries of Molecular Representation for Drug
Discovery with the Graph Attention Mechanism.
<https://www.ncbi.nlm.nih.gov/pubmed/31408336>`__ and is accompanied by the task of predicting
the number of aromatic atoms in molecules.
The dataset was constructed by sampling 3945 molecules with 0-40 aromatic atoms from the
PubChem BioAssay dataset.
"""
def __init__(self, smiles_to_graph=smiles_to_bigraph,
atom_featurizer=None,
bond_featurizer=None):
if 'pandas' not in sys.modules:
from ...base import dgl_warning
dgl_warning("Please install pandas")
self._url = 'dataset/pubchem_bioassay_aromaticity.csv'
data_path = get_download_dir() + '/pubchem_bioassay_aromaticity.csv'
download(_get_dgl_url(self._url), path=data_path)
df = pd.read_csv(data_path)
super(PubChemBioAssayAromaticity, self).__init__(df, smiles_to_graph, atom_featurizer, bond_featurizer,
"cano_smiles", "pubchem_aromaticity_dglgraph.bin")
import numpy as np
import sys import sys
from .csv_dataset import MoleculeCSVDataset from .csv_dataset import MoleculeCSVDataset
......
...@@ -8,3 +8,4 @@ from .mpnn import MPNNModel ...@@ -8,3 +8,4 @@ from .mpnn import MPNNModel
from .dgmg import DGMG from .dgmg import DGMG
from .jtnn import DGLJTNNVAE from .jtnn import DGLJTNNVAE
from .pretrain import load_pretrained from .pretrain import load_pretrained
from .attentive_fp import AttentiveFP
# pylint: disable=C0103, W0612, E1101
"""Pushing the Boundaries of Molecular Representation for Drug Discovery
with the Graph Attention Mechanism"""
import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F
from ... import function as fn
from ...nn.pytorch.softmax import edge_softmax
class AttentiveGRU1(nn.Module):
"""Update node features with attention and GRU.
Parameters
----------
node_feat_size : int
Size for the input node (atom) features.
edge_feat_size : int
Size for the input edge (bond) features.
edge_hidden_size : int
Size for the intermediate edge (bond) representations.
dropout : float
The probability for performing dropout.
"""
def __init__(self, node_feat_size, edge_feat_size, edge_hidden_size, dropout):
super(AttentiveGRU1, self).__init__()
self.edge_transform = nn.Sequential(
nn.Dropout(dropout),
nn.Linear(edge_feat_size, edge_hidden_size)
)
self.gru = nn.GRUCell(edge_hidden_size, node_feat_size)
def forward(self, g, edge_logits, edge_feats, node_feats):
"""
Parameters
----------
g : DGLGraph
edge_logits : float32 tensor of shape (E, 1)
The edge logits based on which softmax will be performed for weighting
edges within 1-hop neighborhoods. E represents the number of edges.
edge_feats : float32 tensor of shape (E, M1)
Previous edge features.
node_feats : float32 tensor of shape (V, M2)
Previous node features.
Returns
-------
float32 tensor of shape (V, M2)
Updated node features.
"""
g = g.local_var()
g.edata['e'] = edge_softmax(g, edge_logits) * self.edge_transform(edge_feats)
g.update_all(fn.copy_edge('e', 'm'), fn.sum('m', 'c'))
context = F.elu(g.ndata['c'])
return F.relu(self.gru(context, node_feats))
class AttentiveGRU2(nn.Module):
"""Update node features with attention and GRU.
Parameters
----------
node_feat_size : int
Size for the input node (atom) features.
edge_hidden_size : int
Size for the intermediate edge (bond) representations.
dropout : float
The probability for performing dropout.
"""
def __init__(self, node_feat_size, edge_hidden_size, dropout):
super(AttentiveGRU2, self).__init__()
self.project_node = nn.Sequential(
nn.Dropout(dropout),
nn.Linear(node_feat_size, edge_hidden_size)
)
self.gru = nn.GRUCell(edge_hidden_size, node_feat_size)
def forward(self, g, edge_logits, node_feats):
"""
Parameters
----------
g : DGLGraph
edge_logits : float32 tensor of shape (E, 1)
The edge logits based on which softmax will be performed for weighting
edges within 1-hop neighborhoods. E represents the number of edges.
node_feats : float32 tensor of shape (V, M2)
Previous node features.
Returns
-------
float32 tensor of shape (V, M2)
Updated node features.
"""
g = g.local_var()
g.edata['a'] = edge_softmax(g, edge_logits)
g.ndata['hv'] = self.project_node(node_feats)
g.update_all(fn.src_mul_edge('hv', 'a', 'm'), fn.sum('m', 'c'))
context = F.elu(g.ndata['c'])
return F.relu(self.gru(context, node_feats))
class GetContext(nn.Module):
"""Generate context for each node (atom) by message passing at the beginning.
Parameters
----------
node_feat_size : int
Size for the input node (atom) features.
edge_feat_size : int
Size for the input edge (bond) features.
graph_feat_size : int
Size of the learned graph representation (molecular fingerprint).
dropout : float
The probability for performing dropout.
"""
def __init__(self, node_feat_size, edge_feat_size, graph_feat_size, dropout):
super(GetContext, self).__init__()
self.project_node = nn.Sequential(
nn.Linear(node_feat_size, graph_feat_size),
nn.LeakyReLU()
)
self.project_edge1 = nn.Sequential(
nn.Linear(node_feat_size + edge_feat_size, graph_feat_size),
nn.LeakyReLU()
)
self.project_edge2 = nn.Sequential(
nn.Dropout(dropout),
nn.Linear(2 * graph_feat_size, 1),
nn.LeakyReLU()
)
self.attentive_gru = AttentiveGRU1(graph_feat_size, graph_feat_size,
graph_feat_size, dropout)
def apply_edges1(self, edges):
"""Edge feature update."""
return {'he1': torch.cat([edges.src['hv'], edges.data['he']], dim=1)}
def apply_edges2(self, edges):
"""Edge feature update."""
return {'he2': torch.cat([edges.dst['hv_new'], edges.data['he1']], dim=1)}
def forward(self, g, node_feats, edge_feats):
"""
Parameters
----------
g : DGLGraph or BatchedDGLGraph
Constructed DGLGraphs.
node_feats : float32 tensor of shape (V, N1)
Input node features. V for the number of nodes and N1 for the feature size.
edge_feats : float32 tensor of shape (E, N2)
Input edge features. E for the number of edges and N2 for the feature size.
Returns
-------
float32 tensor of shape (V, N3)
Updated node features.
"""
g = g.local_var()
g.ndata['hv'] = node_feats
g.ndata['hv_new'] = self.project_node(node_feats)
g.edata['he'] = edge_feats
g.apply_edges(self.apply_edges1)
g.edata['he1'] = self.project_edge1(g.edata['he1'])
g.apply_edges(self.apply_edges2)
logits = self.project_edge2(g.edata['he2'])
return self.attentive_gru(g, logits, g.edata['he1'], g.ndata['hv_new'])
class GNNLayer(nn.Module):
"""GNNLayer for updating node features.
Parameters
----------
node_feat_size : int
Size for the input node features.
graph_feat_size : int
Size for the input graph features.
dropout : float
The probability for performing dropout.
"""
def __init__(self, node_feat_size, graph_feat_size, dropout):
super(GNNLayer, self).__init__()
self.project_edge = nn.Sequential(
nn.Dropout(dropout),
nn.Linear(2 * node_feat_size, 1),
nn.LeakyReLU()
)
self.attentive_gru = AttentiveGRU2(node_feat_size, graph_feat_size, dropout)
def apply_edges(self, edges):
"""Edge feature update by concatenating the features of the destination
and source nodes."""
return {'he': torch.cat([edges.dst['hv'], edges.src['hv']], dim=1)}
def forward(self, g, node_feats):
"""
Parameters
----------
g : DGLGraph or BatchedDGLGraph
Constructed DGLGraphs.
node_feats : float32 tensor of shape (V, N1)
Input node features. V for the number of nodes and N1 for the feature size.
Returns
-------
float32 tensor of shape (V, N1)
Updated node features.
"""
g = g.local_var()
g.ndata['hv'] = node_feats
g.apply_edges(self.apply_edges)
logits = self.project_edge(g.edata['he'])
return self.attentive_gru(g, logits, node_feats)
class GlobalPool(nn.Module):
"""Graph feature update.
Parameters
----------
node_feat_size : int
Size for the input node features.
graph_feat_size : int
Size for the input graph features.
dropout : float
The probability for performing dropout.
"""
def __init__(self, node_feat_size, graph_feat_size, dropout):
super(GlobalPool, self).__init__()
self.compute_logits = nn.Sequential(
nn.Linear(node_feat_size + graph_feat_size, 1),
nn.LeakyReLU()
)
self.project_nodes = nn.Sequential(
nn.Dropout(dropout),
nn.Linear(node_feat_size, graph_feat_size)
)
self.gru = nn.GRUCell(graph_feat_size, graph_feat_size)
def forward(self, g, node_feats, g_feats, get_node_weight=False):
"""
Parameters
----------
g : DGLGraph or BatchedDGLGraph
Constructed DGLGraphs.
node_feats : float32 tensor of shape (V, N1)
Input node features. V for the number of nodes and N1 for the feature size.
g_feats : float32 tensor of shape (G, N2)
Input graph features. G for the number of graphs and N2 for the feature size.
get_node_weight : bool
Whether to get the weights of atoms during readout.
Returns
-------
float32 tensor of shape (G, N2)
Updated graph features.
float32 tensor of shape (V, 1)
The weights of nodes in readout.
"""
with g.local_scope():
g.ndata['z'] = self.compute_logits(
torch.cat([dgl.broadcast_nodes(g, F.relu(g_feats)), node_feats], dim=1))
g.ndata['a'] = dgl.softmax_nodes(g, 'z')
g.ndata['hv'] = self.project_nodes(node_feats)
context = F.elu(dgl.sum_nodes(g, 'hv', 'a'))
if get_node_weight:
return self.gru(context, g_feats), g.ndata['a']
else:
return self.gru(context, g_feats)
class AttentiveFP(nn.Module):
"""`Pushing the Boundaries of Molecular Representation for Drug Discovery with the Graph
Attention Mechanism <https://www.ncbi.nlm.nih.gov/pubmed/31408336>`__
Parameters
----------
node_feat_size : int
Size for the input node (atom) features.
edge_feat_size : int
Size for the input edge (bond) features.
num_layers : int
Number of GNN layers.
num_timesteps : int
Number of timesteps for updating the molecular representation with GRU.
graph_feat_size : int
Size of the learned graph representation (molecular fingerprint).
output_size : int
Size of the prediction (target labels).
dropout : float
The probability for performing dropout.
"""
def __init__(self,
node_feat_size,
edge_feat_size,
num_layers,
num_timesteps,
graph_feat_size,
output_size,
dropout):
super(AttentiveFP, self).__init__()
self.init_context = GetContext(node_feat_size, edge_feat_size, graph_feat_size, dropout)
self.gnn_layers = nn.ModuleList()
for i in range(num_layers - 1):
self.gnn_layers.append(GNNLayer(graph_feat_size, graph_feat_size, dropout))
self.readouts = nn.ModuleList()
for t in range(num_timesteps):
self.readouts.append(GlobalPool(graph_feat_size, graph_feat_size, dropout))
self.predict = nn.Sequential(
nn.Dropout(dropout),
nn.Linear(graph_feat_size, output_size)
)
def forward(self, g, node_feats, edge_feats, get_node_weight=False):
"""
Parameters
----------
g : DGLGraph or BatchedDGLGraph
Constructed DGLGraphs.
node_feats : float32 tensor of shape (V, N1)
Input node features. V for the number of nodes and N1 for the feature size.
edge_feats : float32 tensor of shape (E, N2)
Input edge features. E for the number of edges and N2 for the feature size.
get_node_weight : bool
Whether to get the weights of atoms during readout.
Returns
-------
float32 tensor of shape (G, N3)
Prediction for the graphs. G for the number of graphs and N3 for the output size.
node_weights : list of float32 tensors of shape (V, 1)
Weights of nodes in all readout operations.
"""
node_feats = self.init_context(g, node_feats, edge_feats)
for gnn in self.gnn_layers:
node_feats = gnn(g, node_feats)
with g.local_scope():
g.ndata['hv'] = node_feats
g_feats = dgl.sum_nodes(g, 'hv')
if get_node_weight:
node_weights = []
for readout in self.readouts:
if get_node_weight:
g_feats, node_weights_t = readout(g, node_feats, g_feats, get_node_weight)
node_weights.append(node_weights_t)
else:
g_feats = readout(g, node_feats, g_feats)
if get_node_weight:
return self.predict(g_feats), node_weights
else:
return self.predict(g_feats)
...@@ -9,6 +9,7 @@ from .dgmg import DGMG ...@@ -9,6 +9,7 @@ from .dgmg import DGMG
from .mgcn import MGCNModel from .mgcn import MGCNModel
from .mpnn import MPNNModel from .mpnn import MPNNModel
from .schnet import SchNet from .schnet import SchNet
from .attentive_fp import AttentiveFP
from ...data.utils import _get_dgl_url, download, get_download_dir, extract_archive from ...data.utils import _get_dgl_url, download, get_download_dir, extract_archive
URL = { URL = {
...@@ -17,6 +18,7 @@ URL = { ...@@ -17,6 +18,7 @@ URL = {
'MGCN_Alchemy': 'pre_trained/mgcn_alchemy.pth', 'MGCN_Alchemy': 'pre_trained/mgcn_alchemy.pth',
'SCHNET_Alchemy': 'pre_trained/schnet_alchemy.pth', 'SCHNET_Alchemy': 'pre_trained/schnet_alchemy.pth',
'MPNN_Alchemy': 'pre_trained/mpnn_alchemy.pth', 'MPNN_Alchemy': 'pre_trained/mpnn_alchemy.pth',
'AttentiveFP_Aromaticity': 'pre_trained/attentivefp_aromaticity.pth',
'DGMG_ChEMBL_canonical' : 'pre_trained/dgmg_ChEMBL_canonical.pth', 'DGMG_ChEMBL_canonical' : 'pre_trained/dgmg_ChEMBL_canonical.pth',
'DGMG_ChEMBL_random' : 'pre_trained/dgmg_ChEMBL_random.pth', 'DGMG_ChEMBL_random' : 'pre_trained/dgmg_ChEMBL_random.pth',
'DGMG_ZINC_canonical' : 'pre_trained/dgmg_ZINC_canonical.pth', 'DGMG_ZINC_canonical' : 'pre_trained/dgmg_ZINC_canonical.pth',
...@@ -69,6 +71,7 @@ def load_pretrained(model_name, log=True): ...@@ -69,6 +71,7 @@ def load_pretrained(model_name, log=True):
* ``'MGCN_Alchemy'`` * ``'MGCN_Alchemy'``
* ``'SCHNET_Alchemy'`` * ``'SCHNET_Alchemy'``
* ``'MPNN_Alchemy'`` * ``'MPNN_Alchemy'``
* ``'AttentiveFP_Aromaticity'``
* ``'DGMG_ChEMBL_canonical'`` * ``'DGMG_ChEMBL_canonical'``
* ``'DGMG_ChEMBL_random'`` * ``'DGMG_ChEMBL_random'``
* ``'DGMG_ZINC_canonical'`` * ``'DGMG_ZINC_canonical'``
...@@ -122,6 +125,15 @@ def load_pretrained(model_name, log=True): ...@@ -122,6 +125,15 @@ def load_pretrained(model_name, log=True):
elif model_name == 'MPNN_Alchemy': elif model_name == 'MPNN_Alchemy':
model = MPNNModel(output_dim=12) model = MPNNModel(output_dim=12)
elif model_name == 'AttentiveFP_Aromaticity':
model = AttentiveFP(node_feat_size=39,
edge_feat_size=10,
num_layers=2,
num_timesteps=2,
graph_feat_size=200,
output_size=1,
dropout=0.2)
elif model_name == "JTNN_ZINC": elif model_name == "JTNN_ZINC":
default_dir = get_download_dir() default_dir = get_download_dir()
vocab_file = '{}/jtnn/{}.txt'.format(default_dir, 'vocab') vocab_file = '{}/jtnn/{}.txt'.format(default_dir, 'vocab')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment