"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "d71ecad8cde6edaa231253bfbf10d5f231d72203"
Unverified Commit e590feeb authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[Model Zoo] GAT on Tox21 (#793)

* GAT

* Fix mistake

* Fix

* hotfix

* Fix

* Fix

* Fix

* Fix

* Fix

* Fix

* Fix

* Update

* Update

* Update

* Fix style

* Hotfix

* Hotfix

* Hotfix

* Fix

* Fix

* Update

* CI trial

* Update

* Update

* Update
parent ad15947f
...@@ -17,10 +17,10 @@ Contributors ...@@ -17,10 +17,10 @@ Contributors
* [Tianyi Zhang](https://github.com/Tiiiger): SGC in Pytorch * [Tianyi Zhang](https://github.com/Tiiiger): SGC in Pytorch
* [Jun Chen](https://github.com/kitaev-chen): GIN in Pytorch * [Jun Chen](https://github.com/kitaev-chen): GIN in Pytorch
* [Aymen Waheb](https://github.com/aymenwah): APPNP in Pytorch * [Aymen Waheb](https://github.com/aymenwah): APPNP in Pytorch
* [Chengqiang Lu](https://github.com/geekinglcq): MGCN, SchNet and MPNN in PyTorch
Other improvement Other improvement
* [Brett Koonce](https://github.com/brettkoonce) * [Brett Koonce](https://github.com/brettkoonce)
* [@giuseppefutia](https://github.com/giuseppefutia) * [@giuseppefutia](https://github.com/giuseppefutia)
* [@mori97](https://github.com/mori97) * [@mori97](https://github.com/mori97)
* Hao Jin * Hao Jin
# DGL for Chemistry # DGL for Chemistry
With atoms being nodes and bonds being edges, molecular graphs are among the core objects for study in drug discovery. With atoms being nodes and bonds being edges, molecular graphs are among the core objects for study in Chemistry.
As drug discovery is known to be costly and time consuming, deep learning on graphs can be potentially beneficial for Deep learning on graphs can be beneficial for various applications in Chemistry like drug and material discovery
improving the efficiency of drug discovery [1], [2], [9]. [1], [2], [12].
To make it easy for domain scientists, the DGL team releases a model zoo for Chemistry, focusing on two particular cases To make it easy for domain scientists, the DGL team releases a model zoo for Chemistry, focusing on two particular cases
-- property prediction and target generation/optimization. -- property prediction and target generation/optimization.
...@@ -14,7 +14,7 @@ the chemistry community and the deep learning community to further their researc ...@@ -14,7 +14,7 @@ the chemistry community and the deep learning community to further their researc
Before you proceed, make sure you have installed the dependencies below: Before you proceed, make sure you have installed the dependencies below:
- PyTorch 1.2 - PyTorch 1.2
- Check the [official website](https://pytorch.org/) for installation guide - Check the [official website](https://pytorch.org/) for installation guide.
- RDKit 2018.09.3 - RDKit 2018.09.3
- We recommend installation with `conda install -c conda-forge rdkit==2018.09.3`. For other installation recipes, - We recommend installation with `conda install -c conda-forge rdkit==2018.09.3`. For other installation recipes,
see the [official documentation](https://www.rdkit.org/docs/Install.html). see the [official documentation](https://www.rdkit.org/docs/Install.html).
...@@ -39,17 +39,22 @@ mostly developed based on molecule fingerprints. ...@@ -39,17 +39,22 @@ mostly developed based on molecule fingerprints.
Graph neural networks make it possible for a data-driven representation of molecules out of the atoms, bonds and Graph neural networks make it possible for a data-driven representation of molecules out of the atoms, bonds and
molecular graph topology, which may be viewed as a learned fingerprint [3]. molecular graph topology, which may be viewed as a learned fingerprint [3].
### Models ### Models
- **Graph Convolutional Networks** [3], [9]: Graph Convolutional Networks (GCN) have been one of the most popular graph
- **Graph Convolutional Network**: Graph Convolutional Networks (GCN) have been one of the most popular graph neural neural networks and they can be easily extended for graph level prediction.
networks and they can be easily extended for graph level prediction. - **Graph Attention Networks** [10]: Graph Attention Networks (GATs) incorporate multi-head attention into GCNs,
- **SchNet**: SchNet is a novel deep learning architecture modeling quantum interactions in molecules which utilize explicitly modeling the interactions between adjacent atoms.
the continuous-filter convolutional layers [4]. - **SchNet** [4]: SchNet is a novel deep learning architecture modeling quantum interactions in molecules which utilize
- **Multilevel Graph Convolutional neural Network**: Multilevel Graph Convolutional neural Network (MGCN) is a the continuous-filter convolutional layers.
well-designed hierarchical graph neural network directly extracts features from the conformation and spatial information - **Multilevel Graph Convolutional neural Network** [5]: Multilevel Graph Convolutional neural Network (MGCN) is a well-designed
followed by the multilevel interactions [5]. hierarchical graph neural network directly extracts features from the conformation and spatial information followed
- **Message Passing Neural Network**: Message Passing Neural Network (MPNN) is a well-designed network with edge network by the multilevel interactions.
(enn) as front end and uses Set2Set to output prediction [6]. - **Message Passing Neural Network** [6]: Message Passing Neural Network (MPNN) is a well-designed network with edge network (enn)
as front end and Set2Set for output prediction.
### Example Usage of Pre-trained Models
![](https://s3.us-east-2.amazonaws.com/dgl.ai/model_zoo/drug_discovery/gcn_model_zoo_example.png)
## Generative Models ## Generative Models
...@@ -66,8 +71,14 @@ Generative models are known to be difficult for evaluation. [GuacaMol](https://g ...@@ -66,8 +71,14 @@ Generative models are known to be difficult for evaluation. [GuacaMol](https://g
are also two accompanying review papers that are well written [7], [8]. are also two accompanying review papers that are well written [7], [8].
### Models ### Models
- **Deep Generative Models of Graphs (DGMG)**: A very general framework for graph distribution learning by progressively - **Deep Generative Models of Graphs (DGMG)** [11]: A very general framework for graph distribution learning by
adding atoms and bonds. progressively adding atoms and bonds.
### Example Usage of Pre-trained Models
![](https://s3.us-east-2.amazonaws.com/dgl.ai/model_zoo/drug_discovery/dgmg_model_zoo_example1.png)
![](https://s3.us-east-2.amazonaws.com/dgl.ai/model_zoo/drug_discovery/dgmg_model_zoo_example2.png)
## References ## References
...@@ -85,12 +96,20 @@ information processing systems (NeurIPS)*, 2224-2232. ...@@ -85,12 +96,20 @@ information processing systems (NeurIPS)*, 2224-2232.
[5] Lu et al. Molecular Property Prediction: A Multilevel Quantum Interactions Modeling Perspective. [5] Lu et al. Molecular Property Prediction: A Multilevel Quantum Interactions Modeling Perspective.
*The 33rd AAAI Conference on Artificial Intelligence*. *The 33rd AAAI Conference on Artificial Intelligence*.
[6] Gilmer et al. (2017) Neural Message Passing for Quantum Chemistry. *Proceedings of the 34th International Conference [6] Gilmer et al. (2017) Neural Message Passing for Quantum Chemistry. *Proceedings of the 34th International Conference on
on Machine Learning* JMLR. 1263-1272. Machine Learning* JMLR. 1263-1272.
[7] Brown et al. (2019) GuacaMol: Benchmarking Models for de Novo Molecular Design. *J. Chem. Inf. Model*, 2019, 59, 3, [7] Brown et al. (2019) GuacaMol: Benchmarking Models for de Novo Molecular Design. *J. Chem. Inf. Model*, 2019, 59, 3,
1096-1108. 1096-1108.
[8] Polykovskiy et al. (2019) Molecular Sets (MOSES): A Benchmarking Platform for Molecular Generation Models. *arXiv*. [8] Polykovskiy et al. (2019) Molecular Sets (MOSES): A Benchmarking Platform for Molecular Generation Models. *arXiv*.
[9] Goh et al. (2017) Deep learning for computational chemistry. *Journal of Computational Chemistry* 16, 1291-1307. [9] Kipf et al. (2017) Semi-Supervised Classification with Graph Convolutional Networks.
*The International Conference on Learning Representations (ICLR)*.
[10] Veličković et al. (2018) Graph Attention Networks.
*The International Conference on Learning Representations (ICLR)*.
[11] Li et al. (2018) Learning Deep Generative Models of Graphs. *arXiv preprint arXiv:1803.03324*.
[12] Goh et al. (2017) Deep learning for computational chemistry. *Journal of Computational Chemistry* 16, 1291-1307.
...@@ -12,14 +12,21 @@ stress response pathways. Each target yields a binary prediction problem. Molecu ...@@ -12,14 +12,21 @@ stress response pathways. Each target yields a binary prediction problem. Molecu
into training, validation and test set with a 80/10/10 ratio. By default we follow their split method. into training, validation and test set with a 80/10/10 ratio. By default we follow their split method.
### Models ### Models
- **Graph Convolutional Network** [2]. Graph Convolutional Networks (GCN) have been one of the most popular graph neural - **Graph Convolutional Network** [2], [3]. Graph Convolutional Networks (GCN) have been one of the most popular graph neural
networks and they can be easily extended for graph level prediction. MoleculeNet [1] reports baseline results of graph networks and they can be easily extended for graph level prediction. MoleculeNet [1] reports baseline results of graph
convolutions over multiple datasets. convolutions over multiple datasets.
- **Graph Attention Networks** [7]: Graph Attention Networks (GATs) incorporate multi-head attention into GCNs,
explicitly modeling the interactions between adjacent atoms.
### Usage ### Usage
To train a model from scratch, simply call `python classification.py`. To skip training and use the pre-trained model, Use `classification.py` with arguments
call `python classification.py -p`. ```
-m {GCN, GAT}, MODEL, model to use
-d {Tox21}, DATASET, dataset to use
```
If you want to use the pre-trained model, simply add `-p`.
We use GPU whenever it is available. We use GPU whenever it is available.
...@@ -31,10 +38,16 @@ We use GPU whenever it is available. ...@@ -31,10 +38,16 @@ We use GPU whenever it is available.
| ---------------- | ---------------------- | | ---------------- | ---------------------- |
| MoleculeNet [1] | 0.829 | | MoleculeNet [1] | 0.829 |
| [DeepChem example](https://github.com/deepchem/deepchem/blob/master/examples/tox21/tox21_tensorgraph_graph_conv.py) | 0.813 | | [DeepChem example](https://github.com/deepchem/deepchem/blob/master/examples/tox21/tox21_tensorgraph_graph_conv.py) | 0.813 |
| Pretrained model | 0.827 | | Pretrained model | 0.826 |
Note that the dataset is randomly split so these numbers are only for reference and they do not necessarily suggest
a real difference.
#### GAT on Tox21
Note that due to some possible randomness you may get different numbers for DeepChem example and our model. To get | Source | Averaged ROC-AUC Score |
match exact results for this model, please use the pre-trained model as in the usage section. | ---------------- | ---------------------- |
| Pretrained model | 0.827 |
## Dataset Customization ## Dataset Customization
...@@ -47,16 +60,20 @@ Regression tasks require assigning continuous labels to a molecule, e.g. molecul ...@@ -47,16 +60,20 @@ Regression tasks require assigning continuous labels to a molecule, e.g. molecul
### Dataset ### Dataset
- **Alchemy**. The [Alchemy Dataset](https://alchemy.tencent.com/) is introduced by Tencent Quantum Lab to facilitate the development of new machine learning models useful for chemistry and materials science. - **Alchemy**. The [Alchemy Dataset](https://alchemy.tencent.com/) is introduced by Tencent Quantum Lab to facilitate the development of new
The dataset lists 12 quantum mechanical properties of 130,000+ organic molecules comprising up to 12 heavy atoms (C, N, O, S, F and Cl), sampled from the [GDBMedChem](http://gdb.unibe.ch/downloads/) database. machine learning models useful for chemistry and materials science. The dataset lists 12 quantum mechanical properties of 130,000+ organic
These properties have been calculated using the open-source computational chemistry program Python-based Simulation of Chemistry Framework ([PySCF](https://github.com/pyscf/pyscf)). molecules comprising up to 12 heavy atoms (C, N, O, S, F and Cl), sampled from the [GDBMedChem](http://gdb.unibe.ch/downloads/) database.
The Alchemy dataset expands on the volume and diversity of existing molecular datasets such as QM9. These properties have been calculated using the open-source computational chemistry program Python-based Simulation of Chemistry Framework
([PySCF](https://github.com/pyscf/pyscf)). The Alchemy dataset expands on the volume and diversity of existing molecular datasets such as QM9.
### Models ### Models
- **SchNet**: SchNet is a novel deep learning architecture modeling quantum interactions in molecules which utilize the continuous-filter convolutional layers [3]. - **SchNet**: SchNet is a novel deep learning architecture modeling quantum interactions in molecules which utilize the continuous-filter
- **Multilevel Graph Convolutional neural Network**: Multilevel Graph Convolutional neural Network (MGCN) is a well-designed hierarchical graph neural network directly extracts features from the conformation and spatial information followed by the multilevel interactions [4]. convolutional layers [4].
- **Message Passing Neural Network**: Message Passing Neural Network (MPNN) is a well-designed network with edge network (enn) as front end and us Set2Set to output prediction [5]. - **Multilevel Graph Convolutional neural Network**: Multilevel Graph Convolutional neural Network (MGCN) is a hierarchical
graph neural network directly extracts features from the conformation and spatial information followed by the multilevel interactions [5].
- **Message Passing Neural Network**: Message Passing Neural Network (MPNN) is a network with edge network (enn) as front end
and Set2Set for output prediction [6].
### Usage ### Usage
...@@ -71,22 +88,27 @@ The model option must be one of 'sch', 'mgcn' or 'mpnn'. ...@@ -71,22 +88,27 @@ The model option must be one of 'sch', 'mgcn' or 'mpnn'.
|Model |Mean Absolute Error (MAE)| |Model |Mean Absolute Error (MAE)|
|-------------|-------------------------| |-------------|-------------------------|
|SchNet[3] |0.065| |SchNet[4] |0.065|
|MGCN[4] |0.050| |MGCN[5] |0.050|
|MPNN[5] |0.056| |MPNN[6] |0.056|
## References ## References
[1] Wu et al. (2017) MoleculeNet: a benchmark for molecular machine learning. *Chemical Science* 9, 513-530. [1] Wu et al. (2017) MoleculeNet: a benchmark for molecular machine learning. *Chemical Science* 9, 513-530.
[2] Kipf et al. (2017) Semi-Supervised Classification with Graph Convolutional Networks. [2] Duvenaud et al. (2015) Convolutional networks on graphs for learning molecular fingerprints. *Advances in neural
information processing systems (NeurIPS)*, 2224-2232.
[3] Kipf et al. (2017) Semi-Supervised Classification with Graph Convolutional Networks.
*The International Conference on Learning Representations (ICLR)*. *The International Conference on Learning Representations (ICLR)*.
[3] Schütt et al. (2017) SchNet: A continuous-filter convolutional neural network for modeling quantum interactions. [4] Schütt et al. (2017) SchNet: A continuous-filter convolutional neural network for modeling quantum interactions.
*Advances in Neural Information Processing Systems (NeurIPS)*, 992-1002. *Advances in Neural Information Processing Systems (NeurIPS)*, 992-1002.
[4] Lu et al. (2019) Molecular Property Prediction: A Multilevel Quantum Interactions Modeling Perspective. [5] Lu et al. (2019) Molecular Property Prediction: A Multilevel Quantum Interactions Modeling Perspective.
*The 33rd AAAI Conference on Artificial Intelligence*. *The 33rd AAAI Conference on Artificial Intelligence*.
[5] Gilmer et al. (2017) Neural Message Passing for Quantum Chemistry. *Proceedings of the 34th International Conference on [6] Gilmer et al. (2017) Neural Message Passing for Quantum Chemistry. *Proceedings of the 34th International Conference on
Machine Learning*, JMLR. 1263-1272. Machine Learning*, JMLR. 1263-1272.
[7] Veličković et al. (2018) Graph Attention Networks.
*The International Conference on Learning Representations (ICLR)*.
from dgl.data import Tox21
from dgl.data.utils import split_dataset from dgl.data.utils import split_dataset
from dgl import model_zoo from dgl import model_zoo
import torch import torch
...@@ -8,93 +7,108 @@ from torch.utils.data import DataLoader ...@@ -8,93 +7,108 @@ from torch.utils.data import DataLoader
from utils import Meter, EarlyStopping, collate_molgraphs, set_random_seed from utils import Meter, EarlyStopping, collate_molgraphs, set_random_seed
def run_a_train_epoch(args, epoch, model, data_loader, loss_criterion, optimizer):
model.train()
train_meter = Meter()
for batch_id, batch_data in enumerate(data_loader):
smiles, bg, labels, mask = batch_data
atom_feats = bg.ndata.pop(args['atom_data_field'])
atom_feats, labels, mask = atom_feats.to(args['device']), \
labels.to(args['device']), \
mask.to(args['device'])
logits = model(bg, atom_feats)
# Mask non-existing labels
loss = (loss_criterion(logits, labels) * (mask != 0).float()).mean()
optimizer.zero_grad()
loss.backward()
optimizer.step()
print('epoch {:d}/{:d}, batch {:d}/{:d}, loss {:.4f}'.format(
epoch + 1, args['num_epochs'], batch_id + 1, len(data_loader), loss.item()))
train_meter.update(logits, labels, mask)
train_roc_auc = train_meter.roc_auc_averaged_over_tasks()
print('epoch {:d}/{:d}, training roc-auc score {:.4f}'.format(
epoch + 1, args['num_epochs'], train_roc_auc))
def run_an_eval_epoch(args, model, data_loader):
model.eval()
eval_meter = Meter()
with torch.no_grad():
for batch_id, batch_data in enumerate(data_loader):
smiles, bg, labels, mask = batch_data
atom_feats = bg.ndata.pop(args['atom_data_field'])
atom_feats, labels = atom_feats.to(args['device']), labels.to(args['device'])
logits = model(bg, atom_feats)
eval_meter.update(logits, labels, mask)
return eval_meter.roc_auc_averaged_over_tasks()
def main(args): def main(args):
device = "cuda" if torch.cuda.is_available() else "cpu" args['device'] = "cuda" if torch.cuda.is_available() else "cpu"
batch_size = 128
learning_rate = 0.001
num_epochs = 100
set_random_seed() set_random_seed()
# Interchangeable with other Dataset # Interchangeable with other Dataset
dataset = Tox21() if args['dataset'] == 'Tox21':
atom_data_field = 'h' from dgl.data.chem import Tox21
dataset = Tox21()
trainset, valset, testset = split_dataset(dataset, [0.8, 0.1, 0.1]) trainset, valset, testset = split_dataset(dataset, args['train_val_test_split'])
train_loader = DataLoader( train_loader = DataLoader(trainset, batch_size=args['batch_size'], collate_fn=collate_molgraphs)
trainset, batch_size=batch_size, collate_fn=collate_molgraphs) val_loader = DataLoader(valset, batch_size=args['batch_size'], collate_fn=collate_molgraphs)
val_loader = DataLoader( test_loader = DataLoader(testset, batch_size=args['batch_size'], collate_fn=collate_molgraphs)
valset, batch_size=batch_size, collate_fn=collate_molgraphs)
test_loader = DataLoader(
testset, batch_size=batch_size, collate_fn=collate_molgraphs)
if args.pre_trained: if args['pre_trained']:
num_epochs = 0 args['num_epochs'] = 0
model = model_zoo.chem.load_pretrained('GCN_Tox21') model = model_zoo.chem.load_pretrained(args['exp'])
else: else:
# Interchangeable with other models # Interchangeable with other models
model = model_zoo.chem.GCNClassifier(in_feats=74, if args['model'] == 'GCN':
gcn_hidden_feats=[64, 64], model = model_zoo.chem.GCNClassifier(in_feats=args['in_feats'],
n_tasks=dataset.n_tasks) gcn_hidden_feats=args['gcn_hidden_feats'],
classifier_hidden_feats=args['classifier_hidden_feats'],
n_tasks=dataset.n_tasks)
elif args['model'] == 'GAT':
model = model_zoo.chem.GATClassifier(in_feats=args['in_feats'],
gat_hidden_feats=args['gat_hidden_feats'],
num_heads=args['num_heads'],
classifier_hidden_feats=args['classifier_hidden_feats'],
n_tasks=dataset.n_tasks)
loss_criterion = BCEWithLogitsLoss(pos_weight=torch.tensor( loss_criterion = BCEWithLogitsLoss(pos_weight=torch.tensor(
dataset.task_pos_weights).to(device), reduction='none') dataset.task_pos_weights).to(args['device']), reduction='none')
optimizer = Adam(model.parameters(), lr=learning_rate) optimizer = Adam(model.parameters(), lr=args['lr'])
stopper = EarlyStopping(patience=10) stopper = EarlyStopping(patience=args['patience'])
model.to(device) model.to(args['device'])
for epoch in range(num_epochs): for epoch in range(args['num_epochs']):
model.train() # Train
print('Start training') run_a_train_epoch(args, epoch, model, train_loader, loss_criterion, optimizer)
train_meter = Meter()
for batch_id, batch_data in enumerate(train_loader):
smiles, bg, labels, mask = batch_data
atom_feats = bg.ndata.pop(atom_data_field)
atom_feats, labels, mask = atom_feats.to(device), labels.to(device), mask.to(device)
logits = model(atom_feats, bg)
# Mask non-existing labels
loss = (loss_criterion(logits, labels)
* (mask != 0).float()).mean()
optimizer.zero_grad()
loss.backward()
optimizer.step()
print('epoch {:d}/{:d}, batch {:d}/{:d}, loss {:.4f}'.format(
epoch + 1, num_epochs, batch_id + 1, len(train_loader), loss.item()))
train_meter.update(logits, labels, mask)
train_roc_auc = train_meter.roc_auc_averaged_over_tasks()
print('epoch {:d}/{:d}, training roc-auc score {:.4f}'.format(
epoch + 1, num_epochs, train_roc_auc))
val_meter = Meter()
model.eval()
with torch.no_grad():
for batch_id, batch_data in enumerate(val_loader):
smiles, bg, labels, mask = batch_data
atom_feats = bg.ndata.pop(atom_data_field)
atom_feats, labels = atom_feats.to(device), labels.to(device)
logits = model(atom_feats, bg)
val_meter.update(logits, labels, mask)
val_roc_auc = val_meter.roc_auc_averaged_over_tasks()
if stopper.step(val_roc_auc, model):
break
# Validation and early stop
val_roc_auc = run_an_eval_epoch(args, model, val_loader)
early_stop = stopper.step(val_roc_auc, model)
print('epoch {:d}/{:d}, validation roc-auc score {:.4f}, best validation roc-auc score {:.4f}'.format( print('epoch {:d}/{:d}, validation roc-auc score {:.4f}, best validation roc-auc score {:.4f}'.format(
epoch + 1, num_epochs, val_roc_auc, stopper.best_score)) epoch + 1, args['num_epochs'], val_roc_auc, stopper.best_score))
if early_stop:
break
test_meter = Meter() if not args['pre_trained']:
model.eval() stopper.load_checkpoint(model)
for batch_id, batch_data in enumerate(test_loader): test_roc_auc = run_an_eval_epoch(args, model, test_loader)
smiles, bg, labels, mask = batch_data print('test roc-auc score {:.4f}'.format(test_roc_auc))
atom_feats = bg.ndata.pop(atom_data_field)
atom_feats, labels = atom_feats.to(device), labels.to(device)
logits = model(atom_feats, bg)
test_meter.update(logits, labels, mask)
print('test roc-auc score {:.4f}'.format(test_meter.roc_auc_averaged_over_tasks()))
if __name__ == '__main__': if __name__ == '__main__':
import argparse import argparse
from configure import get_exp_configure
parser = argparse.ArgumentParser(description='Molecule Classification') parser = argparse.ArgumentParser(description='Molecule Classification')
parser.add_argument('-m', '--model', type=str, choices=['GCN', 'GAT'],
help='Model to use')
parser.add_argument('-d', '--dataset', type=str, choices=['Tox21'],
help='Dataset to use')
parser.add_argument('-p', '--pre-trained', action='store_true', parser.add_argument('-p', '--pre-trained', action='store_true',
help='Whether to skip training and use a pre-trained model') help='Whether to skip training and use a pre-trained model')
args = parser.parse_args() args = parser.parse_args().__dict__
args['exp'] = '_'.join([args['model'], args['dataset']])
args.update(get_exp_configure(args['exp']))
main(args) main(args)
GCN_Tox21 = {
'batch_size': 128,
'lr': 1e-3,
'num_epochs': 100,
'atom_data_field': 'h',
'train_val_test_split': [0.8, 0.1, 0.1],
'in_feats': 74,
'gcn_hidden_feats': [64, 64],
'classifier_hidden_feats': 64,
'patience': 10
}
GAT_Tox21 = {
'batch_size': 128,
'lr': 1e-3,
'num_epochs': 100,
'atom_data_field': 'h',
'train_val_test_split': [0.8, 0.1, 0.1],
'in_feats': 74,
'gat_hidden_feats': [32, 32],
'classifier_hidden_feats': 64,
'num_heads': [4, 4],
'patience': 10
}
experiment_configures = {
'GCN_Tox21': GCN_Tox21,
'GAT_Tox21': GAT_Tox21
}
def get_exp_configure(exp_name):
return experiment_configures[exp_name]
...@@ -74,11 +74,11 @@ class EarlyStopping(object): ...@@ -74,11 +74,11 @@ class EarlyStopping(object):
def save_checkpoint(self, model): def save_checkpoint(self, model):
'''Saves model when the metric on the validation set gets improved.''' '''Saves model when the metric on the validation set gets improved.'''
torch.save(model.state_dict(), self.filename) torch.save({'model_state_dict': model.state_dict()}, self.filename)
def load_checkpoint(self, model): def load_checkpoint(self, model):
'''Load model saved with early stopping.''' '''Load model saved with early stopping.'''
model.load_state_dict(torch.load(self.filename)) model.load_state_dict(torch.load(self.filename)['model_state_dict'])
def collate_molgraphs(data): def collate_molgraphs(data):
"""Batching a list of datapoints for dataloader """Batching a list of datapoints for dataloader
......
...@@ -11,7 +11,6 @@ from .reddit import RedditDataset ...@@ -11,7 +11,6 @@ from .reddit import RedditDataset
from .ppi import PPIDataset from .ppi import PPIDataset
from .tu import TUDataset from .tu import TUDataset
from .gindt import GINDataset from .gindt import GINDataset
# from .chem import Tox21, alchemy
def register_data_args(parser): def register_data_args(parser):
......
...@@ -10,11 +10,10 @@ import pickle ...@@ -10,11 +10,10 @@ import pickle
import zipfile import zipfile
from collections import defaultdict from collections import defaultdict
import dgl
import dgl.backend as F
from dgl.data.utils import download, get_download_dir
from .utils import mol_to_complete_graph from .utils import mol_to_complete_graph
from ..utils import download, get_download_dir
from ...batched_graph import batch
from ... import backend as F
try: try:
import pandas as pd import pandas as pd
...@@ -40,12 +39,12 @@ class AlchemyBatcher(object): ...@@ -40,12 +39,12 @@ class AlchemyBatcher(object):
self.graph = graph self.graph = graph
self.label = label self.label = label
def batcher_dev(batch): def batcher_dev(batch_data):
"""Batch datapoints """Batch datapoints
Parameters Parameters
---------- ----------
batch : list batch_data : list
batch[i][0] gives the DGLGraph for the ith datapoint, batch[i][0] gives the DGLGraph for the ith datapoint,
and batch[i][1] gives the label for the ith datapoint. and batch[i][1] gives the label for the ith datapoint.
...@@ -54,15 +53,79 @@ def batcher_dev(batch): ...@@ -54,15 +53,79 @@ def batcher_dev(batch):
AlchemyBatcher AlchemyBatcher
An object holding the batch of data An object holding the batch of data
""" """
graphs, labels = zip(*batch) graphs, labels = zip(*batch_data)
batch_graphs = dgl.batch(graphs) batch_graphs = batch(graphs)
labels = F.stack(labels, 0) labels = F.stack(labels, 0)
return AlchemyBatcher(graph=batch_graphs, label=labels) return AlchemyBatcher(graph=batch_graphs, label=labels)
class TencentAlchemyDataset(object): class TencentAlchemyDataset(object):
fdef_name = osp.join(RDConfig.RDDataDir, 'BaseFeatures.fdef') """`Tencent Alchemy Dataset <https://arxiv.org/abs/1906.09427>`__
chem_feature_factory = ChemicalFeatures.BuildFeatureFactory(fdef_name)
Parameters
----------
mode : str
'dev', 'valid' or 'test', default to be 'dev'
transform : transform operation on DGLGraphs
Default to be None.
from_raw : bool
Whether to process dataset from scratch or use a
processed one for faster speed. Default to be False.
"""
def __init__(self, mode='dev', transform=None, from_raw=False):
assert mode in ['dev', 'valid', 'test'], "mode should be dev/valid/test"
self.mode = mode
self.transform = transform
# Construct DGLGraphs from raw data or use the preprocessed data
self.from_raw = from_raw
file_dir = osp.join(get_download_dir(), './Alchemy_data')
if not from_raw:
file_name = "%s_processed" % (mode)
else:
file_name = "%s_single_sdf" % (mode)
self.file_dir = pathlib.Path(file_dir, file_name)
self.zip_file_path = pathlib.Path(file_dir, file_name + '.zip')
download(_urls['Alchemy'] + file_name + '.zip',
path=str(self.zip_file_path))
if not os.path.exists(str(self.file_dir)):
archive = zipfile.ZipFile(self.zip_file_path)
archive.extractall(file_dir)
archive.close()
self._load()
def _load(self):
if self.mode == 'dev':
if not self.from_raw:
with open(osp.join(self.file_dir, "dev_graphs.pkl"), "rb") as f:
self.graphs = pickle.load(f)
with open(osp.join(self.file_dir, "dev_labels.pkl"), "rb") as f:
self.labels = pickle.load(f)
else:
target_file = pathlib.Path(self.file_dir, "dev_target.csv")
self.target = pd.read_csv(
target_file,
index_col=0,
usecols=['gdb_idx',] + ['property_%d' % x for x in range(12)])
self.target = self.target[['property_%d' % x for x in range(12)]]
self.graphs, self.labels = [], []
supp = Chem.SDMolSupplier(
osp.join(self.file_dir, self.mode + ".sdf"))
cnt = 0
for sdf, label in zip(supp, self.target.iterrows()):
graph = mol_to_complete_graph(sdf, atom_featurizer=self.alchemy_nodes,
bond_featurizer=self.alchemy_edges)
cnt += 1
self.graphs.append(graph)
label = F.tensor(np.array(label[1].tolist()).astype(np.float32))
self.labels.append(label)
self.normalize()
print(len(self.graphs), "loaded!")
def alchemy_nodes(self, mol): def alchemy_nodes(self, mol):
"""Featurization for all atoms in a molecule. The atom indices """Featurization for all atoms in a molecule. The atom indices
...@@ -135,8 +198,8 @@ class TencentAlchemyDataset(object): ...@@ -135,8 +198,8 @@ class TencentAlchemyDataset(object):
return atom_feats_dict return atom_feats_dict
def alchemy_edges(self, mol, self_loop=False): def alchemy_edges(self, mol, self_loop=False):
"""Featurization for all bonds in a molecule. The bond indices """Featurization for all bonds in a molecule.
will be preserved. The bond indices will be preserved.
Parameters Parameters
---------- ----------
...@@ -182,66 +245,16 @@ class TencentAlchemyDataset(object): ...@@ -182,66 +245,16 @@ class TencentAlchemyDataset(object):
return bond_feats_dict return bond_feats_dict
def __init__(self, mode='dev', transform=None, from_raw=False):
assert mode in ['dev', 'valid', 'test'], "mode should be dev/valid/test"
self.mode = mode
self.transform = transform
# Construct the dgl graph from raw data or use the preprocessed data directly
self.from_raw = from_raw
file_dir = osp.join(get_download_dir(), './Alchemy_data')
if not from_raw:
file_name = "%s_processed" % (mode)
else:
file_name = "%s_single_sdf" % (mode)
self.file_dir = pathlib.Path(file_dir, file_name)
self.zip_file_path = pathlib.Path(file_dir, file_name + '.zip')
download(_urls['Alchemy'] + file_name + '.zip',
path=str(self.zip_file_path))
if not os.path.exists(str(self.file_dir)):
archive = zipfile.ZipFile(self.zip_file_path)
archive.extractall(file_dir)
archive.close()
self._load()
def _load(self):
if self.mode == 'dev':
if not self.from_raw:
with open(osp.join(self.file_dir, "dev_graphs.pkl"), "rb") as f:
self.graphs = pickle.load(f)
with open(osp.join(self.file_dir, "dev_labels.pkl"), "rb") as f:
self.labels = pickle.load(f)
else:
target_file = pathlib.Path(self.file_dir, "dev_target.csv")
self.target = pd.read_csv(
target_file,
index_col=0,
usecols=[
'gdb_idx',
] + ['property_%d' % x for x in range(12)])
self.target = self.target[[
'property_%d' % x for x in range(12)
]]
self.graphs, self.labels = [], []
supp = Chem.SDMolSupplier(
osp.join(self.file_dir, self.mode + ".sdf"))
cnt = 0
for sdf, label in zip(supp, self.target.iterrows()):
graph = mol_to_complete_graph(sdf, atom_featurizer=self.alchemy_nodes,
bond_featurizer=self.alchemy_edges)
cnt += 1
self.graphs.append(graph)
label = F.tensor(np.array(label[1].tolist()).astype(np.float32))
self.labels.append(label)
self.normalize()
print(len(self.graphs), "loaded!")
def normalize(self, mean=None, std=None): def normalize(self, mean=None, std=None):
"""Set mean and std or compute from labels for future normalization.
Parameters
----------
mean : int or float
Default to be None.
std : int or float
Default to be None.
"""
labels = np.array([i.numpy() for i in self.labels]) labels = np.array([i.numpy() for i in self.labels])
if mean is None: if mean is None:
mean = np.mean(labels, axis=0) mean = np.mean(labels, axis=0)
...@@ -260,7 +273,20 @@ class TencentAlchemyDataset(object): ...@@ -260,7 +273,20 @@ class TencentAlchemyDataset(object):
return g, l return g, l
def split(self, train_size=0.8): def split(self, train_size=0.8):
"""Split the dataset into two AlchemySubset for train&test.""" """Split the dataset into two AlchemySubset for train&test.
Parameters
----------
train_size : float
Proportion of dataset to use for training. Default to be 0.8.
Returns
-------
train_set : AlchemySubset
Dataset for training
test_set : AlchemySubset
Dataset for test
"""
assert 0 < train_size < 1 assert 0 < train_size < 1
train_num = int(len(self.graphs) * train_size) train_num = int(len(self.graphs) * train_size)
train_set = AlchemySubset(self.graphs[:train_num], train_set = AlchemySubset(self.graphs[:train_num],
...@@ -275,6 +301,19 @@ class AlchemySubset(TencentAlchemyDataset): ...@@ -275,6 +301,19 @@ class AlchemySubset(TencentAlchemyDataset):
""" """
Sub-dataset split from TencentAlchemyDataset. Sub-dataset split from TencentAlchemyDataset.
Used to construct the training & test set. Used to construct the training & test set.
Parameters
----------
graphs : list of DGLGraphs
DGLGraphs for datapoints in the subset
labels : list of tensors
Labels for datapoints in the subset
mean : int or float
Mean of labels in the subset
std : int or float
Std of labels in the subset
transform : transform operation on DGLGraphs
Default to be None.
""" """
def __init__(self, graphs, labels, mean=0, std=1, transform=None): def __init__(self, graphs, labels, mean=0, std=1, transform=None):
super(AlchemySubset, self).__init__() super(AlchemySubset, self).__init__()
......
# pylint: disable=C0111 # pylint: disable=C0111
"""Model Zoo Package""" """Model Zoo Package"""
from .gcn import GCNClassifier from .classifiers import GCNClassifier, GATClassifier
from .sch import SchNetModel from .sch import SchNetModel
from .mgcn import MGCNModel from .mgcn import MGCNModel
from .mpnn import MPNNModel from .mpnn import MPNNModel
......
# pylint: disable=C0111, C0103, C0200
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
from dgl import BatchedDGLGraph
from .gnn import GCNLayer, GATLayer
from ...nn.pytorch import WeightAndSum
class MLPBinaryClassifier(nn.Module):
"""MLP for soft binary classification over multiple tasks from molecule representations.
Parameters
----------
in_feats : int
Number of input molecular graph features
hidden_feats : int
Number of molecular graph features in hidden layers
n_tasks : int
Number of tasks, also output size
dropout : float
The probability for dropout. Default to be 0., i.e. no
dropout is performed.
"""
def __init__(self, in_feats, hidden_feats, n_tasks, dropout=0.):
super(MLPBinaryClassifier, self).__init__()
self.predict = nn.Sequential(
nn.Dropout(dropout),
nn.Linear(in_feats, hidden_feats),
nn.ReLU(),
nn.BatchNorm1d(hidden_feats),
nn.Linear(hidden_feats, n_tasks)
)
def forward(self, h):
"""Perform soft binary classification over multiple tasks
Parameters
----------
h : FloatTensor of shape (B, M3)
* B is the number of molecules in a batch
* M3 is the input molecule feature size, must match in_feats in initialization
Returns
-------
FloatTensor of shape (B, n_tasks)
"""
return self.predict(h)
class BaseGNNClassifier(nn.Module):
"""GCN based predictor for multitask prediction on molecular graphs
We assume each task requires to perform a binary classification.
Parameters
----------
gnn_out_feats : int
Number of atom representation features after using GNN
n_tasks : int
Number of prediction tasks
classifier_hidden_feats : int
Number of molecular graph features in hidden layers of the MLP Classifier
dropout : float
The probability for dropout. Default to be 0., i.e. no
dropout is performed.
"""
def __init__(self, gnn_out_feats, n_tasks, classifier_hidden_feats=128, dropout=0.):
super(BaseGNNClassifier, self).__init__()
self.gnn_layers = nn.ModuleList()
self.weighted_sum_readout = WeightAndSum(gnn_out_feats)
self.g_feats = 2 * gnn_out_feats
self.soft_classifier = MLPBinaryClassifier(
self.g_feats, classifier_hidden_feats, n_tasks, dropout)
def forward(self, bg, feats):
"""Multi-task prediction for a batch of molecules
Parameters
----------
bg : BatchedDGLGraph
B Batched DGLGraphs for processing multiple molecules in parallel
feats : FloatTensor of shape (N, M0)
Initial features for all atoms in the batch of molecules
Returns
-------
FloatTensor of shape (B, n_tasks)
Soft prediction for all tasks on the batch of molecules
"""
# Update atom features with GNNs
for gnn in self.gnn_layers:
feats = gnn(bg, feats)
# Compute molecule features from atom features
h_g_sum = self.weighted_sum_readout(bg, feats)
with bg.local_scope():
bg.ndata['h'] = feats
h_g_max = dgl.max_nodes(bg, 'h')
if not isinstance(bg, BatchedDGLGraph):
h_g_sum = h_g_sum.unsqueeze(0)
h_g_max = h_g_max.unsqueeze(0)
h_g = torch.cat([h_g_sum, h_g_max], dim=1)
# Multi-task prediction
return self.soft_classifier(h_g)
class GCNClassifier(BaseGNNClassifier):
"""GCN based predictor for multitask prediction on molecular graphs
We assume each task requires to perform a binary classification.
Parameters
----------
in_feats : int
Number of input atom features
gcn_hidden_feats : list of int
gcn_hidden_feats[i] gives the number of output atom features
in the i+1-th gcn layer
n_tasks : int
Number of prediction tasks
classifier_hidden_feats : int
Number of molecular graph features in hidden layers of the MLP Classifier
dropout : float
The probability for dropout. Default to be 0., i.e. no
dropout is performed.
"""
def __init__(self, in_feats, gcn_hidden_feats, n_tasks,
classifier_hidden_feats=128, dropout=0.):
super(GCNClassifier, self).__init__(gnn_out_feats=gcn_hidden_feats[-1],
n_tasks=n_tasks,
classifier_hidden_feats=classifier_hidden_feats,
dropout=dropout)
for i in range(len(gcn_hidden_feats)):
out_feats = gcn_hidden_feats[i]
self.gnn_layers.append(GCNLayer(in_feats, out_feats))
in_feats = out_feats
class GATClassifier(BaseGNNClassifier):
"""GAT based predictor for multitask prediction on molecular graphs.
We assume each task requires to perform a binary classification.
Parameters
----------
in_feats : int
Number of input atom features
"""
def __init__(self, in_feats, gat_hidden_feats, num_heads,
n_tasks, classifier_hidden_feats=128, dropout=0):
super(GATClassifier, self).__init__(gnn_out_feats=gat_hidden_feats[-1],
n_tasks=n_tasks,
classifier_hidden_feats=classifier_hidden_feats,
dropout=dropout)
assert len(gat_hidden_feats) == len(num_heads), \
'Got gat_hidden_feats with length {:d} and num_heads with length {:d}, ' \
'expect them to be the same.'.format(len(gat_hidden_feats), len(num_heads))
num_layers = len(num_heads)
for l in range(num_layers):
if l > 0:
in_feats = gat_hidden_feats[l - 1] * num_heads[l - 1]
if l == num_layers - 1:
agg_mode = 'mean'
agg_act = None
else:
agg_mode = 'flatten'
agg_act = F.elu
self.gnn_layers.append(GATLayer(in_feats, gat_hidden_feats[l], num_heads[l],
feat_drop=dropout, attn_drop=dropout,
agg_mode=agg_mode, activation=agg_act))
# pylint: disable=C0111, C0103, C0200
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
from dgl.nn.pytorch import GraphConv
class GCNLayer(nn.Module):
def __init__(self, in_feats, out_feats, activation=F.relu,
residual=True, batchnorm=True, dropout=0.):
"""Single layer GCN for updating node features
Parameters
----------
in_feats : int
Number of input atom features
out_feats : int
Number of output atom features
activation : activation function
Default to be ReLU
residual : bool
Whether to use residual connection, default to be True
batchnorm : bool
Whether to use batch normalization on the output,
default to be True
dropout : float
The probability for dropout. Default to be 0., i.e. no
dropout is performed.
"""
super(GCNLayer, self).__init__()
self.activation = activation
self.graph_conv = GraphConv(in_feats=in_feats, out_feats=out_feats,
norm=False, activation=activation)
self.dropout = nn.Dropout(dropout)
self.residual = residual
if residual:
self.res_connection = nn.Linear(in_feats, out_feats)
self.bn = batchnorm
if batchnorm:
self.bn_layer = nn.BatchNorm1d(out_feats)
def forward(self, feats, bg):
"""Update atom representations
Parameters
----------
feats : FloatTensor of shape (N, M1)
* N is the total number of atoms in the batched graph
* M1 is the input atom feature size, must match in_feats in initialization
bg : BatchedDGLGraph
Batched DGLGraphs for processing multiple molecules in parallel
Returns
-------
new_feats : FloatTensor of shape (N, M2)
* M2 is the output atom feature size, must match out_feats in initialization
"""
new_feats = self.graph_conv(feats, bg)
if self.residual:
res_feats = self.activation(self.res_connection(feats))
new_feats = new_feats + res_feats
new_feats = self.dropout(new_feats)
if self.bn:
new_feats = self.bn_layer(new_feats)
return new_feats
class MLPBinaryClassifier(nn.Module):
def __init__(self, in_feats, hidden_feats, n_tasks, dropout=0.):
"""MLP for soft binary classification over multiple tasks from molecule representations.
Parameters
----------
in_feats : int
Number of input molecular graph features
hidden_feats : int
Number of molecular graph features in hidden layers
n_tasks : int
Number of tasks, also output size
dropout : float
The probability for dropout. Default to be 0., i.e. no
dropout is performed.
"""
super(MLPBinaryClassifier, self).__init__()
self.predict = nn.Sequential(
nn.Dropout(dropout),
nn.Linear(in_feats, hidden_feats),
nn.ReLU(),
nn.BatchNorm1d(hidden_feats),
nn.Linear(hidden_feats, n_tasks)
)
def forward(self, h):
"""Perform soft binary classification over multiple tasks
Parameters
----------
h : FloatTensor of shape (B, M3)
* B is the number of molecules in a batch
* M3 is the input molecule feature size, must match in_feats in initialization
Returns
-------
FloatTensor of shape (B, n_tasks)
"""
return self.predict(h)
class GCNClassifier(nn.Module):
def __init__(self, in_feats, gcn_hidden_feats, n_tasks, classifier_hidden_feats=128,
dropout=0., atom_data_field='h', atom_weight_field='w'):
"""GCN based predictor for multitask prediction on molecular graphs
We assume each task requires to perform a binary classification.
Parameters
----------
in_feats : int
Number of input atom features
gcn_hidden_feats : list of int
gcn_hidden_feats[i] gives the number of output atom features
in the i+1-th gcn layer
n_tasks : int
Number of prediction tasks
classifier_hidden_feats : int
Number of molecular graph features in hidden layers of the MLP Classifier
dropout : float
The probability for dropout. Default to be 0., i.e. no
dropout is performed.
atom_data_field : str
Name for storing atom features in DGLGraphs
atom_weight_field : str
Name for storing atom weights in DGLGraphs
"""
super(GCNClassifier, self).__init__()
self.atom_data_field = atom_data_field
self.gcn_layers = nn.ModuleList()
for i in range(len(gcn_hidden_feats)):
out_feats = gcn_hidden_feats[i]
self.gcn_layers.append(GCNLayer(in_feats, out_feats))
in_feats = out_feats
self.atom_weight_field = atom_weight_field
self.atom_weighting = nn.Sequential(
nn.Linear(in_feats, 1),
nn.Sigmoid()
)
self.g_feats = 2 * in_feats
self.soft_classifier = MLPBinaryClassifier(
self.g_feats, classifier_hidden_feats, n_tasks, dropout)
def forward(self, feats, bg):
"""Multi-task prediction for a batch of molecules
Parameters
----------
feats : FloatTensor of shape (N, M0)
Initial features for all atoms in the batch of molecules
bg : BatchedDGLGraph
B Batched DGLGraphs for processing multiple molecules in parallel
Returns
-------
FloatTensor of shape (B, n_tasks)
Soft prediction for all tasks on the batch of molecules
"""
# Update atom features
for gcn in self.gcn_layers:
feats = gcn(feats, bg)
# Compute molecule features from atom features
bg.ndata[self.atom_data_field] = feats
bg.ndata[self.atom_weight_field] = self.atom_weighting(feats)
h_g_sum = dgl.sum_nodes(
bg, self.atom_data_field, self.atom_weight_field)
h_g_max = dgl.max_nodes(bg, self.atom_data_field)
h_g = torch.cat([h_g_sum, h_g_max], dim=1)
# Multi-task prediction
return self.soft_classifier(h_g)
# pylint: disable=C0103, E1101
"""GNN layers for updating atom representations"""
import torch.nn as nn
import torch.nn.functional as F
from ...nn.pytorch import GraphConv, GATConv
class GCNLayer(nn.Module):
"""Single layer GCN for updating node features
Parameters
----------
in_feats : int
Number of input atom features
out_feats : int
Number of output atom features
activation : activation function
Default to be ReLU
residual : bool
Whether to use residual connection, default to be True
batchnorm : bool
Whether to use batch normalization on the output,
default to be True
dropout : float
The probability for dropout. Default to be 0., i.e. no
dropout is performed.
"""
def __init__(self, in_feats, out_feats, activation=F.relu,
residual=True, batchnorm=True, dropout=0.):
super(GCNLayer, self).__init__()
self.activation = activation
self.graph_conv = GraphConv(in_feats=in_feats, out_feats=out_feats,
norm=False, activation=activation)
self.dropout = nn.Dropout(dropout)
self.residual = residual
if residual:
self.res_connection = nn.Linear(in_feats, out_feats)
self.bn = batchnorm
if batchnorm:
self.bn_layer = nn.BatchNorm1d(out_feats)
def forward(self, bg, feats):
"""Update atom representations
Parameters
----------
bg : BatchedDGLGraph
Batched DGLGraphs for processing multiple molecules in parallel
feats : FloatTensor of shape (N, M1)
* N is the total number of atoms in the batched graph
* M1 is the input atom feature size, must match in_feats in initialization
Returns
-------
new_feats : FloatTensor of shape (N, M2)
* M2 is the output atom feature size, must match out_feats in initialization
"""
new_feats = self.graph_conv(bg, feats)
if self.residual:
res_feats = self.activation(self.res_connection(feats))
new_feats = new_feats + res_feats
new_feats = self.dropout(new_feats)
if self.bn:
new_feats = self.bn_layer(new_feats)
return new_feats
class GATLayer(nn.Module):
"""Single layer GAT for updating node features
Parameters
----------
in_feats : int
Number of input atom features
out_feats : int
Number of output atom features for each attention head
num_heads : int
Number of attention heads
feat_drop : float
Dropout applied to the input features
attn_drop : float
Dropout applied to attention values of edges
alpha : float
Hyperparameter in LeakyReLU, slope for negative values. Default to be 0.2
residual : bool
Whether to perform skip connection, default to be False
agg_mode : str
The way to aggregate multi-head attention results, can be either
'flatten' for concatenating all head results or 'mean' for averaging
all head results
activation : activation function or None
Activation function applied to aggregated multi-head results, default to be None.
"""
def __init__(self, in_feats, out_feats, num_heads, feat_drop, attn_drop,
alpha=0.2, residual=True, agg_mode='flatten', activation=None):
super(GATLayer, self).__init__()
self.gnn = GATConv(in_feats=in_feats, out_feats=out_feats, num_heads=num_heads,
feat_drop=feat_drop, attn_drop=attn_drop,
negative_slope=alpha, residual=residual)
assert agg_mode in ['flatten', 'mean']
self.agg_mode = agg_mode
self.activation = activation
def forward(self, bg, feats):
"""Update atom representations
Parameters
----------
bg : BatchedDGLGraph
Batched DGLGraphs for processing multiple molecules in parallel
feats : FloatTensor of shape (N, M1)
* N is the total number of atoms in the batched graph
* M1 is the input atom feature size, must match in_feats in initialization
Returns
-------
new_feats : FloatTensor of shape (N, M2)
* M2 is the output atom feature size. If self.agg_mode == 'flatten', this would
be out_feats * num_heads, else it would be just out_feats.
"""
new_feats = self.gnn(bg, feats)
if self.agg_mode == 'flatten':
new_feats = new_feats.flatten(1)
else:
new_feats = new_feats.mean(1)
if self.activation is not None:
new_feats = self.activation(new_feats)
return new_feats
...@@ -3,39 +3,46 @@ ...@@ -3,39 +3,46 @@
""" """
The implementation of neural network layers used in SchNet and MGCN. The implementation of neural network layers used in SchNet and MGCN.
""" """
import torch as th import torch as th
import torch.nn as nn import torch.nn as nn
from torch.nn import Softplus from torch.nn import Softplus
import numpy as np import numpy as np
import dgl.function as fn
from ... import function as fn
class AtomEmbedding(nn.Module): class AtomEmbedding(nn.Module):
""" """
Convert the atom(node) list to atom embeddings. Convert the atom(node) list to atom embeddings.
The atoms with the same element share the same initial embeddding. The atoms with the same element share the same initial embedding.
Parameters
----------
dim : int
Dim of embeddings, default to be 128.
type_num : int
The largest atomic number of atoms in the dataset, default to be 100.
pre_train : None or pre-trained embeddings
Pre-trained embeddings, default to be None.
""" """
def __init__(self, dim=128, type_num=100, pre_train=None): def __init__(self, dim=128, type_num=100, pre_train=None):
""" super(AtomEmbedding, self).__init__()
Randomly init the element embeddings.
Args:
dim: the dim of embeddings
type_num: the largest atomic number of atoms in the dataset
pre_train: the pre_trained embeddings
"""
super().__init__()
self._dim = dim self._dim = dim
self._type_num = type_num self._type_num = type_num
if pre_train is not None: if pre_train is not None:
self.embedding = nn.Embedding.from_pretrained(pre_train, self.embedding = nn.Embedding.from_pretrained(pre_train, padding_idx=0)
padding_idx=0)
else: else:
self.embedding = nn.Embedding(type_num, dim, padding_idx=0) self.embedding = nn.Embedding(type_num, dim, padding_idx=0)
def forward(self, g, p_name="node"): def forward(self, g, p_name="node"):
"""Input type is dgl graph""" """
Parameters
----------
g : DGLGraph
Input DGLGraph(s)
p_name : str
Name for storing atom embeddings
"""
atom_list = g.ndata["node_type"] atom_list = g.ndata["node_type"]
g.ndata[p_name] = self.embedding(atom_list) g.ndata[p_name] = self.embedding(atom_list)
return g.ndata[p_name] return g.ndata[p_name]
...@@ -43,47 +50,69 @@ class AtomEmbedding(nn.Module): ...@@ -43,47 +50,69 @@ class AtomEmbedding(nn.Module):
class EdgeEmbedding(nn.Module): class EdgeEmbedding(nn.Module):
""" """
Convert the edge to embedding. Module for embedding edges. Edges linking same pairs of atoms share
The edge links same pair of atoms share the same initial embedding. the same initial embedding.
Parameters
----------
dim : int
Dim of edge embeddings, default to be 128.
edge_num : int
Maximum number of edge types, default to be 128.
pre_train : Edge embeddings or None
Pre-trained edge embeddings
""" """
def __init__(self, dim=128, edge_num=3000, pre_train=None): def __init__(self, dim=128, edge_num=3000, pre_train=None):
""" super(EdgeEmbedding, self).__init__()
Randomly init the edge embeddings.
Args:
dim: the dim of embeddings
edge_num: the maximum type of edges
pre_train: the pre_trained embeddings
"""
super().__init__()
self._dim = dim self._dim = dim
self._edge_num = edge_num self._edge_num = edge_num
if pre_train is not None: if pre_train is not None:
self.embedding = nn.Embedding.from_pretrained(pre_train, self.embedding = nn.Embedding.from_pretrained(pre_train, padding_idx=0)
padding_idx=0)
else: else:
self.embedding = nn.Embedding(edge_num, dim, padding_idx=0) self.embedding = nn.Embedding(edge_num, dim, padding_idx=0)
def generate_edge_type(self, edges): def generate_edge_type(self, edges):
""" """Generate edge type.
Generate the edge type based on the src&dst atom type of the edge.
The edge type is based on the type of the src&dst atom.
Note that C-O and O-C are the same edge type. Note that C-O and O-C are the same edge type.
To map a pair of nodes to one number, we use an unordered pairing function here To map a pair of nodes to one number, we use an unordered pairing function here
See more detail in this disscussion: See more detail in this disscussion:
https://math.stackexchange.com/questions/23503/create-unique-number-from-2-numbers https://math.stackexchange.com/questions/23503/create-unique-number-from-2-numbers
Note that, the edge_num should be larger than the square of maximum atomic number Note that, the edge_num should be larger than the square of maximum atomic number
in the dataset. in the dataset.
Parameters
----------
edges : EdgeBatch
Edges for deciding types
Returns
-------
dict
Stores the edge types in "type"
""" """
atom_type_x = edges.src["node_type"] atom_type_x = edges.src["node_type"]
atom_type_y = edges.dst["node_type"] atom_type_y = edges.dst["node_type"]
return { return {
"type": "type": atom_type_x * atom_type_y +
atom_type_x * atom_type_y + (th.abs(atom_type_x - atom_type_y) - 1)**2 / 4
(th.abs(atom_type_x - atom_type_y) - 1)**2 / 4
} }
def forward(self, g, p_name="edge_f"): def forward(self, g, p_name="edge_f"):
"""Compute edge embeddings
Parameters
----------
g : DGLGraph
The graph to compute edge embeddings
p_name : str
Returns
-------
computed edge embeddings
"""
g.apply_edges(self.generate_edge_type) g.apply_edges(self.generate_edge_type)
g.edata[p_name] = self.embedding(g.edata["type"]) g.edata[p_name] = self.embedding(g.edata["type"])
return g.edata[p_name] return g.edata[p_name]
...@@ -93,14 +122,20 @@ class ShiftSoftplus(Softplus): ...@@ -93,14 +122,20 @@ class ShiftSoftplus(Softplus):
""" """
Shiftsoft plus activation function: Shiftsoft plus activation function:
1/beta * (log(1 + exp**(beta * x)) - log(shift)) 1/beta * (log(1 + exp**(beta * x)) - log(shift))
"""
Parameters
----------
beta : int
shift : int
threshold : int
"""
def __init__(self, beta=1, shift=2, threshold=20): def __init__(self, beta=1, shift=2, threshold=20):
super().__init__(beta, threshold) super(ShiftSoftplus, self).__init__(beta, threshold)
self.shift = shift self.shift = shift
self.softplus = Softplus(beta, threshold) self.softplus = Softplus(beta, threshold)
def forward(self, x): def forward(self, x):
"""Applies the activation function"""
return self.softplus(x) - np.log(float(self.shift)) return self.softplus(x) - np.log(float(self.shift))
...@@ -112,9 +147,8 @@ class RBFLayer(nn.Module): ...@@ -112,9 +147,8 @@ class RBFLayer(nn.Module):
gamma = 10 gamma = 10
0 <= mu_k <= 30 for k=1~300 0 <= mu_k <= 30 for k=1~300
""" """
def __init__(self, low=0, high=30, gap=0.1, dim=1): def __init__(self, low=0, high=30, gap=0.1, dim=1):
super().__init__() super(RBFLayer, self).__init__()
self._low = low self._low = low
self._high = high self._high = high
self._gap = gap self._gap = gap
...@@ -145,25 +179,25 @@ class RBFLayer(nn.Module): ...@@ -145,25 +179,25 @@ class RBFLayer(nn.Module):
class CFConv(nn.Module): class CFConv(nn.Module):
""" """
The continuous-filter convolution layer in SchNet. The continuous-filter convolution layer in SchNet.
One CFConv contains one rbf layer and three linear layer
(two of them have activation funct).
"""
def __init__(self, rbf_dim, dim=64, act="sp"): Parameters
""" ----------
Args: rbf_dim : int
rbf_dim: the dimsion of the RBF layer Dimension of the RBF layer output
dim: the dimension of linear layers dim : int
act: activation function (default shifted softplus) Dimension of linear layers, default to be 64
""" act : str or activation function
super().__init__() Activation function, default to be shifted softplus
"""
def __init__(self, rbf_dim, dim=64, act=None):
super(CFConv, self).__init__()
self._rbf_dim = rbf_dim self._rbf_dim = rbf_dim
self._dim = dim self._dim = dim
self.linear_layer1 = nn.Linear(self._rbf_dim, self._dim) self.linear_layer1 = nn.Linear(self._rbf_dim, self._dim)
self.linear_layer2 = nn.Linear(self._dim, self._dim) self.linear_layer2 = nn.Linear(self._dim, self._dim)
if act == "sp": if act is None:
self.activation = nn.Softplus(beta=0.5, threshold=14) self.activation = nn.Softplus(beta=0.5, threshold=14)
else: else:
self.activation = act self.activation = act
...@@ -187,10 +221,16 @@ class CFConv(nn.Module): ...@@ -187,10 +221,16 @@ class CFConv(nn.Module):
class Interaction(nn.Module): class Interaction(nn.Module):
""" """
The interaction layer in the SchNet model. The interaction layer in the SchNet model.
"""
Parameters
----------
rbf_dim : int
Dimension of the RBF layer output
dim : int
Dimension of intermediate representations
"""
def __init__(self, rbf_dim, dim): def __init__(self, rbf_dim, dim):
super().__init__() super(Interaction, self).__init__()
self._node_dim = dim self._node_dim = dim
self.activation = nn.Softplus(beta=0.5, threshold=14) self.activation = nn.Softplus(beta=0.5, threshold=14)
self.node_layer1 = nn.Linear(dim, dim, bias=False) self.node_layer1 = nn.Linear(dim, dim, bias=False)
...@@ -199,7 +239,16 @@ class Interaction(nn.Module): ...@@ -199,7 +239,16 @@ class Interaction(nn.Module):
self.node_layer3 = nn.Linear(dim, dim) self.node_layer3 = nn.Linear(dim, dim)
def forward(self, g): def forward(self, g):
"""Interaction layer forward.""" """
Parameters
----------
g : DGLGraph
Returns
-------
tensor
Updated atom representations
"""
g.ndata["new_node"] = self.node_layer1(g.ndata["node"]) g.ndata["new_node"] = self.node_layer1(g.ndata["node"])
cf_node = self.cfconv(g) cf_node = self.cfconv(g)
cf_node_1 = self.node_layer2(cf_node) cf_node_1 = self.node_layer2(cf_node)
...@@ -213,16 +262,18 @@ class VEConv(nn.Module): ...@@ -213,16 +262,18 @@ class VEConv(nn.Module):
""" """
The Vertex-Edge convolution layer in MGCN which takes edge & vertex features The Vertex-Edge convolution layer in MGCN which takes edge & vertex features
in consideration at the same time. in consideration at the same time.
"""
Parameters
----------
rbf_dim : int
Dimension of the RBF layer output
dim : int
Dimension of intermediate representations, default to be 64.
update_edge : bool
Whether to apply a linear layer to update edge representations, default to be True.
"""
def __init__(self, rbf_dim, dim=64, update_edge=True): def __init__(self, rbf_dim, dim=64, update_edge=True):
""" super(VEConv, self).__init__()
Args:
rbf_dim: the dimension of the RBF layer
dim: the dimension of linear layers
update_edge: whether update the edge emebedding in each conv-layer
"""
super().__init__()
self._rbf_dim = rbf_dim self._rbf_dim = rbf_dim
self._dim = dim self._dim = dim
self._update_edge = update_edge self._update_edge = update_edge
...@@ -234,7 +285,17 @@ class VEConv(nn.Module): ...@@ -234,7 +285,17 @@ class VEConv(nn.Module):
self.activation = nn.Softplus(beta=0.5, threshold=14) self.activation = nn.Softplus(beta=0.5, threshold=14)
def update_rbf(self, edges): def update_rbf(self, edges):
"""Update the RBF features.""" """Update the RBF features
Parameters
----------
edges : EdgeBatch
Returns
-------
dict
Stores updated features in 'h'
"""
rbf = edges.data["rbf"] rbf = edges.data["rbf"]
h = self.linear_layer1(rbf) h = self.linear_layer1(rbf)
h = self.activation(h) h = self.activation(h)
...@@ -242,27 +303,43 @@ class VEConv(nn.Module): ...@@ -242,27 +303,43 @@ class VEConv(nn.Module):
return {"h": h} return {"h": h}
def update_edge(self, edges): def update_edge(self, edges):
"""Update the edge features.""" """Update the edge features.
Parameters
----------
edges : EdgeBatch
Returns
-------
dict
Stores updated features in 'edge_f'
"""
edge_f = edges.data["edge_f"] edge_f = edges.data["edge_f"]
h = self.linear_layer3(edge_f) h = self.linear_layer3(edge_f)
return {"edge_f": h} return {"edge_f": h}
def forward(self, g): def forward(self, g):
"""VEConv layer forward""" """VEConv layer forward
Parameters
----------
g : DGLGraph
Returns
-------
tensor
Updated atom representations
"""
g.apply_edges(self.update_rbf) g.apply_edges(self.update_rbf)
if self._update_edge: if self._update_edge:
g.apply_edges(self.update_edge) g.apply_edges(self.update_edge)
g.update_all(message_func=[ g.update_all(message_func=[fn.u_mul_e("new_node", "h", "m_0"),
fn.u_mul_e("new_node", "h", "m_0"), fn.copy_e("edge_f", "m_1")],
fn.copy_e("edge_f", "m_1") reduce_func=[fn.sum("m_0", "new_node_0"),
], fn.sum("m_1", "new_node_1")])
reduce_func=[ g.ndata["new_node"] = g.ndata.pop("new_node_0") + \
fn.sum("m_0", "new_node_0"), g.ndata.pop("new_node_1")
fn.sum("m_1", "new_node_1")
])
g.ndata["new_node"] = g.ndata.pop("new_node_0") + g.ndata.pop(
"new_node_1")
return g.ndata["new_node"] return g.ndata["new_node"]
...@@ -270,15 +347,19 @@ class VEConv(nn.Module): ...@@ -270,15 +347,19 @@ class VEConv(nn.Module):
class MultiLevelInteraction(nn.Module): class MultiLevelInteraction(nn.Module):
""" """
The multilevel interaction in the MGCN model. The multilevel interaction in the MGCN model.
"""
Parameters
----------
rbf_dim : int
Dimension of the RBF layer output
dim : int
Dimension of intermediate representations
"""
def __init__(self, rbf_dim, dim): def __init__(self, rbf_dim, dim):
super().__init__() super(MultiLevelInteraction, self).__init__()
self._atom_dim = dim self._atom_dim = dim
self.activation = nn.Softplus(beta=0.5, threshold=14) self.activation = nn.Softplus(beta=0.5, threshold=14)
self.node_layer1 = nn.Linear(dim, dim, bias=True) self.node_layer1 = nn.Linear(dim, dim, bias=True)
self.edge_layer1 = nn.Linear(dim, dim, bias=True) self.edge_layer1 = nn.Linear(dim, dim, bias=True)
self.conv_layer = VEConv(rbf_dim, dim) self.conv_layer = VEConv(rbf_dim, dim)
...@@ -287,22 +368,26 @@ class MultiLevelInteraction(nn.Module): ...@@ -287,22 +368,26 @@ class MultiLevelInteraction(nn.Module):
def forward(self, g, level=1): def forward(self, g, level=1):
""" """
MultiLevel Interaction Layer forward. Parameters
Args: ----------
g: DGLGraph g : DGLGraph
level: current level of this layer level : int
Level of interaction
Returns
-------
tensor
Updated atom representations
""" """
g.ndata["new_node"] = self.node_layer1(
g.ndata["new_node"] = self.node_layer1(g.ndata["node_%s" % g.ndata["node_%s" % (level - 1)])
(level - 1)])
node = self.conv_layer(g) node = self.conv_layer(g)
g.edata["edge_f"] = self.activation(self.edge_layer1( g.edata["edge_f"] = self.activation(
g.edata["edge_f"])) self.edge_layer1(g.edata["edge_f"]))
node_1 = self.node_layer2(node) node_1 = self.node_layer2(node)
node_1a = self.activation(node_1) node_1a = self.activation(node_1)
new_node = self.node_layer3(node_1a) new_node = self.node_layer3(node_1a)
g.ndata["node_%s" % (level)] = g.ndata["node_%s" % g.ndata["node_%s" % (level)] = g.ndata["node_%s" % (level - 1)] + new_node
(level - 1)] + new_node
return g.ndata["node_%s" % (level)] return g.ndata["node_%s" % (level)]
# -*- coding:utf-8 -*- # -*- coding:utf-8 -*-
# pylint: disable=C0103, C0111, W0621 # pylint: disable=C0103, C0111, W0621
"""Implementation of MGCN model""" """Implementation of MGCN model"""
import dgl
import torch as th import torch as th
import torch.nn as nn import torch.nn as nn
from .layers import AtomEmbedding, RBFLayer, EdgeEmbedding, \ from .layers import AtomEmbedding, RBFLayer, EdgeEmbedding, \
MultiLevelInteraction MultiLevelInteraction
from ...batched_graph import sum_nodes
class MGCNModel(nn.Module): class MGCNModel(nn.Module):
""" """
MGCN Model from: MGCN from `Molecular Property Prediction: A Multilevel
Chengqiang Lu, et al. Quantum Interactions Modeling Perspective <https://arxiv.org/abs/1906.11081>`__
Molecular Property Prediction: A Multilevel
Quantum Interactions Modeling Perspective. (AAAI'2019) Parameters
----------
dim : int
Dimension of feature maps, default to be 128.
out_put_dim: int
Number of target properties to predict, default to be 1.
edge_dim : int
Dimension of edge feature, default to be 128.
cutoff : float
The maximum distance between nodes, default to be 5.0.
width : int
Width in the RBF layer, default to be 1.
n_conv : int
Number of convolutional layers, default to be 3.
norm : bool
Whether to perform normalization, default to be False.
atom_ref : Atom embeddings or None
If None, random representation initialization will be used. Otherwise,
they will be used to initialize atom representations. Default to be None.
pre_train : Atom embeddings or None
If None, random representation initialization will be used. Otherwise,
they will be used to initialize atom representations. Default to be None.
""" """
def __init__(self, def __init__(self,
dim=128, dim=128,
output_dim=1, output_dim=1,
...@@ -27,21 +47,7 @@ class MGCNModel(nn.Module): ...@@ -27,21 +47,7 @@ class MGCNModel(nn.Module):
norm=False, norm=False,
atom_ref=None, atom_ref=None,
pre_train=None): pre_train=None):
""" super(MGCNModel, self).__init__()
Args:
dim: dimension of feature maps
out_put_dim: the num of target propperties to predict
edge_dim: dimension of edge feature
cutoff: the maximum distance between nodes
width: width in the RBF layer
n_conv: number of convolutional layers
norm: normalization
atom_ref: atom reference
used as the initial value of atom embeddings,
or set to None with random initialization
pre_train: pre_trained node embeddings
"""
super().__init__()
self.name = "MGCN" self.name = "MGCN"
self._dim = dim self._dim = dim
self.output_dim = output_dim self.output_dim = output_dim
...@@ -73,11 +79,32 @@ class MGCNModel(nn.Module): ...@@ -73,11 +79,32 @@ class MGCNModel(nn.Module):
self.node_dense_layer2 = nn.Linear(64, output_dim) self.node_dense_layer2 = nn.Linear(64, output_dim)
def set_mean_std(self, mean, std, device): def set_mean_std(self, mean, std, device):
"""Set the mean and std of atom representations for normalization.
Parameters
----------
mean : list or numpy array
The mean of labels
std : list or numpy array
The std of labels
device : str or torch.device
Device for storing the mean and std
"""
self.mean_per_node = th.tensor(mean, device=device) self.mean_per_node = th.tensor(mean, device=device)
self.std_per_node = th.tensor(std, device=device) self.std_per_node = th.tensor(std, device=device)
def forward(self, g): def forward(self, g):
"""Predict molecule labels
Parameters
----------
g : DGLGraph
Input DGLGraph for molecule(s)
Returns
-------
res : Predicted labels
"""
self.embedding_layer(g, "node_0") self.embedding_layer(g, "node_0")
if self.atom_ref is not None: if self.atom_ref is not None:
self.e0(g, "e0") self.e0(g, "e0")
...@@ -104,5 +131,5 @@ class MGCNModel(nn.Module): ...@@ -104,5 +131,5 @@ class MGCNModel(nn.Module):
if self.norm: if self.norm:
g.ndata["res"] = g.ndata[ g.ndata["res"] = g.ndata[
"res"] * self.std_per_node + self.mean_per_node "res"] * self.std_per_node + self.mean_per_node
res = dgl.sum_nodes(g, "res") res = sum_nodes(g, "res")
return res return res
...@@ -2,36 +2,39 @@ ...@@ -2,36 +2,39 @@
# coding: utf-8 # coding: utf-8
# pylint: disable=C0103, C0111, E1101, W0612 # pylint: disable=C0103, C0111, E1101, W0612
"""Implementation of MPNN model.""" """Implementation of MPNN model."""
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.nn import Parameter from torch.nn import Parameter
import dgl.function as fn
import dgl.nn.pytorch as dgl_nn from ... import function as fn
from ...nn.pytorch import Set2Set
class NNConvLayer(nn.Module): class NNConvLayer(nn.Module):
""" """
MPNN Conv Layer from Section.5 in the paper "Neural Message Passing for Quantum Chemistry." MPNN Conv Layer from Section 5 of
`Neural Message Passing for Quantum Chemistry <https://arxiv.org/abs/1704.01212>`__
Parameters
----------
in_channels : int
Number of input channels
out_channels : int
Number of output channels
edge_net : Module processing edge information
root_weight : bool
Whether to add the root node feature to output
bias : bool
Whether to add bias to the output
""" """
def __init__(self, def __init__(self,
in_channels, in_channels,
out_channels, out_channels,
edge_net, edge_net,
root_weight=True, root_weight=True,
bias=True): bias=True):
""" super(NNConvLayer, self).__init__()
Args:
in_channels: number of input channels
out_channels: number of output channels
edge_net: the network modules process the edge info
root_weight: whether add the root node feature to output
bias: whether add bias to the output
"""
super().__init__()
self.in_channels = in_channels self.in_channels = in_channels
self.out_channels = out_channels self.out_channels = out_channels
...@@ -50,6 +53,7 @@ class NNConvLayer(nn.Module): ...@@ -50,6 +53,7 @@ class NNConvLayer(nn.Module):
self.reset_parameters() self.reset_parameters()
def reset_parameters(self): def reset_parameters(self):
"""Reinitialize model parameters"""
if self.root is not None: if self.root is not None:
nn.init.xavier_normal_(self.root.data, gain=1.414) nn.init.xavier_normal_(self.root.data, gain=1.414)
if self.bias is not None: if self.bias is not None:
...@@ -59,6 +63,18 @@ class NNConvLayer(nn.Module): ...@@ -59,6 +63,18 @@ class NNConvLayer(nn.Module):
nn.init.xavier_normal_(m.weight.data, gain=1.414) nn.init.xavier_normal_(m.weight.data, gain=1.414)
def message(self, edges): def message(self, edges):
"""Function for computing messages from source nodes
Parameters
----------
edges : EdgeBatch
Edges over which we want to send messages
Returns
-------
dict
Stores message in key 'm'
"""
return { return {
'm': 'm':
torch.matmul(edges.src['h'].unsqueeze(1), torch.matmul(edges.src['h'].unsqueeze(1),
...@@ -66,6 +82,17 @@ class NNConvLayer(nn.Module): ...@@ -66,6 +82,17 @@ class NNConvLayer(nn.Module):
} }
def apply_node_func(self, nodes): def apply_node_func(self, nodes):
"""Function for updating node features directly
Parameters
----------
nodes : NodeBatch
Returns
-------
dict
Stores updated node features in 'h'
"""
aggr_out = nodes.data['aggr_out'] aggr_out = nodes.data['aggr_out']
if self.root is not None: if self.root is not None:
aggr_out = torch.mm(nodes.data['h'], self.root) + aggr_out aggr_out = torch.mm(nodes.data['h'], self.root) + aggr_out
...@@ -76,7 +103,23 @@ class NNConvLayer(nn.Module): ...@@ -76,7 +103,23 @@ class NNConvLayer(nn.Module):
return {'h': aggr_out} return {'h': aggr_out}
def forward(self, g, h, e): def forward(self, g, h, e):
"""MPNN Conv layer forward.""" """Propagate messages and aggregate results for updating
atom representations
Parameters
----------
g : DGLGraph
DGLgraph(s) for molecules
h : tensor
Input atom representations
e : tensor
Input bond representations
Returns
-------
tensor
Aggregated atom information
"""
h = h.unsqueeze(-1) if h.dim() == 1 else h h = h.unsqueeze(-1) if h.dim() == 1 else h
e = e.unsqueeze(-1) if e.dim() == 1 else e e = e.unsqueeze(-1) if e.dim() == 1 else e
...@@ -90,11 +133,28 @@ class NNConvLayer(nn.Module): ...@@ -90,11 +133,28 @@ class NNConvLayer(nn.Module):
class MPNNModel(nn.Module): class MPNNModel(nn.Module):
""" """
MPNN model from: MPNN from
Gilmer, Justin, et al. `Neural Message Passing for Quantum Chemistry <https://arxiv.org/abs/1704.01212>`__
Neural message passing for quantum chemistry.
Parameters
----------
node_input_dim : int
Dimension of input node feature, default to be 15.
edge_input_dim : int
Dimension of input edge feature, default to be 15.
output_dim : int
Dimension of prediction, default to be 12.
node_hidden_dim : int
Dimension of node feature in hidden layers, default to be 64.
edge_hidden_dim : int
Dimension of edge feature in hidden layers, default to be 128.
num_step_message_passing : int
Number of message passing steps, default to be 6.
num_step_set2set : int
Number of set2set steps
num_layer_set2set : int
Number of set2set layers
""" """
def __init__(self, def __init__(self,
node_input_dim=15, node_input_dim=15,
edge_input_dim=5, edge_input_dim=5,
...@@ -104,20 +164,8 @@ class MPNNModel(nn.Module): ...@@ -104,20 +164,8 @@ class MPNNModel(nn.Module):
num_step_message_passing=6, num_step_message_passing=6,
num_step_set2set=6, num_step_set2set=6,
num_layer_set2set=3): num_layer_set2set=3):
"""model parameters setting super(MPNNModel, self).__init__()
Args:
node_input_dim: dimension of input node feature
edge_input_dim: dimension of input edge feature
output_dim: dimension of prediction
node_hidden_dim: dimension of node feature in hidden layers
edge_hidden_dim: dimension of edge feature in hidden layers
num_step_message_passing: number of message passing steps
num_step_set2set: number of set2set steps
num_layer_ste2set: number of set2set layers
"""
super().__init__()
self.name = "MPNN" self.name = "MPNN"
self.num_step_message_passing = num_step_message_passing self.num_step_message_passing = num_step_message_passing
self.lin0 = nn.Linear(node_input_dim, node_hidden_dim) self.lin0 = nn.Linear(node_input_dim, node_hidden_dim)
...@@ -130,12 +178,22 @@ class MPNNModel(nn.Module): ...@@ -130,12 +178,22 @@ class MPNNModel(nn.Module):
root_weight=False) root_weight=False)
self.gru = nn.GRU(node_hidden_dim, node_hidden_dim) self.gru = nn.GRU(node_hidden_dim, node_hidden_dim)
self.set2set = dgl_nn.glob.Set2Set(node_hidden_dim, num_step_set2set, self.set2set = Set2Set(node_hidden_dim, num_step_set2set, num_layer_set2set)
num_layer_set2set)
self.lin1 = nn.Linear(2 * node_hidden_dim, node_hidden_dim) self.lin1 = nn.Linear(2 * node_hidden_dim, node_hidden_dim)
self.lin2 = nn.Linear(node_hidden_dim, output_dim) self.lin2 = nn.Linear(node_hidden_dim, output_dim)
def forward(self, g): def forward(self, g):
"""Predict molecule labels
Parameters
----------
g : DGLGraph
Input DGLGraph for molecule(s)
Returns
-------
res : Predicted labels
"""
h = g.ndata['n_feat'] h = g.ndata['n_feat']
out = F.relu(self.lin0(h)) out = F.relu(self.lin0(h))
h = out.unsqueeze(0) h = out.unsqueeze(0)
......
...@@ -2,9 +2,9 @@ ...@@ -2,9 +2,9 @@
import torch import torch
from rdkit import Chem from rdkit import Chem
from .dgmg import DGMG
from .gcn import GCNClassifier
from . import DGLJTNNVAE from . import DGLJTNNVAE
from .classifiers import GCNClassifier, GATClassifier
from .dgmg import DGMG
from .mgcn import MGCNModel from .mgcn import MGCNModel
from .mpnn import MPNNModel from .mpnn import MPNNModel
from .sch import SchNetModel from .sch import SchNetModel
...@@ -12,6 +12,7 @@ from ...data.utils import _get_dgl_url, download, get_download_dir ...@@ -12,6 +12,7 @@ from ...data.utils import _get_dgl_url, download, get_download_dir
URL = { URL = {
'GCN_Tox21' : 'pre_trained/gcn_tox21.pth', 'GCN_Tox21' : 'pre_trained/gcn_tox21.pth',
'GAT_Tox21' : 'pre_trained/gat_tox21.pth',
'MGCN_Alchemy': 'pre_trained/mgcn_alchemy.pth', 'MGCN_Alchemy': 'pre_trained/mgcn_alchemy.pth',
'SCHNET_Alchemy': 'pre_trained/schnet_alchemy.pth', 'SCHNET_Alchemy': 'pre_trained/schnet_alchemy.pth',
'MPNN_Alchemy': 'pre_trained/mpnn_alchemy.pth', 'MPNN_Alchemy': 'pre_trained/mpnn_alchemy.pth',
...@@ -73,8 +74,15 @@ def load_pretrained(model_name, log=True): ...@@ -73,8 +74,15 @@ def load_pretrained(model_name, log=True):
if model_name == 'GCN_Tox21': if model_name == 'GCN_Tox21':
model = GCNClassifier(in_feats=74, model = GCNClassifier(in_feats=74,
gcn_hidden_feats=[64, 64], gcn_hidden_feats=[64, 64],
n_tasks=12, classifier_hidden_feats=64,
classifier_hidden_feats=64) n_tasks=12)
elif model_name == 'GAT_Tox21':
model = GATClassifier(in_feats=74,
gat_hidden_feats=[32, 32],
num_heads=[4, 4],
classifier_hidden_feats=64,
n_tasks=12)
elif model_name.startswith('DGMG'): elif model_name.startswith('DGMG'):
if model_name.startswith('DGMG_ChEMBL'): if model_name.startswith('DGMG_ChEMBL'):
......
# -*- coding:utf-8 -*- # -*- coding:utf-8 -*-
# pylint: disable=C0103, C0111, W0621 # pylint: disable=C0103, C0111, W0621
"""Implementation of SchNet model.""" """Implementation of SchNet model."""
import dgl
import torch as th import torch as th
import torch.nn as nn import torch.nn as nn
from .layers import AtomEmbedding, Interaction, ShiftSoftplus, RBFLayer from .layers import AtomEmbedding, Interaction, ShiftSoftplus, RBFLayer
from ...batched_graph import sum_nodes
class SchNetModel(nn.Module): class SchNetModel(nn.Module):
""" """
SchNet Model from: `SchNet: A continuous-filter convolutional neural network for modeling
Schütt, Kristof, et al. quantum interactions. (NIPS'2017) <https://arxiv.org/abs/1706.08566>`__
SchNet: A continuous-filter convolutional neural network
for modeling quantum interactions. (NIPS'2017)
"""
Parameters
----------
dim : int
Dimension of features, default to be 64
cutoff : float
Radius cutoff for RBF, default to be 5.0
output_dim : int
Dimension of prediction, default to be 1
width : int
Width in RBF, default to 1
n_conv : int
Number of conv (interaction) layers, default to be 1
norm : bool
Whether to normalize the output atom representations, default to be False.
atom_ref : Atom embeddings or None
If None, random representation initialization will be used. Otherwise,
they will be used to initialize atom representations. Default to be None.
pre_train : Atom embeddings or None
If None, random representation initialization will be used. Otherwise,
they will be used to initialize atom representations. Default to be None.
"""
def __init__(self, def __init__(self,
dim=64, dim=64,
cutoff=5.0, cutoff=5.0,
...@@ -25,17 +43,6 @@ class SchNetModel(nn.Module): ...@@ -25,17 +43,6 @@ class SchNetModel(nn.Module):
norm=False, norm=False,
atom_ref=None, atom_ref=None,
pre_train=None): pre_train=None):
"""
Args:
dim: dimension of features
output_dim: dimension of prediction
cutoff: radius cutoff
width: width in the RBF function
n_conv: number of interaction layers
atom_ref: used as the initial value of atom embeddings,
or set to None with random initialization
norm: normalization
"""
super().__init__() super().__init__()
self.name = "SchNet" self.name = "SchNet"
self._dim = dim self._dim = dim
...@@ -60,11 +67,32 @@ class SchNetModel(nn.Module): ...@@ -60,11 +67,32 @@ class SchNetModel(nn.Module):
self.atom_dense_layer2 = nn.Linear(64, output_dim) self.atom_dense_layer2 = nn.Linear(64, output_dim)
def set_mean_std(self, mean, std, device="cpu"): def set_mean_std(self, mean, std, device="cpu"):
"""Set the mean and std of atom representations for normalization.
Parameters
----------
mean : list or numpy array
The mean of labels
std : list or numpy array
The std of labels
device : str or torch.device
Device for storing the mean and std
"""
self.mean_per_atom = th.tensor(mean, device=device) self.mean_per_atom = th.tensor(mean, device=device)
self.std_per_atom = th.tensor(std, device=device) self.std_per_atom = th.tensor(std, device=device)
def forward(self, g): def forward(self, g):
"""g is the DGLGraph""" """Predict molecule labels
Parameters
----------
g : DGLGraph
Input DGLGraph for molecule(s)
Returns
-------
res : Predicted labels
"""
self.embedding_layer(g) self.embedding_layer(g)
if self.atom_ref is not None: if self.atom_ref is not None:
self.e0(g, "e0") self.e0(g, "e0")
...@@ -81,7 +109,6 @@ class SchNetModel(nn.Module): ...@@ -81,7 +109,6 @@ class SchNetModel(nn.Module):
g.ndata["res"] = g.ndata["res"] + g.ndata["e0"] g.ndata["res"] = g.ndata["res"] + g.ndata["e0"]
if self.norm: if self.norm:
g.ndata["res"] = g.ndata[ g.ndata["res"] = g.ndata["res"] * self.std_per_atom + self.mean_per_atom
"res"] * self.std_per_atom + self.mean_per_atom res = sum_nodes(g, "res")
res = dgl.sum_nodes(g, "res")
return res return res
...@@ -12,7 +12,7 @@ from ...batched_graph import sum_nodes, mean_nodes, max_nodes, broadcast_nodes,\ ...@@ -12,7 +12,7 @@ from ...batched_graph import sum_nodes, mean_nodes, max_nodes, broadcast_nodes,\
__all__ = ['SumPooling', 'AvgPooling', 'MaxPooling', 'SortPooling', __all__ = ['SumPooling', 'AvgPooling', 'MaxPooling', 'SortPooling',
'GlobalAttentionPooling', 'Set2Set', 'GlobalAttentionPooling', 'Set2Set',
'SetTransformerEncoder', 'SetTransformerDecoder'] 'SetTransformerEncoder', 'SetTransformerDecoder', 'WeightAndSum']
class SumPooling(nn.Module): class SumPooling(nn.Module):
r"""Apply sum pooling over the nodes in the graph. r"""Apply sum pooling over the nodes in the graph.
...@@ -668,3 +668,43 @@ class SetTransformerDecoder(nn.Module): ...@@ -668,3 +668,43 @@ class SetTransformerDecoder(nn.Module):
return feat.view(graph.batch_size, self.k * self.d_model) return feat.view(graph.batch_size, self.k * self.d_model)
else: else:
return feat.view(self.k * self.d_model) return feat.view(self.k * self.d_model)
class WeightAndSum(nn.Module):
"""Compute importance weights for atoms and perform a weighted sum.
Parameters
----------
in_feats : int
Input atom feature size
"""
def __init__(self, in_feats):
super(WeightAndSum, self).__init__()
self.in_feats = in_feats
self.atom_weighting = nn.Sequential(
nn.Linear(in_feats, 1),
nn.Sigmoid()
)
def forward(self, bg, feats):
"""Compute molecule representations out of atom representations
Parameters
----------
bg : BatchedDGLGraph
B Batched DGLGraphs for processing multiple molecules in parallel
feats : FloatTensor of shape (N, self.in_feats)
Representations for all atoms in the molecules
* N is the total number of atoms in all molecules
Returns
-------
FloatTensor of shape (B, self.in_feats)
Representations for B molecules
"""
with bg.local_scope():
bg.ndata['h'] = feats
bg.ndata['w'] = self.atom_weighting(bg.ndata['h'])
h_g_sum = sum_nodes(bg, 'h', 'w')
return h_g_sum
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment