"git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "2539f2dbf99ec1b8f44ece884bf2c8678fca3127"
Unverified Commit 52c7ef49 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[Model Zoo] AttentiveFP (#955)

* Update

* Fix style

* Update

* Update

* Fix

* Update

* Update
parent 43028f70
......@@ -224,13 +224,17 @@ If your dataset is stored in a ``.csv`` file, you may find it helpful to use
.. autoclass:: dgl.data.chem.CSVDataset
:members: __getitem__, __len__
Currently two datasets are supported:
Currently three datasets are supported:
* Tox21
* TencentAlchemyDataset
* PubChemBioAssayAromaticity
.. autoclass:: dgl.data.chem.Tox21
:members: __getitem__, __len__, task_pos_weights
.. autoclass:: dgl.data.chem.TencentAlchemyDataset
:members: __getitem__, __len__, set_mean_and_std
.. autoclass:: dgl.data.chem.PubChemBioAssayAromaticity
:members: __getitem__, __len__
......@@ -26,6 +26,7 @@ Currently supported model architectures:
* MPNN
* SchNet
* MGCN
* AttentiveFP
.. autoclass:: dgl.model_zoo.chem.GCNClassifier
:members: forward
......@@ -42,6 +43,9 @@ Currently supported model architectures:
.. autoclass:: dgl.model_zoo.chem.MGCNModel
:members: forward
.. autoclass:: dgl.model_zoo.chem.AttentiveFP
:members: forward
Generative Models
`````````````````
......
......@@ -34,20 +34,20 @@ We use GPU whenever it is available.
#### GCN on Tox21
| Source | Averaged ROC-AUC Score |
| ---------------- | ---------------------- |
| MoleculeNet [1] | 0.829 |
| Source | Averaged Test ROC-AUC Score |
| ---------------- | --------------------------- |
| MoleculeNet [1] | 0.829 |
| [DeepChem example](https://github.com/deepchem/deepchem/blob/master/examples/tox21/tox21_tensorgraph_graph_conv.py) | 0.813 |
| Pretrained model | 0.826 |
| Pretrained model | 0.826 |
Note that the dataset is randomly split so these numbers are only for reference and they do not necessarily suggest
a real difference.
#### GAT on Tox21
| Source | Averaged ROC-AUC Score |
| ---------------- | ---------------------- |
| Pretrained model | 0.827 |
| Source | Averaged Test ROC-AUC Score |
| ---------------- | --------------------------- |
| Pretrained model | 0.827 |
## Regression
......@@ -59,7 +59,11 @@ Regression tasks require assigning continuous labels to a molecule, e.g. molecul
machine learning models useful for chemistry and materials science. The dataset lists 12 quantum mechanical properties of 130,000+ organic
molecules comprising up to 12 heavy atoms (C, N, O, S, F and Cl), sampled from the [GDBMedChem](http://gdb.unibe.ch/downloads/) database.
These properties have been calculated using the open-source computational chemistry program Python-based Simulation of Chemistry Framework
([PySCF](https://github.com/pyscf/pyscf)). The Alchemy dataset expands on the volume and diversity of existing molecular datasets such as QM9.
([PySCF](https://github.com/pyscf/pyscf)). The Alchemy dataset expands on the volume and diversity of existing molecular datasets such as QM9.
- **PubChem BioAssay Aromaticity**. The dataset is introduced in
[Pushing the Boundaries of Molecular Representation for Drug Discovery with the Graph Attention Mechanism](https://www.ncbi.nlm.nih.gov/pubmed/31408336),
for the task of predicting the number of aromatic atoms in molecules. The dataset was constructed by sampling 3945 molecules with 0-40 aromatic atoms
from the PubChem BioAssay dataset.
### Models
......@@ -70,15 +74,20 @@ without requiring them to lie on grids.
- **Multilevel Graph Convolutional Neural Network** [5]. Multilevel Graph Convolutional Neural Networks (MGCN) are
hierarchical graph neural networks that extract features from the conformation and spatial information followed by the
multilevel interactions.
- **AttentiveFP** [8]. AttentiveFP combines attention and GRU for better model capacity and shows competitive
performance across datasetts.
### Usage
Use `regression.py` with arguments
```
-m {MPNN,SCHNET,MGCN}, Model to use
-d {Alchemy}, Dataset to use
-m {MPNN, SCHNET, MGCN, AttentiveFP}, Model to use
-d {Alchemy, Aromaticity}, Dataset to use
```
If you want to use the pre-trained model, simply add `-p`. Currently we only support pre-trained models of AttentiveFP
on PubChem BioAssay Aromaticity dataset.
### Performance
#### Alchemy
......@@ -92,6 +101,20 @@ on the training and validation set for reference.
| MGCN [5] | 0.2395 | 0.6463 |
| MPNN [6] | 0.2452 | 0.6259 |
#### PubChem BioAssay Aromaticity
| Model | Test RMSE |
| --------------- | --------- |
| AttentiveFP [8] | 0.6998 |
## Interpretation
[8] visualizes the weights of atoms in readout for possible interpretations like the figure below.
We provide a jupyter notebook for performing the visualization and you can download it with
`wget https://s3.us-east-2.amazonaws.com/dgl.ai/model_zoo/drug_discovery/AttentiveFP/atom_weight_visualization.ipynb`.
![](https://s3.us-east-2.amazonaws.com/dgl.ai/model_zoo/drug_discovery/AttentiveFP/vis_example.png)
## Dataset Customization
To customize your own dataset, see the instructions
......@@ -117,3 +140,6 @@ Machine Learning*, JMLR. 1263-1272.
[7] Veličković et al. (2018) Graph Attention Networks.
*The International Conference on Learning Representations (ICLR)*.
[8] Xiong et al. (2019) Pushing the Boundaries of Molecular Representation for Drug Discovery with the Graph
Attention Mechanism. *Journal of Medicinal Chemistry*.
......@@ -7,7 +7,7 @@ from torch.utils.data import DataLoader
from dgl import model_zoo
from utils import Meter, EarlyStopping, collate_molgraphs, set_random_seed, \
load_dataset_for_classification
load_dataset_for_classification, load_model
def run_a_train_epoch(args, epoch, model, data_loader, loss_criterion, optimizer):
model.train()
......@@ -60,19 +60,8 @@ def main(args):
args['num_epochs'] = 0
model = model_zoo.chem.load_pretrained(args['exp'])
else:
# Interchangeable with other models
if args['model'] == 'GCN':
model = model_zoo.chem.GCNClassifier(in_feats=args['in_feats'],
gcn_hidden_feats=args['gcn_hidden_feats'],
classifier_hidden_feats=args['classifier_hidden_feats'],
n_tasks=dataset.n_tasks)
elif args['model'] == 'GAT':
model = model_zoo.chem.GATClassifier(in_feats=args['in_feats'],
gat_hidden_feats=args['gat_hidden_feats'],
num_heads=args['num_heads'],
classifier_hidden_feats=args['classifier_hidden_feats'],
n_tasks=dataset.n_tasks)
args['n_tasks'] = dataset.n_tasks
model = load_model(args)
loss_criterion = BCEWithLogitsLoss(pos_weight=dataset.task_pos_weights.to(args['device']),
reduction='none')
optimizer = Adam(model.parameters(), lr=args['lr'])
......
from dgl.data.chem import CanonicalAtomFeaturizer
from dgl.data.chem import BaseAtomFeaturizer, CanonicalAtomFeaturizer, ConcatFeaturizer, \
atom_type_one_hot, atom_degree_one_hot, atom_formal_charge, atom_num_radical_electrons, \
atom_hybridization_one_hot, atom_total_num_H_one_hot, BaseBondFeaturizer
from functools import partial
from utils import chirality
GCN_Tox21 = {
'batch_size': 128,
......@@ -37,7 +42,8 @@ MPNN_Alchemy = {
'output_dim': 12,
'lr': 0.0001,
'patience': 50,
'metric_name': 'l1'
'metric_name': 'l1',
'weight_decay': 0
}
SCHNET_Alchemy = {
......@@ -47,7 +53,8 @@ SCHNET_Alchemy = {
'output_dim': 12,
'lr': 0.0001,
'patience': 50,
'metric_name': 'l1'
'metric_name': 'l1',
'weight_decay': 0
}
MGCN_Alchemy = {
......@@ -57,7 +64,43 @@ MGCN_Alchemy = {
'output_dim': 12,
'lr': 0.0001,
'patience': 50,
'metric_name': 'l1'
'metric_name': 'l1',
'weight_decay': 0
}
AttentiveFP_Aromaticity = {
'random_seed': 8,
'graph_feat_size': 200,
'num_layers': 2,
'num_timesteps': 2,
'node_feat_size': 39,
'edge_feat_size': 10,
'output_size': 1,
'dropout': 0.2,
'weight_decay': 10 ** (-5.0),
'lr': 10 ** (-2.5),
'batch_size': 128,
'num_epochs': 800,
'train_val_test_split': [0.8, 0.1, 0.1],
'patience': 80,
'metric_name': 'rmse',
# Follow the atom featurization in the original work
'atom_featurizer': BaseAtomFeaturizer(
featurizer_funcs={'hv': ConcatFeaturizer([
partial(atom_type_one_hot, allowable_set=[
'B', 'C', 'N', 'O', 'F', 'Si', 'P', 'S', 'Cl', 'As', 'Se', 'Br', 'Te', 'I', 'At'],
encode_unknown=True),
partial(atom_degree_one_hot, allowable_set=list(range(6))),
atom_formal_charge, atom_num_radical_electrons,
partial(atom_hybridization_one_hot, encode_unknown=True),
lambda atom: [0], # A placeholder for aromatic information,
atom_total_num_H_one_hot, chirality
],
)}
),
'bond_featurizer': BaseBondFeaturizer({
'he': lambda bond: [0 for _ in range(10)]
})
}
experiment_configures = {
......@@ -65,7 +108,8 @@ experiment_configures = {
'GAT_Tox21': GAT_Tox21,
'MPNN_Alchemy': MPNN_Alchemy,
'SCHNET_Alchemy': SCHNET_Alchemy,
'MGCN_Alchemy': MGCN_Alchemy
'MGCN_Alchemy': MGCN_Alchemy,
'AttentiveFP_Aromaticity': AttentiveFP_Aromaticity
}
def get_exp_configure(exp_name):
......
......@@ -6,7 +6,7 @@ from torch.utils.data import DataLoader
from dgl import model_zoo
from utils import Meter, set_random_seed, collate_molgraphs, EarlyStopping, \
load_dataset_for_regression
load_dataset_for_regression, load_model
def regress(args, model, bg):
if args['model'] == 'MPNN':
......@@ -14,18 +14,21 @@ def regress(args, model, bg):
e = bg.edata.pop('e_feat')
h, e = h.to(args['device']), e.to(args['device'])
return model(bg, h, e)
else:
elif args['model'] in ['SCHNET', 'MGCN']:
node_types = bg.ndata.pop('node_type')
edge_distances = bg.edata.pop('distance')
node_types, edge_distances = node_types.to(args['device']), \
edge_distances.to(args['device'])
return model(bg, node_types, edge_distances)
else:
atom_feats, bond_feats = bg.ndata.pop('hv'), bg.edata.pop('he')
atom_feats, bond_feats = atom_feats.to(args['device']), bond_feats.to(args['device'])
return model(bg, atom_feats, bond_feats)
def run_a_train_epoch(args, epoch, model, data_loader,
loss_criterion, optimizer):
model.train()
train_meter = Meter()
total_loss = 0
for batch_id, batch_data in enumerate(data_loader):
smiles, bg, labels, masks = batch_data
labels, masks = labels.to(args['device']), masks.to(args['device'])
......@@ -34,12 +37,10 @@ def run_a_train_epoch(args, epoch, model, data_loader,
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.detach().item() * bg.batch_size
train_meter.update(prediction, labels, masks)
total_loss /= len(data_loader.dataset)
total_score = np.mean(train_meter.compute_metric(args['metric_name']))
print('epoch {:d}/{:d}, training loss {:.4f}, training {} {:.4f}'.format(
epoch + 1, args['num_epochs'], total_loss, args['metric_name'], total_score))
print('epoch {:d}/{:d}, training {} {:.4f}'.format(
epoch + 1, args['num_epochs'], args['metric_name'], total_score))
def run_an_eval_epoch(args, model, data_loader):
model.eval()
......@@ -57,35 +58,32 @@ def main(args):
args['device'] = "cuda" if torch.cuda.is_available() else "cpu"
set_random_seed()
# Interchangeable with other datasets
train_set, val_set, test_set = load_dataset_for_regression(args)
train_loader = DataLoader(dataset=train_set,
batch_size=args['batch_size'],
shuffle=True,
collate_fn=collate_molgraphs)
val_loader = DataLoader(dataset=val_set,
batch_size=args['batch_size'],
shuffle=True,
collate_fn=collate_molgraphs)
if test_set is not None:
test_loader = DataLoader(dataset=test_set,
batch_size=args['batch_size'],
collate_fn=collate_molgraphs)
if args['model'] == 'MPNN':
model = model_zoo.chem.MPNNModel(node_input_dim=args['node_in_feats'],
edge_input_dim=args['edge_in_feats'],
output_dim=args['output_dim'])
elif args['model'] == 'SCHNET':
model = model_zoo.chem.SchNet(norm=args['norm'], output_dim=args['output_dim'])
model.set_mean_std(train_set.mean, train_set.std, args['device'])
elif args['model'] == 'MGCN':
model = model_zoo.chem.MGCNModel(norm=args['norm'], output_dim=args['output_dim'])
model.set_mean_std(train_set.mean, train_set.std, args['device'])
if args['pre_trained']:
args['num_epochs'] = 0
model = model_zoo.chem.load_pretrained(args['exp'])
else:
model = load_model(args)
if args['model'] in ['SCHNET', 'MGCN']:
model.set_mean_std(train_set.mean, train_set.std, args['device'])
loss_fn = nn.MSELoss(reduction='none')
optimizer = torch.optim.Adam(model.parameters(), lr=args['lr'], weight_decay=args['weight_decay'])
stopper = EarlyStopping(mode='lower', patience=args['patience'])
model.to(args['device'])
loss_fn = nn.MSELoss(reduction='none')
optimizer = torch.optim.Adam(model.parameters(), lr=args['lr'])
stopper = EarlyStopping(mode='lower', patience=args['patience'])
for epoch in range(args['num_epochs']):
# Train
run_a_train_epoch(args, epoch, model, train_loader, loss_fn, optimizer)
......@@ -96,11 +94,13 @@ def main(args):
print('epoch {:d}/{:d}, validation {} {:.4f}, best validation {} {:.4f}'.format(
epoch + 1, args['num_epochs'], args['metric_name'], val_score,
args['metric_name'], stopper.best_score))
if early_stop:
break
if test_set is not None:
stopper.load_checkpoint(model)
if not args['pre_trained']:
stopper.load_checkpoint(model)
test_score = run_an_eval_epoch(args, model, test_loader)
print('test {} {:.4f}'.format(args['metric_name'], test_score))
......@@ -110,10 +110,13 @@ if __name__ == "__main__":
from configure import get_exp_configure
parser = argparse.ArgumentParser(description='Molecule Regression')
parser.add_argument('-m', '--model', type=str, choices=['MPNN', 'SCHNET', 'MGCN'],
parser.add_argument('-m', '--model', type=str,
choices=['MPNN', 'SCHNET', 'MGCN', 'AttentiveFP'],
help='Model to use')
parser.add_argument('-d', '--dataset', type=str, choices=['Alchemy'],
parser.add_argument('-d', '--dataset', type=str, choices=['Alchemy', 'Aromaticity'],
help='Dataset to use')
parser.add_argument('-p', '--pre-trained', action='store_true',
help='Whether to skip training and use a pre-trained model')
args = parser.parse_args().__dict__
args['exp'] = '_'.join([args['model'], args['dataset']])
args.update(get_exp_configure(args['exp']))
......
import datetime
import dgl
import math
import numpy as np
import random
import torch
import torch.nn.functional as F
from dgl import model_zoo
from dgl.data.chem import one_hot_encoding
from dgl.data.utils import split_dataset
from sklearn.metrics import roc_auc_score, mean_squared_error
from sklearn.metrics import roc_auc_score
def set_random_seed(seed=0):
"""Set random seed.
......@@ -23,6 +24,13 @@ def set_random_seed(seed=0):
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
def chirality(atom):
try:
return one_hot_encoding(atom.GetProp('_CIPCode'), ['R', 'S']) + \
[atom.HasProp('_ChiralityPossible')]
except:
return [False, False] + [atom.HasProp('_ChiralityPossible')]
class Meter(object):
"""Track and summarize model performance on a dataset for
(multi-label) binary classification."""
......@@ -110,9 +118,9 @@ class Meter(object):
scores = []
for task in range(n_tasks):
task_w = mask[:, task]
task_y_true = y_true[:, task][task_w != 0].numpy()
task_y_pred = y_pred[:, task][task_w != 0].numpy()
scores.append(math.sqrt(mean_squared_error(task_y_true, task_y_pred)))
task_y_true = y_true[:, task][task_w != 0]
task_y_pred = y_pred[:, task][task_w != 0]
scores.append(np.sqrt(F.mse_loss(task_y_pred, task_y_true).cpu().item()))
return scores
def compute_metric(self, metric_name, reduction='mean'):
......@@ -292,11 +300,55 @@ def load_dataset_for_regression(args):
test_set
Subset for test.
"""
assert args['dataset'] in ['Alchemy']
assert args['dataset'] in ['Alchemy', 'Aromaticity']
if args['dataset'] == 'Alchemy':
from dgl.data.chem import TencentAlchemyDataset
train_set = TencentAlchemyDataset(mode='dev')
val_set = TencentAlchemyDataset(mode='valid')
test_set = None
if args['dataset'] == 'Aromaticity':
from dgl.data.chem import PubChemBioAssayAromaticity
dataset = PubChemBioAssayAromaticity(atom_featurizer=args['atom_featurizer'],
bond_featurizer=args['bond_featurizer'])
train_set, val_set, test_set = split_dataset(dataset, frac_list=args['train_val_test_split'],
shuffle=True, random_state=args['random_seed'])
return train_set, val_set, test_set
def load_model(args):
if args['model'] == 'GCN':
model = model_zoo.chem.GCNClassifier(in_feats=args['in_feats'],
gcn_hidden_feats=args['gcn_hidden_feats'],
classifier_hidden_feats=args['classifier_hidden_feats'],
n_tasks=args['n_tasks'])
if args['model'] == 'GAT':
model = model_zoo.chem.GATClassifier(in_feats=args['in_feats'],
gat_hidden_feats=args['gat_hidden_feats'],
num_heads=args['num_heads'],
classifier_hidden_feats=args['classifier_hidden_feats'],
n_tasks=args['n_tasks'])
if args['model'] == 'MPNN':
model = model_zoo.chem.MPNNModel(node_input_dim=args['node_in_feats'],
edge_input_dim=args['edge_in_feats'],
output_dim=args['output_dim'])
if args['model'] == 'SCHNET':
model = model_zoo.chem.SchNet(norm=args['norm'], output_dim=args['output_dim'])
if args['model'] == 'MGCN':
model = model_zoo.chem.MGCNModel(norm=args['norm'], output_dim=args['output_dim'])
if args['model'] == 'AttentiveFP':
model = model_zoo.chem.AttentiveFP(node_feat_size=args['node_feat_size'],
edge_feat_size=args['edge_feat_size'],
num_layers=args['num_layers'],
num_timesteps=args['num_timesteps'],
graph_feat_size=args['graph_feat_size'],
output_size=args['output_size'],
dropout=args['dropout'])
return model
......@@ -2,3 +2,4 @@ from .utils import *
from .csv_dataset import MoleculeCSVDataset
from .tox21 import Tox21
from .alchemy import TencentAlchemyDataset
from .pubchem_aromaticity import PubChemBioAssayAromaticity
import pandas as pd
import sys
from .csv_dataset import MoleculeCSVDataset
from .utils import smiles_to_bigraph
from ..utils import get_download_dir, download, _get_dgl_url
class PubChemBioAssayAromaticity(MoleculeCSVDataset):
"""Subset of PubChem BioAssay Dataset for aromaticity prediction.
The dataset was constructed in `Pushing the Boundaries of Molecular Representation for Drug
Discovery with the Graph Attention Mechanism.
<https://www.ncbi.nlm.nih.gov/pubmed/31408336>`__ and is accompanied by the task of predicting
the number of aromatic atoms in molecules.
The dataset was constructed by sampling 3945 molecules with 0-40 aromatic atoms from the
PubChem BioAssay dataset.
"""
def __init__(self, smiles_to_graph=smiles_to_bigraph,
atom_featurizer=None,
bond_featurizer=None):
if 'pandas' not in sys.modules:
from ...base import dgl_warning
dgl_warning("Please install pandas")
self._url = 'dataset/pubchem_bioassay_aromaticity.csv'
data_path = get_download_dir() + '/pubchem_bioassay_aromaticity.csv'
download(_get_dgl_url(self._url), path=data_path)
df = pd.read_csv(data_path)
super(PubChemBioAssayAromaticity, self).__init__(df, smiles_to_graph, atom_featurizer, bond_featurizer,
"cano_smiles", "pubchem_aromaticity_dglgraph.bin")
import numpy as np
import sys
from .csv_dataset import MoleculeCSVDataset
......
......@@ -8,3 +8,4 @@ from .mpnn import MPNNModel
from .dgmg import DGMG
from .jtnn import DGLJTNNVAE
from .pretrain import load_pretrained
from .attentive_fp import AttentiveFP
# pylint: disable=C0103, W0612, E1101
"""Pushing the Boundaries of Molecular Representation for Drug Discovery
with the Graph Attention Mechanism"""
import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F
from ... import function as fn
from ...nn.pytorch.softmax import edge_softmax
class AttentiveGRU1(nn.Module):
"""Update node features with attention and GRU.
Parameters
----------
node_feat_size : int
Size for the input node (atom) features.
edge_feat_size : int
Size for the input edge (bond) features.
edge_hidden_size : int
Size for the intermediate edge (bond) representations.
dropout : float
The probability for performing dropout.
"""
def __init__(self, node_feat_size, edge_feat_size, edge_hidden_size, dropout):
super(AttentiveGRU1, self).__init__()
self.edge_transform = nn.Sequential(
nn.Dropout(dropout),
nn.Linear(edge_feat_size, edge_hidden_size)
)
self.gru = nn.GRUCell(edge_hidden_size, node_feat_size)
def forward(self, g, edge_logits, edge_feats, node_feats):
"""
Parameters
----------
g : DGLGraph
edge_logits : float32 tensor of shape (E, 1)
The edge logits based on which softmax will be performed for weighting
edges within 1-hop neighborhoods. E represents the number of edges.
edge_feats : float32 tensor of shape (E, M1)
Previous edge features.
node_feats : float32 tensor of shape (V, M2)
Previous node features.
Returns
-------
float32 tensor of shape (V, M2)
Updated node features.
"""
g = g.local_var()
g.edata['e'] = edge_softmax(g, edge_logits) * self.edge_transform(edge_feats)
g.update_all(fn.copy_edge('e', 'm'), fn.sum('m', 'c'))
context = F.elu(g.ndata['c'])
return F.relu(self.gru(context, node_feats))
class AttentiveGRU2(nn.Module):
"""Update node features with attention and GRU.
Parameters
----------
node_feat_size : int
Size for the input node (atom) features.
edge_hidden_size : int
Size for the intermediate edge (bond) representations.
dropout : float
The probability for performing dropout.
"""
def __init__(self, node_feat_size, edge_hidden_size, dropout):
super(AttentiveGRU2, self).__init__()
self.project_node = nn.Sequential(
nn.Dropout(dropout),
nn.Linear(node_feat_size, edge_hidden_size)
)
self.gru = nn.GRUCell(edge_hidden_size, node_feat_size)
def forward(self, g, edge_logits, node_feats):
"""
Parameters
----------
g : DGLGraph
edge_logits : float32 tensor of shape (E, 1)
The edge logits based on which softmax will be performed for weighting
edges within 1-hop neighborhoods. E represents the number of edges.
node_feats : float32 tensor of shape (V, M2)
Previous node features.
Returns
-------
float32 tensor of shape (V, M2)
Updated node features.
"""
g = g.local_var()
g.edata['a'] = edge_softmax(g, edge_logits)
g.ndata['hv'] = self.project_node(node_feats)
g.update_all(fn.src_mul_edge('hv', 'a', 'm'), fn.sum('m', 'c'))
context = F.elu(g.ndata['c'])
return F.relu(self.gru(context, node_feats))
class GetContext(nn.Module):
"""Generate context for each node (atom) by message passing at the beginning.
Parameters
----------
node_feat_size : int
Size for the input node (atom) features.
edge_feat_size : int
Size for the input edge (bond) features.
graph_feat_size : int
Size of the learned graph representation (molecular fingerprint).
dropout : float
The probability for performing dropout.
"""
def __init__(self, node_feat_size, edge_feat_size, graph_feat_size, dropout):
super(GetContext, self).__init__()
self.project_node = nn.Sequential(
nn.Linear(node_feat_size, graph_feat_size),
nn.LeakyReLU()
)
self.project_edge1 = nn.Sequential(
nn.Linear(node_feat_size + edge_feat_size, graph_feat_size),
nn.LeakyReLU()
)
self.project_edge2 = nn.Sequential(
nn.Dropout(dropout),
nn.Linear(2 * graph_feat_size, 1),
nn.LeakyReLU()
)
self.attentive_gru = AttentiveGRU1(graph_feat_size, graph_feat_size,
graph_feat_size, dropout)
def apply_edges1(self, edges):
"""Edge feature update."""
return {'he1': torch.cat([edges.src['hv'], edges.data['he']], dim=1)}
def apply_edges2(self, edges):
"""Edge feature update."""
return {'he2': torch.cat([edges.dst['hv_new'], edges.data['he1']], dim=1)}
def forward(self, g, node_feats, edge_feats):
"""
Parameters
----------
g : DGLGraph or BatchedDGLGraph
Constructed DGLGraphs.
node_feats : float32 tensor of shape (V, N1)
Input node features. V for the number of nodes and N1 for the feature size.
edge_feats : float32 tensor of shape (E, N2)
Input edge features. E for the number of edges and N2 for the feature size.
Returns
-------
float32 tensor of shape (V, N3)
Updated node features.
"""
g = g.local_var()
g.ndata['hv'] = node_feats
g.ndata['hv_new'] = self.project_node(node_feats)
g.edata['he'] = edge_feats
g.apply_edges(self.apply_edges1)
g.edata['he1'] = self.project_edge1(g.edata['he1'])
g.apply_edges(self.apply_edges2)
logits = self.project_edge2(g.edata['he2'])
return self.attentive_gru(g, logits, g.edata['he1'], g.ndata['hv_new'])
class GNNLayer(nn.Module):
"""GNNLayer for updating node features.
Parameters
----------
node_feat_size : int
Size for the input node features.
graph_feat_size : int
Size for the input graph features.
dropout : float
The probability for performing dropout.
"""
def __init__(self, node_feat_size, graph_feat_size, dropout):
super(GNNLayer, self).__init__()
self.project_edge = nn.Sequential(
nn.Dropout(dropout),
nn.Linear(2 * node_feat_size, 1),
nn.LeakyReLU()
)
self.attentive_gru = AttentiveGRU2(node_feat_size, graph_feat_size, dropout)
def apply_edges(self, edges):
"""Edge feature update by concatenating the features of the destination
and source nodes."""
return {'he': torch.cat([edges.dst['hv'], edges.src['hv']], dim=1)}
def forward(self, g, node_feats):
"""
Parameters
----------
g : DGLGraph or BatchedDGLGraph
Constructed DGLGraphs.
node_feats : float32 tensor of shape (V, N1)
Input node features. V for the number of nodes and N1 for the feature size.
Returns
-------
float32 tensor of shape (V, N1)
Updated node features.
"""
g = g.local_var()
g.ndata['hv'] = node_feats
g.apply_edges(self.apply_edges)
logits = self.project_edge(g.edata['he'])
return self.attentive_gru(g, logits, node_feats)
class GlobalPool(nn.Module):
"""Graph feature update.
Parameters
----------
node_feat_size : int
Size for the input node features.
graph_feat_size : int
Size for the input graph features.
dropout : float
The probability for performing dropout.
"""
def __init__(self, node_feat_size, graph_feat_size, dropout):
super(GlobalPool, self).__init__()
self.compute_logits = nn.Sequential(
nn.Linear(node_feat_size + graph_feat_size, 1),
nn.LeakyReLU()
)
self.project_nodes = nn.Sequential(
nn.Dropout(dropout),
nn.Linear(node_feat_size, graph_feat_size)
)
self.gru = nn.GRUCell(graph_feat_size, graph_feat_size)
def forward(self, g, node_feats, g_feats, get_node_weight=False):
"""
Parameters
----------
g : DGLGraph or BatchedDGLGraph
Constructed DGLGraphs.
node_feats : float32 tensor of shape (V, N1)
Input node features. V for the number of nodes and N1 for the feature size.
g_feats : float32 tensor of shape (G, N2)
Input graph features. G for the number of graphs and N2 for the feature size.
get_node_weight : bool
Whether to get the weights of atoms during readout.
Returns
-------
float32 tensor of shape (G, N2)
Updated graph features.
float32 tensor of shape (V, 1)
The weights of nodes in readout.
"""
with g.local_scope():
g.ndata['z'] = self.compute_logits(
torch.cat([dgl.broadcast_nodes(g, F.relu(g_feats)), node_feats], dim=1))
g.ndata['a'] = dgl.softmax_nodes(g, 'z')
g.ndata['hv'] = self.project_nodes(node_feats)
context = F.elu(dgl.sum_nodes(g, 'hv', 'a'))
if get_node_weight:
return self.gru(context, g_feats), g.ndata['a']
else:
return self.gru(context, g_feats)
class AttentiveFP(nn.Module):
"""`Pushing the Boundaries of Molecular Representation for Drug Discovery with the Graph
Attention Mechanism <https://www.ncbi.nlm.nih.gov/pubmed/31408336>`__
Parameters
----------
node_feat_size : int
Size for the input node (atom) features.
edge_feat_size : int
Size for the input edge (bond) features.
num_layers : int
Number of GNN layers.
num_timesteps : int
Number of timesteps for updating the molecular representation with GRU.
graph_feat_size : int
Size of the learned graph representation (molecular fingerprint).
output_size : int
Size of the prediction (target labels).
dropout : float
The probability for performing dropout.
"""
def __init__(self,
node_feat_size,
edge_feat_size,
num_layers,
num_timesteps,
graph_feat_size,
output_size,
dropout):
super(AttentiveFP, self).__init__()
self.init_context = GetContext(node_feat_size, edge_feat_size, graph_feat_size, dropout)
self.gnn_layers = nn.ModuleList()
for i in range(num_layers - 1):
self.gnn_layers.append(GNNLayer(graph_feat_size, graph_feat_size, dropout))
self.readouts = nn.ModuleList()
for t in range(num_timesteps):
self.readouts.append(GlobalPool(graph_feat_size, graph_feat_size, dropout))
self.predict = nn.Sequential(
nn.Dropout(dropout),
nn.Linear(graph_feat_size, output_size)
)
def forward(self, g, node_feats, edge_feats, get_node_weight=False):
"""
Parameters
----------
g : DGLGraph or BatchedDGLGraph
Constructed DGLGraphs.
node_feats : float32 tensor of shape (V, N1)
Input node features. V for the number of nodes and N1 for the feature size.
edge_feats : float32 tensor of shape (E, N2)
Input edge features. E for the number of edges and N2 for the feature size.
get_node_weight : bool
Whether to get the weights of atoms during readout.
Returns
-------
float32 tensor of shape (G, N3)
Prediction for the graphs. G for the number of graphs and N3 for the output size.
node_weights : list of float32 tensors of shape (V, 1)
Weights of nodes in all readout operations.
"""
node_feats = self.init_context(g, node_feats, edge_feats)
for gnn in self.gnn_layers:
node_feats = gnn(g, node_feats)
with g.local_scope():
g.ndata['hv'] = node_feats
g_feats = dgl.sum_nodes(g, 'hv')
if get_node_weight:
node_weights = []
for readout in self.readouts:
if get_node_weight:
g_feats, node_weights_t = readout(g, node_feats, g_feats, get_node_weight)
node_weights.append(node_weights_t)
else:
g_feats = readout(g, node_feats, g_feats)
if get_node_weight:
return self.predict(g_feats), node_weights
else:
return self.predict(g_feats)
......@@ -9,6 +9,7 @@ from .dgmg import DGMG
from .mgcn import MGCNModel
from .mpnn import MPNNModel
from .schnet import SchNet
from .attentive_fp import AttentiveFP
from ...data.utils import _get_dgl_url, download, get_download_dir, extract_archive
URL = {
......@@ -17,6 +18,7 @@ URL = {
'MGCN_Alchemy': 'pre_trained/mgcn_alchemy.pth',
'SCHNET_Alchemy': 'pre_trained/schnet_alchemy.pth',
'MPNN_Alchemy': 'pre_trained/mpnn_alchemy.pth',
'AttentiveFP_Aromaticity': 'pre_trained/attentivefp_aromaticity.pth',
'DGMG_ChEMBL_canonical' : 'pre_trained/dgmg_ChEMBL_canonical.pth',
'DGMG_ChEMBL_random' : 'pre_trained/dgmg_ChEMBL_random.pth',
'DGMG_ZINC_canonical' : 'pre_trained/dgmg_ZINC_canonical.pth',
......@@ -69,6 +71,7 @@ def load_pretrained(model_name, log=True):
* ``'MGCN_Alchemy'``
* ``'SCHNET_Alchemy'``
* ``'MPNN_Alchemy'``
* ``'AttentiveFP_Aromaticity'``
* ``'DGMG_ChEMBL_canonical'``
* ``'DGMG_ChEMBL_random'``
* ``'DGMG_ZINC_canonical'``
......@@ -122,6 +125,15 @@ def load_pretrained(model_name, log=True):
elif model_name == 'MPNN_Alchemy':
model = MPNNModel(output_dim=12)
elif model_name == 'AttentiveFP_Aromaticity':
model = AttentiveFP(node_feat_size=39,
edge_feat_size=10,
num_layers=2,
num_timesteps=2,
graph_feat_size=200,
output_size=1,
dropout=0.2)
elif model_name == "JTNN_ZINC":
default_dir = get_download_dir()
vocab_file = '{}/jtnn/{}.txt'.format(default_dir, 'vocab')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment