Unverified Commit e590feeb authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[Model Zoo] GAT on Tox21 (#793)

* GAT

* Fix mistake

* Fix

* hotfix

* Fix

* Fix

* Fix

* Fix

* Fix

* Fix

* Fix

* Update

* Update

* Update

* Fix style

* Hotfix

* Hotfix

* Hotfix

* Fix

* Fix

* Update

* CI trial

* Update

* Update

* Update
parent ad15947f
......@@ -17,10 +17,10 @@ Contributors
* [Tianyi Zhang](https://github.com/Tiiiger): SGC in Pytorch
* [Jun Chen](https://github.com/kitaev-chen): GIN in Pytorch
* [Aymen Waheb](https://github.com/aymenwah): APPNP in Pytorch
* [Chengqiang Lu](https://github.com/geekinglcq): MGCN, SchNet and MPNN in PyTorch
Other improvement
* [Brett Koonce](https://github.com/brettkoonce)
* [@giuseppefutia](https://github.com/giuseppefutia)
* [@mori97](https://github.com/mori97)
* Hao Jin
# DGL for Chemistry
With atoms being nodes and bonds being edges, molecular graphs are among the core objects for study in drug discovery.
As drug discovery is known to be costly and time consuming, deep learning on graphs can be potentially beneficial for
improving the efficiency of drug discovery [1], [2], [9].
With atoms being nodes and bonds being edges, molecular graphs are among the core objects for study in Chemistry.
Deep learning on graphs can be beneficial for various applications in Chemistry like drug and material discovery
[1], [2], [12].
To make it easy for domain scientists, the DGL team releases a model zoo for Chemistry, focusing on two particular cases
-- property prediction and target generation/optimization.
......@@ -14,7 +14,7 @@ the chemistry community and the deep learning community to further their researc
Before you proceed, make sure you have installed the dependencies below:
- PyTorch 1.2
- Check the [official website](https://pytorch.org/) for installation guide
- Check the [official website](https://pytorch.org/) for installation guide.
- RDKit 2018.09.3
- We recommend installation with `conda install -c conda-forge rdkit==2018.09.3`. For other installation recipes,
see the [official documentation](https://www.rdkit.org/docs/Install.html).
......@@ -40,16 +40,21 @@ Graph neural networks make it possible for a data-driven representation of molec
molecular graph topology, which may be viewed as a learned fingerprint [3].
### Models
- **Graph Convolutional Network**: Graph Convolutional Networks (GCN) have been one of the most popular graph neural
networks and they can be easily extended for graph level prediction.
- **SchNet**: SchNet is a novel deep learning architecture modeling quantum interactions in molecules which utilize
the continuous-filter convolutional layers [4].
- **Multilevel Graph Convolutional neural Network**: Multilevel Graph Convolutional neural Network (MGCN) is a
well-designed hierarchical graph neural network directly extracts features from the conformation and spatial information
followed by the multilevel interactions [5].
- **Message Passing Neural Network**: Message Passing Neural Network (MPNN) is a well-designed network with edge network
(enn) as front end and uses Set2Set to output prediction [6].
- **Graph Convolutional Networks** [3], [9]: Graph Convolutional Networks (GCN) have been one of the most popular graph
neural networks and they can be easily extended for graph level prediction.
- **Graph Attention Networks** [10]: Graph Attention Networks (GATs) incorporate multi-head attention into GCNs,
explicitly modeling the interactions between adjacent atoms.
- **SchNet** [4]: SchNet is a novel deep learning architecture modeling quantum interactions in molecules which utilize
the continuous-filter convolutional layers.
- **Multilevel Graph Convolutional neural Network** [5]: Multilevel Graph Convolutional neural Network (MGCN) is a well-designed
hierarchical graph neural network directly extracts features from the conformation and spatial information followed
by the multilevel interactions.
- **Message Passing Neural Network** [6]: Message Passing Neural Network (MPNN) is a well-designed network with edge network (enn)
as front end and Set2Set for output prediction.
### Example Usage of Pre-trained Models
![](https://s3.us-east-2.amazonaws.com/dgl.ai/model_zoo/drug_discovery/gcn_model_zoo_example.png)
## Generative Models
......@@ -66,8 +71,14 @@ Generative models are known to be difficult for evaluation. [GuacaMol](https://g
are also two accompanying review papers that are well written [7], [8].
### Models
- **Deep Generative Models of Graphs (DGMG)**: A very general framework for graph distribution learning by progressively
adding atoms and bonds.
- **Deep Generative Models of Graphs (DGMG)** [11]: A very general framework for graph distribution learning by
progressively adding atoms and bonds.
### Example Usage of Pre-trained Models
![](https://s3.us-east-2.amazonaws.com/dgl.ai/model_zoo/drug_discovery/dgmg_model_zoo_example1.png)
![](https://s3.us-east-2.amazonaws.com/dgl.ai/model_zoo/drug_discovery/dgmg_model_zoo_example2.png)
## References
......@@ -85,12 +96,20 @@ information processing systems (NeurIPS)*, 2224-2232.
[5] Lu et al. Molecular Property Prediction: A Multilevel Quantum Interactions Modeling Perspective.
*The 33rd AAAI Conference on Artificial Intelligence*.
[6] Gilmer et al. (2017) Neural Message Passing for Quantum Chemistry. *Proceedings of the 34th International Conference
on Machine Learning* JMLR. 1263-1272.
[6] Gilmer et al. (2017) Neural Message Passing for Quantum Chemistry. *Proceedings of the 34th International Conference on
Machine Learning* JMLR. 1263-1272.
[7] Brown et al. (2019) GuacaMol: Benchmarking Models for de Novo Molecular Design. *J. Chem. Inf. Model*, 2019, 59, 3,
1096-1108.
[8] Polykovskiy et al. (2019) Molecular Sets (MOSES): A Benchmarking Platform for Molecular Generation Models. *arXiv*.
[9] Goh et al. (2017) Deep learning for computational chemistry. *Journal of Computational Chemistry* 16, 1291-1307.
[9] Kipf et al. (2017) Semi-Supervised Classification with Graph Convolutional Networks.
*The International Conference on Learning Representations (ICLR)*.
[10] Veličković et al. (2018) Graph Attention Networks.
*The International Conference on Learning Representations (ICLR)*.
[11] Li et al. (2018) Learning Deep Generative Models of Graphs. *arXiv preprint arXiv:1803.03324*.
[12] Goh et al. (2017) Deep learning for computational chemistry. *Journal of Computational Chemistry* 16, 1291-1307.
......@@ -12,14 +12,21 @@ stress response pathways. Each target yields a binary prediction problem. Molecu
into training, validation and test set with a 80/10/10 ratio. By default we follow their split method.
### Models
- **Graph Convolutional Network** [2]. Graph Convolutional Networks (GCN) have been one of the most popular graph neural
- **Graph Convolutional Network** [2], [3]. Graph Convolutional Networks (GCN) have been one of the most popular graph neural
networks and they can be easily extended for graph level prediction. MoleculeNet [1] reports baseline results of graph
convolutions over multiple datasets.
- **Graph Attention Networks** [7]: Graph Attention Networks (GATs) incorporate multi-head attention into GCNs,
explicitly modeling the interactions between adjacent atoms.
### Usage
To train a model from scratch, simply call `python classification.py`. To skip training and use the pre-trained model,
call `python classification.py -p`.
Use `classification.py` with arguments
```
-m {GCN, GAT}, MODEL, model to use
-d {Tox21}, DATASET, dataset to use
```
If you want to use the pre-trained model, simply add `-p`.
We use GPU whenever it is available.
......@@ -31,10 +38,16 @@ We use GPU whenever it is available.
| ---------------- | ---------------------- |
| MoleculeNet [1] | 0.829 |
| [DeepChem example](https://github.com/deepchem/deepchem/blob/master/examples/tox21/tox21_tensorgraph_graph_conv.py) | 0.813 |
| Pretrained model | 0.827 |
| Pretrained model | 0.826 |
Note that the dataset is randomly split so these numbers are only for reference and they do not necessarily suggest
a real difference.
Note that due to some possible randomness you may get different numbers for DeepChem example and our model. To get
match exact results for this model, please use the pre-trained model as in the usage section.
#### GAT on Tox21
| Source | Averaged ROC-AUC Score |
| ---------------- | ---------------------- |
| Pretrained model | 0.827 |
## Dataset Customization
......@@ -47,16 +60,20 @@ Regression tasks require assigning continuous labels to a molecule, e.g. molecul
### Dataset
- **Alchemy**. The [Alchemy Dataset](https://alchemy.tencent.com/) is introduced by Tencent Quantum Lab to facilitate the development of new machine learning models useful for chemistry and materials science.
The dataset lists 12 quantum mechanical properties of 130,000+ organic molecules comprising up to 12 heavy atoms (C, N, O, S, F and Cl), sampled from the [GDBMedChem](http://gdb.unibe.ch/downloads/) database.
These properties have been calculated using the open-source computational chemistry program Python-based Simulation of Chemistry Framework ([PySCF](https://github.com/pyscf/pyscf)).
The Alchemy dataset expands on the volume and diversity of existing molecular datasets such as QM9.
- **Alchemy**. The [Alchemy Dataset](https://alchemy.tencent.com/) is introduced by Tencent Quantum Lab to facilitate the development of new
machine learning models useful for chemistry and materials science. The dataset lists 12 quantum mechanical properties of 130,000+ organic
molecules comprising up to 12 heavy atoms (C, N, O, S, F and Cl), sampled from the [GDBMedChem](http://gdb.unibe.ch/downloads/) database.
These properties have been calculated using the open-source computational chemistry program Python-based Simulation of Chemistry Framework
([PySCF](https://github.com/pyscf/pyscf)). The Alchemy dataset expands on the volume and diversity of existing molecular datasets such as QM9.
### Models
- **SchNet**: SchNet is a novel deep learning architecture modeling quantum interactions in molecules which utilize the continuous-filter convolutional layers [3].
- **Multilevel Graph Convolutional neural Network**: Multilevel Graph Convolutional neural Network (MGCN) is a well-designed hierarchical graph neural network directly extracts features from the conformation and spatial information followed by the multilevel interactions [4].
- **Message Passing Neural Network**: Message Passing Neural Network (MPNN) is a well-designed network with edge network (enn) as front end and us Set2Set to output prediction [5].
- **SchNet**: SchNet is a novel deep learning architecture modeling quantum interactions in molecules which utilize the continuous-filter
convolutional layers [4].
- **Multilevel Graph Convolutional neural Network**: Multilevel Graph Convolutional neural Network (MGCN) is a hierarchical
graph neural network directly extracts features from the conformation and spatial information followed by the multilevel interactions [5].
- **Message Passing Neural Network**: Message Passing Neural Network (MPNN) is a network with edge network (enn) as front end
and Set2Set for output prediction [6].
### Usage
......@@ -71,22 +88,27 @@ The model option must be one of 'sch', 'mgcn' or 'mpnn'.
|Model |Mean Absolute Error (MAE)|
|-------------|-------------------------|
|SchNet[3] |0.065|
|MGCN[4] |0.050|
|MPNN[5] |0.056|
|SchNet[4] |0.065|
|MGCN[5] |0.050|
|MPNN[6] |0.056|
## References
[1] Wu et al. (2017) MoleculeNet: a benchmark for molecular machine learning. *Chemical Science* 9, 513-530.
[2] Kipf et al. (2017) Semi-Supervised Classification with Graph Convolutional Networks.
[2] Duvenaud et al. (2015) Convolutional networks on graphs for learning molecular fingerprints. *Advances in neural
information processing systems (NeurIPS)*, 2224-2232.
[3] Kipf et al. (2017) Semi-Supervised Classification with Graph Convolutional Networks.
*The International Conference on Learning Representations (ICLR)*.
[3] Schütt et al. (2017) SchNet: A continuous-filter convolutional neural network for modeling quantum interactions.
[4] Schütt et al. (2017) SchNet: A continuous-filter convolutional neural network for modeling quantum interactions.
*Advances in Neural Information Processing Systems (NeurIPS)*, 992-1002.
[4] Lu et al. (2019) Molecular Property Prediction: A Multilevel Quantum Interactions Modeling Perspective.
[5] Lu et al. (2019) Molecular Property Prediction: A Multilevel Quantum Interactions Modeling Perspective.
*The 33rd AAAI Conference on Artificial Intelligence*.
[5] Gilmer et al. (2017) Neural Message Passing for Quantum Chemistry. *Proceedings of the 34th International Conference on
[6] Gilmer et al. (2017) Neural Message Passing for Quantum Chemistry. *Proceedings of the 34th International Conference on
Machine Learning*, JMLR. 1263-1272.
[7] Veličković et al. (2018) Graph Attention Networks.
*The International Conference on Learning Representations (ICLR)*.
from dgl.data import Tox21
from dgl.data.utils import split_dataset
from dgl import model_zoo
import torch
......@@ -8,93 +7,108 @@ from torch.utils.data import DataLoader
from utils import Meter, EarlyStopping, collate_molgraphs, set_random_seed
def main(args):
device = "cuda" if torch.cuda.is_available() else "cpu"
batch_size = 128
learning_rate = 0.001
num_epochs = 100
set_random_seed()
# Interchangeable with other Dataset
dataset = Tox21()
atom_data_field = 'h'
trainset, valset, testset = split_dataset(dataset, [0.8, 0.1, 0.1])
train_loader = DataLoader(
trainset, batch_size=batch_size, collate_fn=collate_molgraphs)
val_loader = DataLoader(
valset, batch_size=batch_size, collate_fn=collate_molgraphs)
test_loader = DataLoader(
testset, batch_size=batch_size, collate_fn=collate_molgraphs)
if args.pre_trained:
num_epochs = 0
model = model_zoo.chem.load_pretrained('GCN_Tox21')
else:
# Interchangeable with other models
model = model_zoo.chem.GCNClassifier(in_feats=74,
gcn_hidden_feats=[64, 64],
n_tasks=dataset.n_tasks)
loss_criterion = BCEWithLogitsLoss(pos_weight=torch.tensor(
dataset.task_pos_weights).to(device), reduction='none')
optimizer = Adam(model.parameters(), lr=learning_rate)
stopper = EarlyStopping(patience=10)
model.to(device)
for epoch in range(num_epochs):
def run_a_train_epoch(args, epoch, model, data_loader, loss_criterion, optimizer):
model.train()
print('Start training')
train_meter = Meter()
for batch_id, batch_data in enumerate(train_loader):
for batch_id, batch_data in enumerate(data_loader):
smiles, bg, labels, mask = batch_data
atom_feats = bg.ndata.pop(atom_data_field)
atom_feats, labels, mask = atom_feats.to(device), labels.to(device), mask.to(device)
logits = model(atom_feats, bg)
atom_feats = bg.ndata.pop(args['atom_data_field'])
atom_feats, labels, mask = atom_feats.to(args['device']), \
labels.to(args['device']), \
mask.to(args['device'])
logits = model(bg, atom_feats)
# Mask non-existing labels
loss = (loss_criterion(logits, labels)
* (mask != 0).float()).mean()
loss = (loss_criterion(logits, labels) * (mask != 0).float()).mean()
optimizer.zero_grad()
loss.backward()
optimizer.step()
print('epoch {:d}/{:d}, batch {:d}/{:d}, loss {:.4f}'.format(
epoch + 1, num_epochs, batch_id + 1, len(train_loader), loss.item()))
epoch + 1, args['num_epochs'], batch_id + 1, len(data_loader), loss.item()))
train_meter.update(logits, labels, mask)
train_roc_auc = train_meter.roc_auc_averaged_over_tasks()
print('epoch {:d}/{:d}, training roc-auc score {:.4f}'.format(
epoch + 1, num_epochs, train_roc_auc))
epoch + 1, args['num_epochs'], train_roc_auc))
val_meter = Meter()
def run_an_eval_epoch(args, model, data_loader):
model.eval()
eval_meter = Meter()
with torch.no_grad():
for batch_id, batch_data in enumerate(val_loader):
for batch_id, batch_data in enumerate(data_loader):
smiles, bg, labels, mask = batch_data
atom_feats = bg.ndata.pop(atom_data_field)
atom_feats, labels = atom_feats.to(device), labels.to(device)
logits = model(atom_feats, bg)
val_meter.update(logits, labels, mask)
atom_feats = bg.ndata.pop(args['atom_data_field'])
atom_feats, labels = atom_feats.to(args['device']), labels.to(args['device'])
logits = model(bg, atom_feats)
eval_meter.update(logits, labels, mask)
return eval_meter.roc_auc_averaged_over_tasks()
val_roc_auc = val_meter.roc_auc_averaged_over_tasks()
if stopper.step(val_roc_auc, model):
break
def main(args):
args['device'] = "cuda" if torch.cuda.is_available() else "cpu"
set_random_seed()
# Interchangeable with other Dataset
if args['dataset'] == 'Tox21':
from dgl.data.chem import Tox21
dataset = Tox21()
trainset, valset, testset = split_dataset(dataset, args['train_val_test_split'])
train_loader = DataLoader(trainset, batch_size=args['batch_size'], collate_fn=collate_molgraphs)
val_loader = DataLoader(valset, batch_size=args['batch_size'], collate_fn=collate_molgraphs)
test_loader = DataLoader(testset, batch_size=args['batch_size'], collate_fn=collate_molgraphs)
if args['pre_trained']:
args['num_epochs'] = 0
model = model_zoo.chem.load_pretrained(args['exp'])
else:
# Interchangeable with other models
if args['model'] == 'GCN':
model = model_zoo.chem.GCNClassifier(in_feats=args['in_feats'],
gcn_hidden_feats=args['gcn_hidden_feats'],
classifier_hidden_feats=args['classifier_hidden_feats'],
n_tasks=dataset.n_tasks)
elif args['model'] == 'GAT':
model = model_zoo.chem.GATClassifier(in_feats=args['in_feats'],
gat_hidden_feats=args['gat_hidden_feats'],
num_heads=args['num_heads'],
classifier_hidden_feats=args['classifier_hidden_feats'],
n_tasks=dataset.n_tasks)
loss_criterion = BCEWithLogitsLoss(pos_weight=torch.tensor(
dataset.task_pos_weights).to(args['device']), reduction='none')
optimizer = Adam(model.parameters(), lr=args['lr'])
stopper = EarlyStopping(patience=args['patience'])
model.to(args['device'])
for epoch in range(args['num_epochs']):
# Train
run_a_train_epoch(args, epoch, model, train_loader, loss_criterion, optimizer)
# Validation and early stop
val_roc_auc = run_an_eval_epoch(args, model, val_loader)
early_stop = stopper.step(val_roc_auc, model)
print('epoch {:d}/{:d}, validation roc-auc score {:.4f}, best validation roc-auc score {:.4f}'.format(
epoch + 1, num_epochs, val_roc_auc, stopper.best_score))
epoch + 1, args['num_epochs'], val_roc_auc, stopper.best_score))
if early_stop:
break
test_meter = Meter()
model.eval()
for batch_id, batch_data in enumerate(test_loader):
smiles, bg, labels, mask = batch_data
atom_feats = bg.ndata.pop(atom_data_field)
atom_feats, labels = atom_feats.to(device), labels.to(device)
logits = model(atom_feats, bg)
test_meter.update(logits, labels, mask)
print('test roc-auc score {:.4f}'.format(test_meter.roc_auc_averaged_over_tasks()))
if not args['pre_trained']:
stopper.load_checkpoint(model)
test_roc_auc = run_an_eval_epoch(args, model, test_loader)
print('test roc-auc score {:.4f}'.format(test_roc_auc))
if __name__ == '__main__':
import argparse
from configure import get_exp_configure
parser = argparse.ArgumentParser(description='Molecule Classification')
parser.add_argument('-m', '--model', type=str, choices=['GCN', 'GAT'],
help='Model to use')
parser.add_argument('-d', '--dataset', type=str, choices=['Tox21'],
help='Dataset to use')
parser.add_argument('-p', '--pre-trained', action='store_true',
help='Whether to skip training and use a pre-trained model')
args = parser.parse_args()
args = parser.parse_args().__dict__
args['exp'] = '_'.join([args['model'], args['dataset']])
args.update(get_exp_configure(args['exp']))
main(args)
GCN_Tox21 = {
'batch_size': 128,
'lr': 1e-3,
'num_epochs': 100,
'atom_data_field': 'h',
'train_val_test_split': [0.8, 0.1, 0.1],
'in_feats': 74,
'gcn_hidden_feats': [64, 64],
'classifier_hidden_feats': 64,
'patience': 10
}
GAT_Tox21 = {
'batch_size': 128,
'lr': 1e-3,
'num_epochs': 100,
'atom_data_field': 'h',
'train_val_test_split': [0.8, 0.1, 0.1],
'in_feats': 74,
'gat_hidden_feats': [32, 32],
'classifier_hidden_feats': 64,
'num_heads': [4, 4],
'patience': 10
}
experiment_configures = {
'GCN_Tox21': GCN_Tox21,
'GAT_Tox21': GAT_Tox21
}
def get_exp_configure(exp_name):
return experiment_configures[exp_name]
......@@ -74,11 +74,11 @@ class EarlyStopping(object):
def save_checkpoint(self, model):
'''Saves model when the metric on the validation set gets improved.'''
torch.save(model.state_dict(), self.filename)
torch.save({'model_state_dict': model.state_dict()}, self.filename)
def load_checkpoint(self, model):
'''Load model saved with early stopping.'''
model.load_state_dict(torch.load(self.filename))
model.load_state_dict(torch.load(self.filename)['model_state_dict'])
def collate_molgraphs(data):
"""Batching a list of datapoints for dataloader
......
......@@ -11,7 +11,6 @@ from .reddit import RedditDataset
from .ppi import PPIDataset
from .tu import TUDataset
from .gindt import GINDataset
# from .chem import Tox21, alchemy
def register_data_args(parser):
......
......@@ -10,11 +10,10 @@ import pickle
import zipfile
from collections import defaultdict
import dgl
import dgl.backend as F
from dgl.data.utils import download, get_download_dir
from .utils import mol_to_complete_graph
from ..utils import download, get_download_dir
from ...batched_graph import batch
from ... import backend as F
try:
import pandas as pd
......@@ -40,12 +39,12 @@ class AlchemyBatcher(object):
self.graph = graph
self.label = label
def batcher_dev(batch):
def batcher_dev(batch_data):
"""Batch datapoints
Parameters
----------
batch : list
batch_data : list
batch[i][0] gives the DGLGraph for the ith datapoint,
and batch[i][1] gives the label for the ith datapoint.
......@@ -54,15 +53,79 @@ def batcher_dev(batch):
AlchemyBatcher
An object holding the batch of data
"""
graphs, labels = zip(*batch)
batch_graphs = dgl.batch(graphs)
graphs, labels = zip(*batch_data)
batch_graphs = batch(graphs)
labels = F.stack(labels, 0)
return AlchemyBatcher(graph=batch_graphs, label=labels)
class TencentAlchemyDataset(object):
fdef_name = osp.join(RDConfig.RDDataDir, 'BaseFeatures.fdef')
chem_feature_factory = ChemicalFeatures.BuildFeatureFactory(fdef_name)
"""`Tencent Alchemy Dataset <https://arxiv.org/abs/1906.09427>`__
Parameters
----------
mode : str
'dev', 'valid' or 'test', default to be 'dev'
transform : transform operation on DGLGraphs
Default to be None.
from_raw : bool
Whether to process dataset from scratch or use a
processed one for faster speed. Default to be False.
"""
def __init__(self, mode='dev', transform=None, from_raw=False):
assert mode in ['dev', 'valid', 'test'], "mode should be dev/valid/test"
self.mode = mode
self.transform = transform
# Construct DGLGraphs from raw data or use the preprocessed data
self.from_raw = from_raw
file_dir = osp.join(get_download_dir(), './Alchemy_data')
if not from_raw:
file_name = "%s_processed" % (mode)
else:
file_name = "%s_single_sdf" % (mode)
self.file_dir = pathlib.Path(file_dir, file_name)
self.zip_file_path = pathlib.Path(file_dir, file_name + '.zip')
download(_urls['Alchemy'] + file_name + '.zip',
path=str(self.zip_file_path))
if not os.path.exists(str(self.file_dir)):
archive = zipfile.ZipFile(self.zip_file_path)
archive.extractall(file_dir)
archive.close()
self._load()
def _load(self):
if self.mode == 'dev':
if not self.from_raw:
with open(osp.join(self.file_dir, "dev_graphs.pkl"), "rb") as f:
self.graphs = pickle.load(f)
with open(osp.join(self.file_dir, "dev_labels.pkl"), "rb") as f:
self.labels = pickle.load(f)
else:
target_file = pathlib.Path(self.file_dir, "dev_target.csv")
self.target = pd.read_csv(
target_file,
index_col=0,
usecols=['gdb_idx',] + ['property_%d' % x for x in range(12)])
self.target = self.target[['property_%d' % x for x in range(12)]]
self.graphs, self.labels = [], []
supp = Chem.SDMolSupplier(
osp.join(self.file_dir, self.mode + ".sdf"))
cnt = 0
for sdf, label in zip(supp, self.target.iterrows()):
graph = mol_to_complete_graph(sdf, atom_featurizer=self.alchemy_nodes,
bond_featurizer=self.alchemy_edges)
cnt += 1
self.graphs.append(graph)
label = F.tensor(np.array(label[1].tolist()).astype(np.float32))
self.labels.append(label)
self.normalize()
print(len(self.graphs), "loaded!")
def alchemy_nodes(self, mol):
"""Featurization for all atoms in a molecule. The atom indices
......@@ -135,8 +198,8 @@ class TencentAlchemyDataset(object):
return atom_feats_dict
def alchemy_edges(self, mol, self_loop=False):
"""Featurization for all bonds in a molecule. The bond indices
will be preserved.
"""Featurization for all bonds in a molecule.
The bond indices will be preserved.
Parameters
----------
......@@ -182,66 +245,16 @@ class TencentAlchemyDataset(object):
return bond_feats_dict
def __init__(self, mode='dev', transform=None, from_raw=False):
assert mode in ['dev', 'valid', 'test'], "mode should be dev/valid/test"
self.mode = mode
self.transform = transform
# Construct the dgl graph from raw data or use the preprocessed data directly
self.from_raw = from_raw
file_dir = osp.join(get_download_dir(), './Alchemy_data')
if not from_raw:
file_name = "%s_processed" % (mode)
else:
file_name = "%s_single_sdf" % (mode)
self.file_dir = pathlib.Path(file_dir, file_name)
self.zip_file_path = pathlib.Path(file_dir, file_name + '.zip')
download(_urls['Alchemy'] + file_name + '.zip',
path=str(self.zip_file_path))
if not os.path.exists(str(self.file_dir)):
archive = zipfile.ZipFile(self.zip_file_path)
archive.extractall(file_dir)
archive.close()
self._load()
def _load(self):
if self.mode == 'dev':
if not self.from_raw:
with open(osp.join(self.file_dir, "dev_graphs.pkl"), "rb") as f:
self.graphs = pickle.load(f)
with open(osp.join(self.file_dir, "dev_labels.pkl"), "rb") as f:
self.labels = pickle.load(f)
else:
target_file = pathlib.Path(self.file_dir, "dev_target.csv")
self.target = pd.read_csv(
target_file,
index_col=0,
usecols=[
'gdb_idx',
] + ['property_%d' % x for x in range(12)])
self.target = self.target[[
'property_%d' % x for x in range(12)
]]
self.graphs, self.labels = [], []
supp = Chem.SDMolSupplier(
osp.join(self.file_dir, self.mode + ".sdf"))
cnt = 0
for sdf, label in zip(supp, self.target.iterrows()):
graph = mol_to_complete_graph(sdf, atom_featurizer=self.alchemy_nodes,
bond_featurizer=self.alchemy_edges)
cnt += 1
self.graphs.append(graph)
label = F.tensor(np.array(label[1].tolist()).astype(np.float32))
self.labels.append(label)
self.normalize()
print(len(self.graphs), "loaded!")
def normalize(self, mean=None, std=None):
"""Set mean and std or compute from labels for future normalization.
Parameters
----------
mean : int or float
Default to be None.
std : int or float
Default to be None.
"""
labels = np.array([i.numpy() for i in self.labels])
if mean is None:
mean = np.mean(labels, axis=0)
......@@ -260,7 +273,20 @@ class TencentAlchemyDataset(object):
return g, l
def split(self, train_size=0.8):
"""Split the dataset into two AlchemySubset for train&test."""
"""Split the dataset into two AlchemySubset for train&test.
Parameters
----------
train_size : float
Proportion of dataset to use for training. Default to be 0.8.
Returns
-------
train_set : AlchemySubset
Dataset for training
test_set : AlchemySubset
Dataset for test
"""
assert 0 < train_size < 1
train_num = int(len(self.graphs) * train_size)
train_set = AlchemySubset(self.graphs[:train_num],
......@@ -275,6 +301,19 @@ class AlchemySubset(TencentAlchemyDataset):
"""
Sub-dataset split from TencentAlchemyDataset.
Used to construct the training & test set.
Parameters
----------
graphs : list of DGLGraphs
DGLGraphs for datapoints in the subset
labels : list of tensors
Labels for datapoints in the subset
mean : int or float
Mean of labels in the subset
std : int or float
Std of labels in the subset
transform : transform operation on DGLGraphs
Default to be None.
"""
def __init__(self, graphs, labels, mean=0, std=1, transform=None):
super(AlchemySubset, self).__init__()
......
# pylint: disable=C0111
"""Model Zoo Package"""
from .gcn import GCNClassifier
from .classifiers import GCNClassifier, GATClassifier
from .sch import SchNetModel
from .mgcn import MGCNModel
from .mpnn import MPNNModel
......
# pylint: disable=C0111, C0103, C0200
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
from dgl import BatchedDGLGraph
from .gnn import GCNLayer, GATLayer
from ...nn.pytorch import WeightAndSum
class MLPBinaryClassifier(nn.Module):
"""MLP for soft binary classification over multiple tasks from molecule representations.
Parameters
----------
in_feats : int
Number of input molecular graph features
hidden_feats : int
Number of molecular graph features in hidden layers
n_tasks : int
Number of tasks, also output size
dropout : float
The probability for dropout. Default to be 0., i.e. no
dropout is performed.
"""
def __init__(self, in_feats, hidden_feats, n_tasks, dropout=0.):
super(MLPBinaryClassifier, self).__init__()
self.predict = nn.Sequential(
nn.Dropout(dropout),
nn.Linear(in_feats, hidden_feats),
nn.ReLU(),
nn.BatchNorm1d(hidden_feats),
nn.Linear(hidden_feats, n_tasks)
)
def forward(self, h):
"""Perform soft binary classification over multiple tasks
Parameters
----------
h : FloatTensor of shape (B, M3)
* B is the number of molecules in a batch
* M3 is the input molecule feature size, must match in_feats in initialization
Returns
-------
FloatTensor of shape (B, n_tasks)
"""
return self.predict(h)
class BaseGNNClassifier(nn.Module):
"""GCN based predictor for multitask prediction on molecular graphs
We assume each task requires to perform a binary classification.
Parameters
----------
gnn_out_feats : int
Number of atom representation features after using GNN
n_tasks : int
Number of prediction tasks
classifier_hidden_feats : int
Number of molecular graph features in hidden layers of the MLP Classifier
dropout : float
The probability for dropout. Default to be 0., i.e. no
dropout is performed.
"""
def __init__(self, gnn_out_feats, n_tasks, classifier_hidden_feats=128, dropout=0.):
super(BaseGNNClassifier, self).__init__()
self.gnn_layers = nn.ModuleList()
self.weighted_sum_readout = WeightAndSum(gnn_out_feats)
self.g_feats = 2 * gnn_out_feats
self.soft_classifier = MLPBinaryClassifier(
self.g_feats, classifier_hidden_feats, n_tasks, dropout)
def forward(self, bg, feats):
"""Multi-task prediction for a batch of molecules
Parameters
----------
bg : BatchedDGLGraph
B Batched DGLGraphs for processing multiple molecules in parallel
feats : FloatTensor of shape (N, M0)
Initial features for all atoms in the batch of molecules
Returns
-------
FloatTensor of shape (B, n_tasks)
Soft prediction for all tasks on the batch of molecules
"""
# Update atom features with GNNs
for gnn in self.gnn_layers:
feats = gnn(bg, feats)
# Compute molecule features from atom features
h_g_sum = self.weighted_sum_readout(bg, feats)
with bg.local_scope():
bg.ndata['h'] = feats
h_g_max = dgl.max_nodes(bg, 'h')
if not isinstance(bg, BatchedDGLGraph):
h_g_sum = h_g_sum.unsqueeze(0)
h_g_max = h_g_max.unsqueeze(0)
h_g = torch.cat([h_g_sum, h_g_max], dim=1)
# Multi-task prediction
return self.soft_classifier(h_g)
class GCNClassifier(BaseGNNClassifier):
"""GCN based predictor for multitask prediction on molecular graphs
We assume each task requires to perform a binary classification.
Parameters
----------
in_feats : int
Number of input atom features
gcn_hidden_feats : list of int
gcn_hidden_feats[i] gives the number of output atom features
in the i+1-th gcn layer
n_tasks : int
Number of prediction tasks
classifier_hidden_feats : int
Number of molecular graph features in hidden layers of the MLP Classifier
dropout : float
The probability for dropout. Default to be 0., i.e. no
dropout is performed.
"""
def __init__(self, in_feats, gcn_hidden_feats, n_tasks,
classifier_hidden_feats=128, dropout=0.):
super(GCNClassifier, self).__init__(gnn_out_feats=gcn_hidden_feats[-1],
n_tasks=n_tasks,
classifier_hidden_feats=classifier_hidden_feats,
dropout=dropout)
for i in range(len(gcn_hidden_feats)):
out_feats = gcn_hidden_feats[i]
self.gnn_layers.append(GCNLayer(in_feats, out_feats))
in_feats = out_feats
class GATClassifier(BaseGNNClassifier):
"""GAT based predictor for multitask prediction on molecular graphs.
We assume each task requires to perform a binary classification.
Parameters
----------
in_feats : int
Number of input atom features
"""
def __init__(self, in_feats, gat_hidden_feats, num_heads,
n_tasks, classifier_hidden_feats=128, dropout=0):
super(GATClassifier, self).__init__(gnn_out_feats=gat_hidden_feats[-1],
n_tasks=n_tasks,
classifier_hidden_feats=classifier_hidden_feats,
dropout=dropout)
assert len(gat_hidden_feats) == len(num_heads), \
'Got gat_hidden_feats with length {:d} and num_heads with length {:d}, ' \
'expect them to be the same.'.format(len(gat_hidden_feats), len(num_heads))
num_layers = len(num_heads)
for l in range(num_layers):
if l > 0:
in_feats = gat_hidden_feats[l - 1] * num_heads[l - 1]
if l == num_layers - 1:
agg_mode = 'mean'
agg_act = None
else:
agg_mode = 'flatten'
agg_act = F.elu
self.gnn_layers.append(GATLayer(in_feats, gat_hidden_feats[l], num_heads[l],
feat_drop=dropout, attn_drop=dropout,
agg_mode=agg_mode, activation=agg_act))
# pylint: disable=C0111, C0103, C0200
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
from dgl.nn.pytorch import GraphConv
class GCNLayer(nn.Module):
def __init__(self, in_feats, out_feats, activation=F.relu,
residual=True, batchnorm=True, dropout=0.):
"""Single layer GCN for updating node features
Parameters
----------
in_feats : int
Number of input atom features
out_feats : int
Number of output atom features
activation : activation function
Default to be ReLU
residual : bool
Whether to use residual connection, default to be True
batchnorm : bool
Whether to use batch normalization on the output,
default to be True
dropout : float
The probability for dropout. Default to be 0., i.e. no
dropout is performed.
"""
super(GCNLayer, self).__init__()
self.activation = activation
self.graph_conv = GraphConv(in_feats=in_feats, out_feats=out_feats,
norm=False, activation=activation)
self.dropout = nn.Dropout(dropout)
self.residual = residual
if residual:
self.res_connection = nn.Linear(in_feats, out_feats)
self.bn = batchnorm
if batchnorm:
self.bn_layer = nn.BatchNorm1d(out_feats)
def forward(self, feats, bg):
"""Update atom representations
Parameters
----------
feats : FloatTensor of shape (N, M1)
* N is the total number of atoms in the batched graph
* M1 is the input atom feature size, must match in_feats in initialization
bg : BatchedDGLGraph
Batched DGLGraphs for processing multiple molecules in parallel
Returns
-------
new_feats : FloatTensor of shape (N, M2)
* M2 is the output atom feature size, must match out_feats in initialization
"""
new_feats = self.graph_conv(feats, bg)
if self.residual:
res_feats = self.activation(self.res_connection(feats))
new_feats = new_feats + res_feats
new_feats = self.dropout(new_feats)
if self.bn:
new_feats = self.bn_layer(new_feats)
return new_feats
class MLPBinaryClassifier(nn.Module):
def __init__(self, in_feats, hidden_feats, n_tasks, dropout=0.):
"""MLP for soft binary classification over multiple tasks from molecule representations.
Parameters
----------
in_feats : int
Number of input molecular graph features
hidden_feats : int
Number of molecular graph features in hidden layers
n_tasks : int
Number of tasks, also output size
dropout : float
The probability for dropout. Default to be 0., i.e. no
dropout is performed.
"""
super(MLPBinaryClassifier, self).__init__()
self.predict = nn.Sequential(
nn.Dropout(dropout),
nn.Linear(in_feats, hidden_feats),
nn.ReLU(),
nn.BatchNorm1d(hidden_feats),
nn.Linear(hidden_feats, n_tasks)
)
def forward(self, h):
"""Perform soft binary classification over multiple tasks
Parameters
----------
h : FloatTensor of shape (B, M3)
* B is the number of molecules in a batch
* M3 is the input molecule feature size, must match in_feats in initialization
Returns
-------
FloatTensor of shape (B, n_tasks)
"""
return self.predict(h)
class GCNClassifier(nn.Module):
def __init__(self, in_feats, gcn_hidden_feats, n_tasks, classifier_hidden_feats=128,
dropout=0., atom_data_field='h', atom_weight_field='w'):
"""GCN based predictor for multitask prediction on molecular graphs
We assume each task requires to perform a binary classification.
Parameters
----------
in_feats : int
Number of input atom features
gcn_hidden_feats : list of int
gcn_hidden_feats[i] gives the number of output atom features
in the i+1-th gcn layer
n_tasks : int
Number of prediction tasks
classifier_hidden_feats : int
Number of molecular graph features in hidden layers of the MLP Classifier
dropout : float
The probability for dropout. Default to be 0., i.e. no
dropout is performed.
atom_data_field : str
Name for storing atom features in DGLGraphs
atom_weight_field : str
Name for storing atom weights in DGLGraphs
"""
super(GCNClassifier, self).__init__()
self.atom_data_field = atom_data_field
self.gcn_layers = nn.ModuleList()
for i in range(len(gcn_hidden_feats)):
out_feats = gcn_hidden_feats[i]
self.gcn_layers.append(GCNLayer(in_feats, out_feats))
in_feats = out_feats
self.atom_weight_field = atom_weight_field
self.atom_weighting = nn.Sequential(
nn.Linear(in_feats, 1),
nn.Sigmoid()
)
self.g_feats = 2 * in_feats
self.soft_classifier = MLPBinaryClassifier(
self.g_feats, classifier_hidden_feats, n_tasks, dropout)
def forward(self, feats, bg):
"""Multi-task prediction for a batch of molecules
Parameters
----------
feats : FloatTensor of shape (N, M0)
Initial features for all atoms in the batch of molecules
bg : BatchedDGLGraph
B Batched DGLGraphs for processing multiple molecules in parallel
Returns
-------
FloatTensor of shape (B, n_tasks)
Soft prediction for all tasks on the batch of molecules
"""
# Update atom features
for gcn in self.gcn_layers:
feats = gcn(feats, bg)
# Compute molecule features from atom features
bg.ndata[self.atom_data_field] = feats
bg.ndata[self.atom_weight_field] = self.atom_weighting(feats)
h_g_sum = dgl.sum_nodes(
bg, self.atom_data_field, self.atom_weight_field)
h_g_max = dgl.max_nodes(bg, self.atom_data_field)
h_g = torch.cat([h_g_sum, h_g_max], dim=1)
# Multi-task prediction
return self.soft_classifier(h_g)
# pylint: disable=C0103, E1101
"""GNN layers for updating atom representations"""
import torch.nn as nn
import torch.nn.functional as F
from ...nn.pytorch import GraphConv, GATConv
class GCNLayer(nn.Module):
"""Single layer GCN for updating node features
Parameters
----------
in_feats : int
Number of input atom features
out_feats : int
Number of output atom features
activation : activation function
Default to be ReLU
residual : bool
Whether to use residual connection, default to be True
batchnorm : bool
Whether to use batch normalization on the output,
default to be True
dropout : float
The probability for dropout. Default to be 0., i.e. no
dropout is performed.
"""
def __init__(self, in_feats, out_feats, activation=F.relu,
residual=True, batchnorm=True, dropout=0.):
super(GCNLayer, self).__init__()
self.activation = activation
self.graph_conv = GraphConv(in_feats=in_feats, out_feats=out_feats,
norm=False, activation=activation)
self.dropout = nn.Dropout(dropout)
self.residual = residual
if residual:
self.res_connection = nn.Linear(in_feats, out_feats)
self.bn = batchnorm
if batchnorm:
self.bn_layer = nn.BatchNorm1d(out_feats)
def forward(self, bg, feats):
"""Update atom representations
Parameters
----------
bg : BatchedDGLGraph
Batched DGLGraphs for processing multiple molecules in parallel
feats : FloatTensor of shape (N, M1)
* N is the total number of atoms in the batched graph
* M1 is the input atom feature size, must match in_feats in initialization
Returns
-------
new_feats : FloatTensor of shape (N, M2)
* M2 is the output atom feature size, must match out_feats in initialization
"""
new_feats = self.graph_conv(bg, feats)
if self.residual:
res_feats = self.activation(self.res_connection(feats))
new_feats = new_feats + res_feats
new_feats = self.dropout(new_feats)
if self.bn:
new_feats = self.bn_layer(new_feats)
return new_feats
class GATLayer(nn.Module):
"""Single layer GAT for updating node features
Parameters
----------
in_feats : int
Number of input atom features
out_feats : int
Number of output atom features for each attention head
num_heads : int
Number of attention heads
feat_drop : float
Dropout applied to the input features
attn_drop : float
Dropout applied to attention values of edges
alpha : float
Hyperparameter in LeakyReLU, slope for negative values. Default to be 0.2
residual : bool
Whether to perform skip connection, default to be False
agg_mode : str
The way to aggregate multi-head attention results, can be either
'flatten' for concatenating all head results or 'mean' for averaging
all head results
activation : activation function or None
Activation function applied to aggregated multi-head results, default to be None.
"""
def __init__(self, in_feats, out_feats, num_heads, feat_drop, attn_drop,
alpha=0.2, residual=True, agg_mode='flatten', activation=None):
super(GATLayer, self).__init__()
self.gnn = GATConv(in_feats=in_feats, out_feats=out_feats, num_heads=num_heads,
feat_drop=feat_drop, attn_drop=attn_drop,
negative_slope=alpha, residual=residual)
assert agg_mode in ['flatten', 'mean']
self.agg_mode = agg_mode
self.activation = activation
def forward(self, bg, feats):
"""Update atom representations
Parameters
----------
bg : BatchedDGLGraph
Batched DGLGraphs for processing multiple molecules in parallel
feats : FloatTensor of shape (N, M1)
* N is the total number of atoms in the batched graph
* M1 is the input atom feature size, must match in_feats in initialization
Returns
-------
new_feats : FloatTensor of shape (N, M2)
* M2 is the output atom feature size. If self.agg_mode == 'flatten', this would
be out_feats * num_heads, else it would be just out_feats.
"""
new_feats = self.gnn(bg, feats)
if self.agg_mode == 'flatten':
new_feats = new_feats.flatten(1)
else:
new_feats = new_feats.mean(1)
if self.activation is not None:
new_feats = self.activation(new_feats)
return new_feats
......@@ -3,39 +3,46 @@
"""
The implementation of neural network layers used in SchNet and MGCN.
"""
import torch as th
import torch.nn as nn
from torch.nn import Softplus
import numpy as np
import dgl.function as fn
from ... import function as fn
class AtomEmbedding(nn.Module):
"""
Convert the atom(node) list to atom embeddings.
The atoms with the same element share the same initial embeddding.
The atoms with the same element share the same initial embedding.
Parameters
----------
dim : int
Dim of embeddings, default to be 128.
type_num : int
The largest atomic number of atoms in the dataset, default to be 100.
pre_train : None or pre-trained embeddings
Pre-trained embeddings, default to be None.
"""
def __init__(self, dim=128, type_num=100, pre_train=None):
"""
Randomly init the element embeddings.
Args:
dim: the dim of embeddings
type_num: the largest atomic number of atoms in the dataset
pre_train: the pre_trained embeddings
"""
super().__init__()
super(AtomEmbedding, self).__init__()
self._dim = dim
self._type_num = type_num
if pre_train is not None:
self.embedding = nn.Embedding.from_pretrained(pre_train,
padding_idx=0)
self.embedding = nn.Embedding.from_pretrained(pre_train, padding_idx=0)
else:
self.embedding = nn.Embedding(type_num, dim, padding_idx=0)
def forward(self, g, p_name="node"):
"""Input type is dgl graph"""
"""
Parameters
----------
g : DGLGraph
Input DGLGraph(s)
p_name : str
Name for storing atom embeddings
"""
atom_list = g.ndata["node_type"]
g.ndata[p_name] = self.embedding(atom_list)
return g.ndata[p_name]
......@@ -43,47 +50,69 @@ class AtomEmbedding(nn.Module):
class EdgeEmbedding(nn.Module):
"""
Convert the edge to embedding.
The edge links same pair of atoms share the same initial embedding.
Module for embedding edges. Edges linking same pairs of atoms share
the same initial embedding.
Parameters
----------
dim : int
Dim of edge embeddings, default to be 128.
edge_num : int
Maximum number of edge types, default to be 128.
pre_train : Edge embeddings or None
Pre-trained edge embeddings
"""
def __init__(self, dim=128, edge_num=3000, pre_train=None):
"""
Randomly init the edge embeddings.
Args:
dim: the dim of embeddings
edge_num: the maximum type of edges
pre_train: the pre_trained embeddings
"""
super().__init__()
super(EdgeEmbedding, self).__init__()
self._dim = dim
self._edge_num = edge_num
if pre_train is not None:
self.embedding = nn.Embedding.from_pretrained(pre_train,
padding_idx=0)
self.embedding = nn.Embedding.from_pretrained(pre_train, padding_idx=0)
else:
self.embedding = nn.Embedding(edge_num, dim, padding_idx=0)
def generate_edge_type(self, edges):
"""
Generate the edge type based on the src&dst atom type of the edge.
"""Generate edge type.
The edge type is based on the type of the src&dst atom.
Note that C-O and O-C are the same edge type.
To map a pair of nodes to one number, we use an unordered pairing function here
See more detail in this disscussion:
https://math.stackexchange.com/questions/23503/create-unique-number-from-2-numbers
Note that, the edge_num should be larger than the square of maximum atomic number
in the dataset.
Parameters
----------
edges : EdgeBatch
Edges for deciding types
Returns
-------
dict
Stores the edge types in "type"
"""
atom_type_x = edges.src["node_type"]
atom_type_y = edges.dst["node_type"]
return {
"type":
atom_type_x * atom_type_y +
"type": atom_type_x * atom_type_y +
(th.abs(atom_type_x - atom_type_y) - 1)**2 / 4
}
def forward(self, g, p_name="edge_f"):
"""Compute edge embeddings
Parameters
----------
g : DGLGraph
The graph to compute edge embeddings
p_name : str
Returns
-------
computed edge embeddings
"""
g.apply_edges(self.generate_edge_type)
g.edata[p_name] = self.embedding(g.edata["type"])
return g.edata[p_name]
......@@ -93,14 +122,20 @@ class ShiftSoftplus(Softplus):
"""
Shiftsoft plus activation function:
1/beta * (log(1 + exp**(beta * x)) - log(shift))
"""
Parameters
----------
beta : int
shift : int
threshold : int
"""
def __init__(self, beta=1, shift=2, threshold=20):
super().__init__(beta, threshold)
super(ShiftSoftplus, self).__init__(beta, threshold)
self.shift = shift
self.softplus = Softplus(beta, threshold)
def forward(self, x):
"""Applies the activation function"""
return self.softplus(x) - np.log(float(self.shift))
......@@ -112,9 +147,8 @@ class RBFLayer(nn.Module):
gamma = 10
0 <= mu_k <= 30 for k=1~300
"""
def __init__(self, low=0, high=30, gap=0.1, dim=1):
super().__init__()
super(RBFLayer, self).__init__()
self._low = low
self._high = high
self._gap = gap
......@@ -145,25 +179,25 @@ class RBFLayer(nn.Module):
class CFConv(nn.Module):
"""
The continuous-filter convolution layer in SchNet.
One CFConv contains one rbf layer and three linear layer
(two of them have activation funct).
"""
def __init__(self, rbf_dim, dim=64, act="sp"):
"""
Args:
rbf_dim: the dimsion of the RBF layer
dim: the dimension of linear layers
act: activation function (default shifted softplus)
Parameters
----------
rbf_dim : int
Dimension of the RBF layer output
dim : int
Dimension of linear layers, default to be 64
act : str or activation function
Activation function, default to be shifted softplus
"""
super().__init__()
def __init__(self, rbf_dim, dim=64, act=None):
super(CFConv, self).__init__()
self._rbf_dim = rbf_dim
self._dim = dim
self.linear_layer1 = nn.Linear(self._rbf_dim, self._dim)
self.linear_layer2 = nn.Linear(self._dim, self._dim)
if act == "sp":
if act is None:
self.activation = nn.Softplus(beta=0.5, threshold=14)
else:
self.activation = act
......@@ -187,10 +221,16 @@ class CFConv(nn.Module):
class Interaction(nn.Module):
"""
The interaction layer in the SchNet model.
"""
Parameters
----------
rbf_dim : int
Dimension of the RBF layer output
dim : int
Dimension of intermediate representations
"""
def __init__(self, rbf_dim, dim):
super().__init__()
super(Interaction, self).__init__()
self._node_dim = dim
self.activation = nn.Softplus(beta=0.5, threshold=14)
self.node_layer1 = nn.Linear(dim, dim, bias=False)
......@@ -199,7 +239,16 @@ class Interaction(nn.Module):
self.node_layer3 = nn.Linear(dim, dim)
def forward(self, g):
"""Interaction layer forward."""
"""
Parameters
----------
g : DGLGraph
Returns
-------
tensor
Updated atom representations
"""
g.ndata["new_node"] = self.node_layer1(g.ndata["node"])
cf_node = self.cfconv(g)
cf_node_1 = self.node_layer2(cf_node)
......@@ -213,16 +262,18 @@ class VEConv(nn.Module):
"""
The Vertex-Edge convolution layer in MGCN which takes edge & vertex features
in consideration at the same time.
"""
def __init__(self, rbf_dim, dim=64, update_edge=True):
Parameters
----------
rbf_dim : int
Dimension of the RBF layer output
dim : int
Dimension of intermediate representations, default to be 64.
update_edge : bool
Whether to apply a linear layer to update edge representations, default to be True.
"""
Args:
rbf_dim: the dimension of the RBF layer
dim: the dimension of linear layers
update_edge: whether update the edge emebedding in each conv-layer
"""
super().__init__()
def __init__(self, rbf_dim, dim=64, update_edge=True):
super(VEConv, self).__init__()
self._rbf_dim = rbf_dim
self._dim = dim
self._update_edge = update_edge
......@@ -234,7 +285,17 @@ class VEConv(nn.Module):
self.activation = nn.Softplus(beta=0.5, threshold=14)
def update_rbf(self, edges):
"""Update the RBF features."""
"""Update the RBF features
Parameters
----------
edges : EdgeBatch
Returns
-------
dict
Stores updated features in 'h'
"""
rbf = edges.data["rbf"]
h = self.linear_layer1(rbf)
h = self.activation(h)
......@@ -242,27 +303,43 @@ class VEConv(nn.Module):
return {"h": h}
def update_edge(self, edges):
"""Update the edge features."""
"""Update the edge features.
Parameters
----------
edges : EdgeBatch
Returns
-------
dict
Stores updated features in 'edge_f'
"""
edge_f = edges.data["edge_f"]
h = self.linear_layer3(edge_f)
return {"edge_f": h}
def forward(self, g):
"""VEConv layer forward"""
"""VEConv layer forward
Parameters
----------
g : DGLGraph
Returns
-------
tensor
Updated atom representations
"""
g.apply_edges(self.update_rbf)
if self._update_edge:
g.apply_edges(self.update_edge)
g.update_all(message_func=[
fn.u_mul_e("new_node", "h", "m_0"),
fn.copy_e("edge_f", "m_1")
],
reduce_func=[
fn.sum("m_0", "new_node_0"),
fn.sum("m_1", "new_node_1")
])
g.ndata["new_node"] = g.ndata.pop("new_node_0") + g.ndata.pop(
"new_node_1")
g.update_all(message_func=[fn.u_mul_e("new_node", "h", "m_0"),
fn.copy_e("edge_f", "m_1")],
reduce_func=[fn.sum("m_0", "new_node_0"),
fn.sum("m_1", "new_node_1")])
g.ndata["new_node"] = g.ndata.pop("new_node_0") + \
g.ndata.pop("new_node_1")
return g.ndata["new_node"]
......@@ -270,15 +347,19 @@ class VEConv(nn.Module):
class MultiLevelInteraction(nn.Module):
"""
The multilevel interaction in the MGCN model.
"""
Parameters
----------
rbf_dim : int
Dimension of the RBF layer output
dim : int
Dimension of intermediate representations
"""
def __init__(self, rbf_dim, dim):
super().__init__()
super(MultiLevelInteraction, self).__init__()
self._atom_dim = dim
self.activation = nn.Softplus(beta=0.5, threshold=14)
self.node_layer1 = nn.Linear(dim, dim, bias=True)
self.edge_layer1 = nn.Linear(dim, dim, bias=True)
self.conv_layer = VEConv(rbf_dim, dim)
......@@ -287,22 +368,26 @@ class MultiLevelInteraction(nn.Module):
def forward(self, g, level=1):
"""
MultiLevel Interaction Layer forward.
Args:
g: DGLGraph
level: current level of this layer
Parameters
----------
g : DGLGraph
level : int
Level of interaction
Returns
-------
tensor
Updated atom representations
"""
g.ndata["new_node"] = self.node_layer1(g.ndata["node_%s" %
(level - 1)])
g.ndata["new_node"] = self.node_layer1(
g.ndata["node_%s" % (level - 1)])
node = self.conv_layer(g)
g.edata["edge_f"] = self.activation(self.edge_layer1(
g.edata["edge_f"]))
g.edata["edge_f"] = self.activation(
self.edge_layer1(g.edata["edge_f"]))
node_1 = self.node_layer2(node)
node_1a = self.activation(node_1)
new_node = self.node_layer3(node_1a)
g.ndata["node_%s" % (level)] = g.ndata["node_%s" %
(level - 1)] + new_node
g.ndata["node_%s" % (level)] = g.ndata["node_%s" % (level - 1)] + new_node
return g.ndata["node_%s" % (level)]
# -*- coding:utf-8 -*-
# pylint: disable=C0103, C0111, W0621
"""Implementation of MGCN model"""
import dgl
import torch as th
import torch.nn as nn
from .layers import AtomEmbedding, RBFLayer, EdgeEmbedding, \
MultiLevelInteraction
from ...batched_graph import sum_nodes
class MGCNModel(nn.Module):
"""
MGCN Model from:
Chengqiang Lu, et al.
Molecular Property Prediction: A Multilevel
Quantum Interactions Modeling Perspective. (AAAI'2019)
MGCN from `Molecular Property Prediction: A Multilevel
Quantum Interactions Modeling Perspective <https://arxiv.org/abs/1906.11081>`__
Parameters
----------
dim : int
Dimension of feature maps, default to be 128.
out_put_dim: int
Number of target properties to predict, default to be 1.
edge_dim : int
Dimension of edge feature, default to be 128.
cutoff : float
The maximum distance between nodes, default to be 5.0.
width : int
Width in the RBF layer, default to be 1.
n_conv : int
Number of convolutional layers, default to be 3.
norm : bool
Whether to perform normalization, default to be False.
atom_ref : Atom embeddings or None
If None, random representation initialization will be used. Otherwise,
they will be used to initialize atom representations. Default to be None.
pre_train : Atom embeddings or None
If None, random representation initialization will be used. Otherwise,
they will be used to initialize atom representations. Default to be None.
"""
def __init__(self,
dim=128,
output_dim=1,
......@@ -27,21 +47,7 @@ class MGCNModel(nn.Module):
norm=False,
atom_ref=None,
pre_train=None):
"""
Args:
dim: dimension of feature maps
out_put_dim: the num of target propperties to predict
edge_dim: dimension of edge feature
cutoff: the maximum distance between nodes
width: width in the RBF layer
n_conv: number of convolutional layers
norm: normalization
atom_ref: atom reference
used as the initial value of atom embeddings,
or set to None with random initialization
pre_train: pre_trained node embeddings
"""
super().__init__()
super(MGCNModel, self).__init__()
self.name = "MGCN"
self._dim = dim
self.output_dim = output_dim
......@@ -73,11 +79,32 @@ class MGCNModel(nn.Module):
self.node_dense_layer2 = nn.Linear(64, output_dim)
def set_mean_std(self, mean, std, device):
"""Set the mean and std of atom representations for normalization.
Parameters
----------
mean : list or numpy array
The mean of labels
std : list or numpy array
The std of labels
device : str or torch.device
Device for storing the mean and std
"""
self.mean_per_node = th.tensor(mean, device=device)
self.std_per_node = th.tensor(std, device=device)
def forward(self, g):
"""Predict molecule labels
Parameters
----------
g : DGLGraph
Input DGLGraph for molecule(s)
Returns
-------
res : Predicted labels
"""
self.embedding_layer(g, "node_0")
if self.atom_ref is not None:
self.e0(g, "e0")
......@@ -104,5 +131,5 @@ class MGCNModel(nn.Module):
if self.norm:
g.ndata["res"] = g.ndata[
"res"] * self.std_per_node + self.mean_per_node
res = dgl.sum_nodes(g, "res")
res = sum_nodes(g, "res")
return res
......@@ -2,36 +2,39 @@
# coding: utf-8
# pylint: disable=C0103, C0111, E1101, W0612
"""Implementation of MPNN model."""
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Parameter
import dgl.function as fn
import dgl.nn.pytorch as dgl_nn
from ... import function as fn
from ...nn.pytorch import Set2Set
class NNConvLayer(nn.Module):
"""
MPNN Conv Layer from Section.5 in the paper "Neural Message Passing for Quantum Chemistry."
MPNN Conv Layer from Section 5 of
`Neural Message Passing for Quantum Chemistry <https://arxiv.org/abs/1704.01212>`__
Parameters
----------
in_channels : int
Number of input channels
out_channels : int
Number of output channels
edge_net : Module processing edge information
root_weight : bool
Whether to add the root node feature to output
bias : bool
Whether to add bias to the output
"""
def __init__(self,
in_channels,
out_channels,
edge_net,
root_weight=True,
bias=True):
"""
Args:
in_channels: number of input channels
out_channels: number of output channels
edge_net: the network modules process the edge info
root_weight: whether add the root node feature to output
bias: whether add bias to the output
"""
super().__init__()
super(NNConvLayer, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
......@@ -50,6 +53,7 @@ class NNConvLayer(nn.Module):
self.reset_parameters()
def reset_parameters(self):
"""Reinitialize model parameters"""
if self.root is not None:
nn.init.xavier_normal_(self.root.data, gain=1.414)
if self.bias is not None:
......@@ -59,6 +63,18 @@ class NNConvLayer(nn.Module):
nn.init.xavier_normal_(m.weight.data, gain=1.414)
def message(self, edges):
"""Function for computing messages from source nodes
Parameters
----------
edges : EdgeBatch
Edges over which we want to send messages
Returns
-------
dict
Stores message in key 'm'
"""
return {
'm':
torch.matmul(edges.src['h'].unsqueeze(1),
......@@ -66,6 +82,17 @@ class NNConvLayer(nn.Module):
}
def apply_node_func(self, nodes):
"""Function for updating node features directly
Parameters
----------
nodes : NodeBatch
Returns
-------
dict
Stores updated node features in 'h'
"""
aggr_out = nodes.data['aggr_out']
if self.root is not None:
aggr_out = torch.mm(nodes.data['h'], self.root) + aggr_out
......@@ -76,7 +103,23 @@ class NNConvLayer(nn.Module):
return {'h': aggr_out}
def forward(self, g, h, e):
"""MPNN Conv layer forward."""
"""Propagate messages and aggregate results for updating
atom representations
Parameters
----------
g : DGLGraph
DGLgraph(s) for molecules
h : tensor
Input atom representations
e : tensor
Input bond representations
Returns
-------
tensor
Aggregated atom information
"""
h = h.unsqueeze(-1) if h.dim() == 1 else h
e = e.unsqueeze(-1) if e.dim() == 1 else e
......@@ -90,11 +133,28 @@ class NNConvLayer(nn.Module):
class MPNNModel(nn.Module):
"""
MPNN model from:
Gilmer, Justin, et al.
Neural message passing for quantum chemistry.
MPNN from
`Neural Message Passing for Quantum Chemistry <https://arxiv.org/abs/1704.01212>`__
Parameters
----------
node_input_dim : int
Dimension of input node feature, default to be 15.
edge_input_dim : int
Dimension of input edge feature, default to be 15.
output_dim : int
Dimension of prediction, default to be 12.
node_hidden_dim : int
Dimension of node feature in hidden layers, default to be 64.
edge_hidden_dim : int
Dimension of edge feature in hidden layers, default to be 128.
num_step_message_passing : int
Number of message passing steps, default to be 6.
num_step_set2set : int
Number of set2set steps
num_layer_set2set : int
Number of set2set layers
"""
def __init__(self,
node_input_dim=15,
edge_input_dim=5,
......@@ -104,20 +164,8 @@ class MPNNModel(nn.Module):
num_step_message_passing=6,
num_step_set2set=6,
num_layer_set2set=3):
"""model parameters setting
Args:
node_input_dim: dimension of input node feature
edge_input_dim: dimension of input edge feature
output_dim: dimension of prediction
node_hidden_dim: dimension of node feature in hidden layers
edge_hidden_dim: dimension of edge feature in hidden layers
num_step_message_passing: number of message passing steps
num_step_set2set: number of set2set steps
num_layer_ste2set: number of set2set layers
"""
super(MPNNModel, self).__init__()
super().__init__()
self.name = "MPNN"
self.num_step_message_passing = num_step_message_passing
self.lin0 = nn.Linear(node_input_dim, node_hidden_dim)
......@@ -130,12 +178,22 @@ class MPNNModel(nn.Module):
root_weight=False)
self.gru = nn.GRU(node_hidden_dim, node_hidden_dim)
self.set2set = dgl_nn.glob.Set2Set(node_hidden_dim, num_step_set2set,
num_layer_set2set)
self.set2set = Set2Set(node_hidden_dim, num_step_set2set, num_layer_set2set)
self.lin1 = nn.Linear(2 * node_hidden_dim, node_hidden_dim)
self.lin2 = nn.Linear(node_hidden_dim, output_dim)
def forward(self, g):
"""Predict molecule labels
Parameters
----------
g : DGLGraph
Input DGLGraph for molecule(s)
Returns
-------
res : Predicted labels
"""
h = g.ndata['n_feat']
out = F.relu(self.lin0(h))
h = out.unsqueeze(0)
......
......@@ -2,9 +2,9 @@
import torch
from rdkit import Chem
from .dgmg import DGMG
from .gcn import GCNClassifier
from . import DGLJTNNVAE
from .classifiers import GCNClassifier, GATClassifier
from .dgmg import DGMG
from .mgcn import MGCNModel
from .mpnn import MPNNModel
from .sch import SchNetModel
......@@ -12,6 +12,7 @@ from ...data.utils import _get_dgl_url, download, get_download_dir
URL = {
'GCN_Tox21' : 'pre_trained/gcn_tox21.pth',
'GAT_Tox21' : 'pre_trained/gat_tox21.pth',
'MGCN_Alchemy': 'pre_trained/mgcn_alchemy.pth',
'SCHNET_Alchemy': 'pre_trained/schnet_alchemy.pth',
'MPNN_Alchemy': 'pre_trained/mpnn_alchemy.pth',
......@@ -73,8 +74,15 @@ def load_pretrained(model_name, log=True):
if model_name == 'GCN_Tox21':
model = GCNClassifier(in_feats=74,
gcn_hidden_feats=[64, 64],
n_tasks=12,
classifier_hidden_feats=64)
classifier_hidden_feats=64,
n_tasks=12)
elif model_name == 'GAT_Tox21':
model = GATClassifier(in_feats=74,
gat_hidden_feats=[32, 32],
num_heads=[4, 4],
classifier_hidden_feats=64,
n_tasks=12)
elif model_name.startswith('DGMG'):
if model_name.startswith('DGMG_ChEMBL'):
......
# -*- coding:utf-8 -*-
# pylint: disable=C0103, C0111, W0621
"""Implementation of SchNet model."""
import dgl
import torch as th
import torch.nn as nn
from .layers import AtomEmbedding, Interaction, ShiftSoftplus, RBFLayer
from ...batched_graph import sum_nodes
class SchNetModel(nn.Module):
"""
SchNet Model from:
Schütt, Kristof, et al.
SchNet: A continuous-filter convolutional neural network
for modeling quantum interactions. (NIPS'2017)
"""
`SchNet: A continuous-filter convolutional neural network for modeling
quantum interactions. (NIPS'2017) <https://arxiv.org/abs/1706.08566>`__
Parameters
----------
dim : int
Dimension of features, default to be 64
cutoff : float
Radius cutoff for RBF, default to be 5.0
output_dim : int
Dimension of prediction, default to be 1
width : int
Width in RBF, default to 1
n_conv : int
Number of conv (interaction) layers, default to be 1
norm : bool
Whether to normalize the output atom representations, default to be False.
atom_ref : Atom embeddings or None
If None, random representation initialization will be used. Otherwise,
they will be used to initialize atom representations. Default to be None.
pre_train : Atom embeddings or None
If None, random representation initialization will be used. Otherwise,
they will be used to initialize atom representations. Default to be None.
"""
def __init__(self,
dim=64,
cutoff=5.0,
......@@ -25,17 +43,6 @@ class SchNetModel(nn.Module):
norm=False,
atom_ref=None,
pre_train=None):
"""
Args:
dim: dimension of features
output_dim: dimension of prediction
cutoff: radius cutoff
width: width in the RBF function
n_conv: number of interaction layers
atom_ref: used as the initial value of atom embeddings,
or set to None with random initialization
norm: normalization
"""
super().__init__()
self.name = "SchNet"
self._dim = dim
......@@ -60,11 +67,32 @@ class SchNetModel(nn.Module):
self.atom_dense_layer2 = nn.Linear(64, output_dim)
def set_mean_std(self, mean, std, device="cpu"):
"""Set the mean and std of atom representations for normalization.
Parameters
----------
mean : list or numpy array
The mean of labels
std : list or numpy array
The std of labels
device : str or torch.device
Device for storing the mean and std
"""
self.mean_per_atom = th.tensor(mean, device=device)
self.std_per_atom = th.tensor(std, device=device)
def forward(self, g):
"""g is the DGLGraph"""
"""Predict molecule labels
Parameters
----------
g : DGLGraph
Input DGLGraph for molecule(s)
Returns
-------
res : Predicted labels
"""
self.embedding_layer(g)
if self.atom_ref is not None:
self.e0(g, "e0")
......@@ -81,7 +109,6 @@ class SchNetModel(nn.Module):
g.ndata["res"] = g.ndata["res"] + g.ndata["e0"]
if self.norm:
g.ndata["res"] = g.ndata[
"res"] * self.std_per_atom + self.mean_per_atom
res = dgl.sum_nodes(g, "res")
g.ndata["res"] = g.ndata["res"] * self.std_per_atom + self.mean_per_atom
res = sum_nodes(g, "res")
return res
......@@ -12,7 +12,7 @@ from ...batched_graph import sum_nodes, mean_nodes, max_nodes, broadcast_nodes,\
__all__ = ['SumPooling', 'AvgPooling', 'MaxPooling', 'SortPooling',
'GlobalAttentionPooling', 'Set2Set',
'SetTransformerEncoder', 'SetTransformerDecoder']
'SetTransformerEncoder', 'SetTransformerDecoder', 'WeightAndSum']
class SumPooling(nn.Module):
r"""Apply sum pooling over the nodes in the graph.
......@@ -668,3 +668,43 @@ class SetTransformerDecoder(nn.Module):
return feat.view(graph.batch_size, self.k * self.d_model)
else:
return feat.view(self.k * self.d_model)
class WeightAndSum(nn.Module):
"""Compute importance weights for atoms and perform a weighted sum.
Parameters
----------
in_feats : int
Input atom feature size
"""
def __init__(self, in_feats):
super(WeightAndSum, self).__init__()
self.in_feats = in_feats
self.atom_weighting = nn.Sequential(
nn.Linear(in_feats, 1),
nn.Sigmoid()
)
def forward(self, bg, feats):
"""Compute molecule representations out of atom representations
Parameters
----------
bg : BatchedDGLGraph
B Batched DGLGraphs for processing multiple molecules in parallel
feats : FloatTensor of shape (N, self.in_feats)
Representations for all atoms in the molecules
* N is the total number of atoms in all molecules
Returns
-------
FloatTensor of shape (B, self.in_feats)
Representations for B molecules
"""
with bg.local_scope():
bg.ndata['h'] = feats
bg.ndata['w'] = self.atom_weighting(bg.ndata['h'])
h_g_sum = sum_nodes(bg, 'h', 'w')
return h_g_sum
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment