Unverified Commit 189c2c09 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[Model Zoo] Refactor Model Zoo for Chemistry (#839)

* Update

* Update

* Update

* Update fix

* Update

* Update

* Refactor

* Update

* Update

* Update

* Update

* Update

* Update

* Fix style
parent 5b417683
...@@ -5,13 +5,17 @@ ...@@ -5,13 +5,17 @@
[Documentation](https://docs.dgl.ai) | [DGL at a glance](https://docs.dgl.ai/tutorials/basics/1_first.html#sphx-glr-tutorials-basics-1-first-py) | [Documentation](https://docs.dgl.ai) | [DGL at a glance](https://docs.dgl.ai/tutorials/basics/1_first.html#sphx-glr-tutorials-basics-1-first-py) |
[Model Tutorials](https://docs.dgl.ai/tutorials/models/index.html) | [Discussion Forum](https://discuss.dgl.ai) [Model Tutorials](https://docs.dgl.ai/tutorials/models/index.html) | [Discussion Forum](https://discuss.dgl.ai)
Model Zoos: [Chemistry](https://github.com/dmlc/dgl/tree/master/examples/pytorch/model_zoo) | [Citation Networks](https://github.com/dmlc/dgl/tree/master/examples/pytorch/model_zoo/citation_network)
DGL is a Python package that interfaces between existing tensor libraries and data being expressed as DGL is a Python package that interfaces between existing tensor libraries and data being expressed as
graphs. graphs.
It makes implementing graph neural networks (including Graph Convolution Networks, TreeLSTM, and many others) easy while It makes implementing graph neural networks (including Graph Convolution Networks, TreeLSTM, and many others) easy while
maintaining high computation efficiency. maintaining high computation efficiency.
A summary of the model accuracy and training speed with the Pytorch backend (on Amazon EC2 p3.2x instance (w/ V100 GPU)), as compared with the best open-source implementations: All model examples can be found [here](https://github.com/dmlc/dgl/tree/master/examples).
A summary of part of the model accuracy and training speed with the Pytorch backend (on Amazon EC2 p3.2x instance (w/ V100 GPU)), as compared with the best open-source implementations:
| Model | Reported <br> Accuracy | DGL <br> Accuracy | Author's training speed (epoch time) | DGL speed (epoch time) | Improvement | | Model | Reported <br> Accuracy | DGL <br> Accuracy | Author's training speed (epoch time) | DGL speed (epoch time) | Improvement |
| ----- | ----------------- | ------------ | ------------------------------------ | ---------------------- | ----------- | | ----- | ----------------- | ------------ | ------------------------------------ | ---------------------- | ----------- |
......
...@@ -15,6 +15,10 @@ Utils ...@@ -15,6 +15,10 @@ Utils
utils.download utils.download
utils.check_sha1 utils.check_sha1
utils.extract_archive utils.extract_archive
utils.split_dataset
.. autoclass:: dgl.data.utils.Subset
:members: __getitem__, __len__
Dataset Classes Dataset Classes
--------------- ---------------
...@@ -57,3 +61,54 @@ Protein-Protein Interaction dataset ...@@ -57,3 +61,54 @@ Protein-Protein Interaction dataset
.. autoclass:: PPIDataset .. autoclass:: PPIDataset
:members: __getitem__, __len__ :members: __getitem__, __len__
Molecular Graphs
----------------
To work on molecular graphs, make sure you have installed `RDKit 2018.09.3 <https://www.rdkit.org/docs/Install.html>`__.
Featurization
`````````````
For the use of graph neural networks, we need to featurize nodes (atoms) and edges (bonds). Below we list some
featurization methods/utilities:
.. autosummary::
:toctree: ../../generated/
chem.one_hot_encoding
chem.BaseAtomFeaturizer
chem.CanonicalAtomFeaturizer
Graph Construction
``````````````````
Several methods for constructing DGLGraphs from SMILES/RDKit molecule objects are listed below:
.. autosummary::
:toctree: ../../generated/
chem.mol_to_graph
chem.smile_to_bigraph
chem.mol_to_bigraph
chem.smile_to_complete_graph
chem.mol_to_complete_graph
Dataset Classes
```````````````
If your dataset is stored in a ``.csv`` file, you may find it helpful to use
.. autoclass:: dgl.data.chem.CSVDataset
:members: __getitem__, __len__
Currently two datasets are supported:
* Tox21
* TencentAlchemyDataset
.. autoclass:: dgl.data.chem.Tox21
:members: __getitem__, __len__, task_pos_weights
.. autoclass:: dgl.data.chem.TencentAlchemyDataset
:members: __getitem__, __len__, set_mean_and_std
...@@ -19,3 +19,4 @@ API Reference ...@@ -19,3 +19,4 @@ API Reference
graph_store graph_store
nodeflow nodeflow
random random
model_zoo
.. _apimodelzoo:
Model Zoo
=========
.. currentmodule:: dgl.model_zoo
Chemistry
---------
Utils
`````
.. autosummary::
:toctree: ../../generated/
chem.load_pretrained
Property Prediction
```````````````````
Currently supported model architectures:
* GCNClassifier
* GATClassifier
* MPNN
* SchNet
* MGCN
.. autoclass:: dgl.model_zoo.chem.GCNClassifier
:members: forward
.. autoclass:: dgl.model_zoo.chem.GATClassifier
:members: forward
.. autoclass:: dgl.model_zoo.chem.MPNNModel
:members: forward
.. autoclass:: dgl.model_zoo.chem.SchNet
:members: forward
.. autoclass:: dgl.model_zoo.chem.MGCNModel
:members: forward
Generative Models
`````````````````
Currently supported model architectures:
* DGMG
* JTNN
.. autoclass:: dgl.model_zoo.chem.DGMG
:members: forward
.. autoclass:: dgl.model_zoo.chem.DGLJTNNVAE
:members: forward
...@@ -86,6 +86,7 @@ are also two accompanying review papers that are well written [7], [8]. ...@@ -86,6 +86,7 @@ are also two accompanying review papers that are well written [7], [8].
### Models ### Models
- **Deep Generative Models of Graphs (DGMG)** [11]: A very general framework for graph distribution learning by - **Deep Generative Models of Graphs (DGMG)** [11]: A very general framework for graph distribution learning by
progressively adding atoms and bonds. progressively adding atoms and bonds.
- **Junction Tree Variational Autoencoder for Molecular Graph Generation (JTNN)** [13]:
### Example Usage of Pre-trained Models ### Example Usage of Pre-trained Models
...@@ -143,3 +144,6 @@ Machine Learning* JMLR. 1263-1272. ...@@ -143,3 +144,6 @@ Machine Learning* JMLR. 1263-1272.
[11] Li et al. (2018) Learning Deep Generative Models of Graphs. *arXiv preprint arXiv:1803.03324*. [11] Li et al. (2018) Learning Deep Generative Models of Graphs. *arXiv preprint arXiv:1803.03324*.
[12] Goh et al. (2017) Deep learning for computational chemistry. *Journal of Computational Chemistry* 16, 1291-1307. [12] Goh et al. (2017) Deep learning for computational chemistry. *Journal of Computational Chemistry* 16, 1291-1307.
[13] Jin et al. (2018) Junction Tree Variational Autoencoder for Molecular Graph Generation.
*Proceedings of the 35th International Conference on Machine Learning (ICML)*, 2323-2332.
...@@ -15,7 +15,7 @@ into training, validation and test set with a 80/10/10 ratio. By default we foll ...@@ -15,7 +15,7 @@ into training, validation and test set with a 80/10/10 ratio. By default we foll
- **Graph Convolutional Network** [2], [3]. Graph Convolutional Networks (GCN) have been one of the most popular graph neural - **Graph Convolutional Network** [2], [3]. Graph Convolutional Networks (GCN) have been one of the most popular graph neural
networks and they can be easily extended for graph level prediction. MoleculeNet [1] reports baseline results of graph networks and they can be easily extended for graph level prediction. MoleculeNet [1] reports baseline results of graph
convolutions over multiple datasets. convolutions over multiple datasets.
- **Graph Attention Networks** [7]: Graph Attention Networks (GATs) incorporate multi-head attention into GCNs, - **Graph Attention Networks** [7]. Graph Attention Networks (GATs) incorporate multi-head attention into GCNs,
explicitly modeling the interactions between adjacent atoms. explicitly modeling the interactions between adjacent atoms.
### Usage ### Usage
...@@ -49,16 +49,11 @@ a real difference. ...@@ -49,16 +49,11 @@ a real difference.
| ---------------- | ---------------------- | | ---------------- | ---------------------- |
| Pretrained model | 0.827 | | Pretrained model | 0.827 |
## Dataset Customization
To customize your own dataset, see the instructions
[here](https://github.com/dmlc/dgl/tree/master/python/dgl/data/chem).
## Regression ## Regression
Regression tasks require assigning continuous labels to a molecule, e.g. molecular energy. Regression tasks require assigning continuous labels to a molecule, e.g. molecular energy.
### Dataset ### Datasets
- **Alchemy**. The [Alchemy Dataset](https://alchemy.tencent.com/) is introduced by Tencent Quantum Lab to facilitate the development of new - **Alchemy**. The [Alchemy Dataset](https://alchemy.tencent.com/) is introduced by Tencent Quantum Lab to facilitate the development of new
machine learning models useful for chemistry and materials science. The dataset lists 12 quantum mechanical properties of 130,000+ organic machine learning models useful for chemistry and materials science. The dataset lists 12 quantum mechanical properties of 130,000+ organic
...@@ -68,29 +63,39 @@ These properties have been calculated using the open-source computational chemis ...@@ -68,29 +63,39 @@ These properties have been calculated using the open-source computational chemis
### Models ### Models
- **SchNet**: SchNet is a novel deep learning architecture modeling quantum interactions in molecules which utilize the continuous-filter - **Message Passing Neural Network** [6]. Message Passing Neural Networks (MPNNs) have reached the best performance on
convolutional layers [4]. the QM9 dataset for some time.
- **Multilevel Graph Convolutional neural Network**: Multilevel Graph Convolutional neural Network (MGCN) is a hierarchical - **SchNet** [4]. SchNet employs continuous filter convolutional layers to model quantum interactions in molecules
graph neural network directly extracts features from the conformation and spatial information followed by the multilevel interactions [5]. without requiring them to lie on grids.
- **Message Passing Neural Network**: Message Passing Neural Network (MPNN) is a network with edge network (enn) as front end - **Multilevel Graph Convolutional Neural Network** [5]. Multilevel Graph Convolutional Neural Networks (MGCN) are
and Set2Set for output prediction [6]. hierarchical graph neural networks that extract features from the conformation and spatial information followed by the
multilevel interactions.
### Usage ### Usage
```py Use `regression.py` with arguments
python regression.py --model sch --epoch 200 ```
``` -m {MPNN,SCHNET,MGCN}, Model to use
The model option must be one of 'sch', 'mgcn' or 'mpnn'. -d {Alchemy}, Dataset to use
```
### Performance ### Performance
#### Alchemy #### Alchemy
|Model |Mean Absolute Error (MAE)| The Alchemy contest is still ongoing. Before the test set is fully released, we only include the performance numbers
|-------------|-------------------------| on the training and validation set for reference.
|SchNet[4] |0.065|
|MGCN[5] |0.050| | Model | Training MAE | Validation MAE |
|MPNN[6] |0.056| | ---------- | ------------ | -------------- |
| SchNet [4] | 0.2665 | 0.6139 |
| MGCN [5] | 0.2395 | 0.6463 |
| MPNN [6] | 0.2452 | 0.6259 |
## Dataset Customization
To customize your own dataset, see the instructions
[here](https://github.com/dmlc/dgl/tree/master/python/dgl/data/chem).
## References ## References
[1] Wu et al. (2017) MoleculeNet: a benchmark for molecular machine learning. *Chemical Science* 9, 513-530. [1] Wu et al. (2017) MoleculeNet: a benchmark for molecular machine learning. *Chemical Science* 9, 513-530.
......
from dgl.data.utils import split_dataset
from dgl import model_zoo
import torch import torch
from torch.nn import BCEWithLogitsLoss from torch.nn import BCEWithLogitsLoss
from torch.optim import Adam from torch.optim import Adam
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from utils import Meter, EarlyStopping, collate_molgraphs, set_random_seed from dgl import model_zoo
from dgl.data.utils import split_dataset
from utils import Meter, EarlyStopping, collate_molgraphs_for_classification, set_random_seed
def run_a_train_epoch(args, epoch, model, data_loader, loss_criterion, optimizer): def run_a_train_epoch(args, epoch, model, data_loader, loss_criterion, optimizer):
model.train() model.train()
...@@ -45,15 +46,18 @@ def main(args): ...@@ -45,15 +46,18 @@ def main(args):
args['device'] = "cuda" if torch.cuda.is_available() else "cpu" args['device'] = "cuda" if torch.cuda.is_available() else "cpu"
set_random_seed() set_random_seed()
# Interchangeable with other Dataset # Interchangeable with other datasets
if args['dataset'] == 'Tox21': if args['dataset'] == 'Tox21':
from dgl.data.chem import Tox21 from dgl.data.chem import Tox21
dataset = Tox21() dataset = Tox21()
trainset, valset, testset = split_dataset(dataset, args['train_val_test_split']) trainset, valset, testset = split_dataset(dataset, args['train_val_test_split'])
train_loader = DataLoader(trainset, batch_size=args['batch_size'], collate_fn=collate_molgraphs) train_loader = DataLoader(trainset, batch_size=args['batch_size'],
val_loader = DataLoader(valset, batch_size=args['batch_size'], collate_fn=collate_molgraphs) collate_fn=collate_molgraphs_for_classification)
test_loader = DataLoader(testset, batch_size=args['batch_size'], collate_fn=collate_molgraphs) val_loader = DataLoader(valset, batch_size=args['batch_size'],
collate_fn=collate_molgraphs_for_classification)
test_loader = DataLoader(testset, batch_size=args['batch_size'],
collate_fn=collate_molgraphs_for_classification)
if args['pre_trained']: if args['pre_trained']:
args['num_epochs'] = 0 args['num_epochs'] = 0
......
...@@ -23,9 +23,38 @@ GAT_Tox21 = { ...@@ -23,9 +23,38 @@ GAT_Tox21 = {
'patience': 10 'patience': 10
} }
MPNN_Alchemy = {
'batch_size': 16,
'num_epochs': 250,
'output_dim': 12,
'lr': 0.0001,
'patience': 50
}
SCHNET_Alchemy = {
'batch_size': 16,
'num_epochs': 250,
'norm': True,
'output_dim': 12,
'lr': 0.0001,
'patience': 50
}
MGCN_Alchemy = {
'batch_size': 16,
'num_epochs': 250,
'norm': True,
'output_dim': 12,
'lr': 0.0001,
'patience': 50
}
experiment_configures = { experiment_configures = {
'GCN_Tox21': GCN_Tox21, 'GCN_Tox21': GCN_Tox21,
'GAT_Tox21': GAT_Tox21 'GAT_Tox21': GAT_Tox21,
'MPNN_Alchemy': MPNN_Alchemy,
'SCHNET_Alchemy': SCHNET_Alchemy,
'MGCN_Alchemy': MGCN_Alchemy
} }
def get_exp_configure(exp_name): def get_exp_configure(exp_name):
......
import argparse
import torch import torch
import torch.nn as nn import torch.nn as nn
from dgl.data.chem import alchemy
from dgl import model_zoo
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
def train(model="sch", from dgl import model_zoo
epochs=80,
device=torch.device("cpu"),
training_set_size=0.8):
print("start")
alchemy_dataset = alchemy.TencentAlchemyDataset()
train_set, test_set = alchemy_dataset.split(train_size=training_set_size)
train_loader = DataLoader(dataset=train_set,
batch_size=20,
collate_fn=alchemy.batcher_dev,
shuffle=False,
num_workers=0)
test_loader = DataLoader(dataset=test_set,
batch_size=20,
collate_fn=alchemy.batcher_dev,
shuffle=False,
num_workers=0)
if model == "sch":
model = model_zoo.chem.SchNetModel(norm=True, output_dim=12)
model.set_mean_std(alchemy_dataset.mean, alchemy_dataset.std, device)
elif model == "mgcn":
model = model_zoo.chem.MGCNModel(norm=True, output_dim=12)
model.set_mean_std(alchemy_dataset.mean, alchemy_dataset.std, device)
elif model == "mpnn":
model = model_zoo.chem.MPNNModel(output_dim=12)
model.to(device)
loss_fn = nn.MSELoss()
MAE_fn = nn.L1Loss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
for epoch in range(epochs):
w_loss, w_mae = 0, 0
model.train()
for idx, batch in enumerate(train_loader):
batch.graph.to(device)
batch.label = batch.label.to(device)
res = model(batch.graph)
loss = loss_fn(res, batch.label)
mae = MAE_fn(res, batch.label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
w_mae += mae.detach().item() from utils import set_random_seed, collate_molgraphs_for_regression, EarlyStopping
w_loss += loss.detach().item()
w_mae /= idx + 1 def regress(args, model, bg):
if args['model'] == 'MPNN':
h = bg.ndata.pop('n_feat')
e = bg.edata.pop('e_feat')
h, e = h.to(args['device']), e.to(args['device'])
return model(bg, h, e)
else:
node_types = bg.ndata.pop('node_type')
edge_distances = bg.edata.pop('distance')
node_types, edge_distances = node_types.to(args['device']), \
edge_distances.to(args['device'])
return model(bg, node_types, edge_distances)
def run_a_train_epoch(args, epoch, model, data_loader,
loss_criterion, score_criterion, optimizer):
model.train()
total_loss, total_score = 0, 0
for batch_id, batch_data in enumerate(data_loader):
smiles, bg, labels = batch_data
labels = labels.to(args['device'])
prediction = regress(args, model, bg)
loss = loss_criterion(prediction, labels)
score = score_criterion(prediction, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.detach().item() * bg.batch_size
total_score += score.detach().item() * bg.batch_size
total_loss /= len(data_loader.dataset)
total_score /= len(data_loader.dataset)
print('epoch {:d}/{:d}, training loss {:.4f}, training score {:.4f}'.format(
epoch + 1, args['num_epochs'], total_loss, total_score))
def run_an_eval_epoch(args, model, data_loader, score_criterion):
model.eval()
total_score = 0
with torch.no_grad():
for batch_id, batch_data in enumerate(data_loader):
smiles, bg, labels = batch_data
labels = labels.to(args['device'])
prediction = regress(args, model, bg)
score = score_criterion(prediction, labels)
total_score += score.detach().item() * bg.batch_size
total_score /= len(data_loader.dataset)
return total_score
def main(args):
args['device'] = "cuda" if torch.cuda.is_available() else "cpu"
set_random_seed()
# Interchangeable with other datasets
if args['dataset'] == 'Alchemy':
from dgl.data.chem import TencentAlchemyDataset
train_set = TencentAlchemyDataset(mode='dev')
val_set = TencentAlchemyDataset(mode='valid')
print("Epoch {:2d}, loss: {:.7f}, MAE: {:.7f}".format( train_loader = DataLoader(dataset=train_set,
epoch, w_loss, w_mae)) batch_size=args['batch_size'],
collate_fn=collate_molgraphs_for_regression)
val_loader = DataLoader(dataset=val_set,
batch_size=args['batch_size'],
collate_fn=collate_molgraphs_for_regression)
if args['model'] == 'MPNN':
model = model_zoo.chem.MPNNModel(output_dim=args['output_dim'])
elif args['model'] == 'SCHNET':
model = model_zoo.chem.SchNet(norm=args['norm'], output_dim=args['output_dim'])
model.set_mean_std(train_set.mean, train_set.std, args['device'])
elif args['model'] == 'MGCN':
model = model_zoo.chem.MGCNModel(norm=args['norm'], output_dim=args['output_dim'])
model.set_mean_std(train_set.mean, train_set.std, args['device'])
model.to(args['device'])
w_loss, w_mae = 0, 0 loss_fn = nn.MSELoss()
model.eval() score_fn = nn.L1Loss()
optimizer = torch.optim.Adam(model.parameters(), lr=args['lr'])
stopper = EarlyStopping(mode='lower', patience=args['patience'])
for epoch in range(args['num_epochs']):
# Train
run_a_train_epoch(args, epoch, model, train_loader, loss_fn, score_fn, optimizer)
# Validation and early stop
val_score = run_an_eval_epoch(args, model, val_loader, score_fn)
early_stop = stopper.step(val_score, model)
print('epoch {:d}/{:d}, validation score {:.4f}, best validation score {:.4f}'.format(
epoch + 1, args['num_epochs'], val_score, stopper.best_score))
if early_stop:
break
for idx, batch in enumerate(test_loader): if __name__ == "__main__":
batch.graph.to(device) import argparse
batch.label = batch.label.to(device)
res = model(batch.graph) from configure import get_exp_configure
mae = MAE_fn(res, batch.label)
w_mae += mae.detach().item() parser = argparse.ArgumentParser(description='Molecule Regression')
w_loss += loss.detach().item() parser.add_argument('-m', '--model', type=str, choices=['MPNN', 'SCHNET', 'MGCN'],
w_mae /= idx + 1 help='Model to use')
print("MAE (test set): {:.7f}".format(w_mae)) parser.add_argument('-d', '--dataset', type=str, choices=['Alchemy'],
help='Dataset to use')
args = parser.parse_args().__dict__
args['exp'] = '_'.join([args['model'], args['dataset']])
args.update(get_exp_configure(args['exp']))
if __name__ == "__main__": main(args)
parser = argparse.ArgumentParser()
parser.add_argument("-M",
"--model",
help="model name (sch, mgcn, mpnn)",
choices=['sch', 'mgcn', 'mpnn'],
default="sch")
parser.add_argument("--epochs",
help="number of epochs",
default=250,
type=int)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
args = parser.parse_args()
train(args.model, args.epochs, device)
...@@ -6,6 +6,13 @@ import torch ...@@ -6,6 +6,13 @@ import torch
from sklearn.metrics import roc_auc_score from sklearn.metrics import roc_auc_score
def set_random_seed(seed=0): def set_random_seed(seed=0):
"""Set random seed.
Parameters
----------
seed : int
Random seed to use
"""
random.seed(seed) random.seed(seed)
np.random.seed(seed) np.random.seed(seed)
torch.manual_seed(seed) torch.manual_seed(seed)
...@@ -13,19 +20,39 @@ def set_random_seed(seed=0): ...@@ -13,19 +20,39 @@ def set_random_seed(seed=0):
torch.cuda.manual_seed(seed) torch.cuda.manual_seed(seed)
class Meter(object): class Meter(object):
"""Track and summarize model performance on a dataset for
(multi-label) binary classification."""
def __init__(self): def __init__(self):
self.mask = [] self.mask = []
self.y_pred = [] self.y_pred = []
self.y_true = [] self.y_true = []
def update(self, y_pred, y_true, mask): def update(self, y_pred, y_true, mask):
self.y_pred.append(y_pred) """Update for the result of an iteration
self.y_true.append(y_true)
self.mask.append(mask) Parameters
----------
y_pred : float32 tensor
Predicted molecule labels with shape (B, T),
B for batch size and T for the number of tasks
y_true : float32 tensor
Ground truth molecule labels with shape (B, T)
mask : float32 tensor
Mask for indicating the existence of ground
truth labels with shape (B, T)
"""
self.y_pred.append(y_pred.detach().cpu())
self.y_true.append(y_true.detach().cpu())
self.mask.append(mask.detach().cpu())
# Todo: Allow different evaluation metrics
def roc_auc_averaged_over_tasks(self): def roc_auc_averaged_over_tasks(self):
"""Compute roc-auc score for each task and return the average.""" """Compute roc-auc score for each task and return the average.
Returns
-------
float
roc-auc score averaged over all tasks
"""
mask = torch.cat(self.mask, dim=0) mask = torch.cat(self.mask, dim=0)
y_pred = torch.cat(self.y_pred, dim=0) y_pred = torch.cat(self.y_pred, dim=0)
y_true = torch.cat(self.y_true, dim=0) y_true = torch.cat(self.y_true, dim=0)
...@@ -36,40 +63,64 @@ class Meter(object): ...@@ -36,40 +63,64 @@ class Meter(object):
total_score = 0 total_score = 0
for task in range(n_tasks): for task in range(n_tasks):
task_w = mask[:, task] task_w = mask[:, task]
task_y_true = y_true[:, task][task_w != 0].cpu().numpy() task_y_true = y_true[:, task][task_w != 0].numpy()
task_y_pred = y_pred[:, task][task_w != 0].cpu().detach().numpy() task_y_pred = y_pred[:, task][task_w != 0].numpy()
total_score += roc_auc_score(task_y_true, task_y_pred) total_score += roc_auc_score(task_y_true, task_y_pred)
return total_score / n_tasks return total_score / n_tasks
class EarlyStopping(object): class EarlyStopping(object):
def __init__(self, patience=10, filename=None): """Early stop performing
Parameters
----------
mode : str
* 'higher': Higher metric suggests a better model
* 'lower': Lower metric suggests a better model
patience : int
Number of epochs to wait before early stop
if the metric stops getting improved
filename : str or None
Filename for storing the model checkpoint
"""
def __init__(self, mode='higher', patience=10, filename=None):
if filename is None: if filename is None:
dt = datetime.datetime.now() dt = datetime.datetime.now()
filename = 'early_stop_{}_{:02d}-{:02d}-{:02d}.pth'.format( filename = 'early_stop_{}_{:02d}-{:02d}-{:02d}.pth'.format(
dt.date(), dt.hour, dt.minute, dt.second) dt.date(), dt.hour, dt.minute, dt.second)
assert mode in ['higher', 'lower']
self.mode = mode
if self.mode == 'higher':
self._check = self._check_higher
else:
self._check = self._check_lower
self.patience = patience self.patience = patience
self.counter = 0 self.counter = 0
self.filename = filename self.filename = filename
self.best_score = None self.best_score = None
self.early_stop = False self.early_stop = False
def step(self, acc, model): def _check_higher(self, score, prev_best_score):
score = acc return (score > prev_best_score)
def _check_lower(self, score, prev_best_score):
return (score < prev_best_score)
def step(self, score, model):
if self.best_score is None: if self.best_score is None:
self.best_score = score self.best_score = score
self.save_checkpoint(model) self.save_checkpoint(model)
# Todo: this is not true for all metrics. elif self._check(score, self.best_score):
elif score < self.best_score: self.best_score = score
self.save_checkpoint(model)
self.counter = 0
else:
self.counter += 1 self.counter += 1
print( print(
f'EarlyStopping counter: {self.counter} out of {self.patience}') f'EarlyStopping counter: {self.counter} out of {self.patience}')
if self.counter >= self.patience: if self.counter >= self.patience:
self.early_stop = True self.early_stop = True
else:
self.best_score = score
self.save_checkpoint(model)
self.counter = 0
return self.early_stop return self.early_stop
def save_checkpoint(self, model): def save_checkpoint(self, model):
...@@ -80,14 +131,14 @@ class EarlyStopping(object): ...@@ -80,14 +131,14 @@ class EarlyStopping(object):
'''Load model saved with early stopping.''' '''Load model saved with early stopping.'''
model.load_state_dict(torch.load(self.filename)['model_state_dict']) model.load_state_dict(torch.load(self.filename)['model_state_dict'])
def collate_molgraphs(data): def collate_molgraphs_for_classification(data):
"""Batching a list of datapoints for dataloader """Batching a list of datapoints for dataloader in classification tasks.
Parameters Parameters
---------- ----------
data : list of 4-tuples data : list of 4-tuples
Each tuple is for a single datapoint, consisting of Each tuple is for a single datapoint, consisting of
A SMILE, a DGLGraph, all-task labels and all-task weights a SMILE, a DGLGraph, all-task labels and all-task weights
Returns Returns
------- -------
...@@ -109,3 +160,29 @@ def collate_molgraphs(data): ...@@ -109,3 +160,29 @@ def collate_molgraphs(data):
labels = torch.stack(labels, dim=0) labels = torch.stack(labels, dim=0)
mask = torch.stack(mask, dim=0) mask = torch.stack(mask, dim=0)
return smiles, bg, labels, mask return smiles, bg, labels, mask
def collate_molgraphs_for_regression(data):
"""Batching a list of datapoints for dataloader in regression tasks.
Parameters
----------
data : list of 3-tuples
Each tuple is for a single datapoint, consisting of
a SMILE, a DGLGraph and all-task labels.
Returns
-------
smiles : list
List of smiles
bg : BatchedDGLGraph
Batched DGLGraphs
labels : Tensor of dtype float32 and shape (B, T)
Batched datapoint labels. B is len(data) and
T is the number of total tasks.
"""
smiles, graphs, labels = map(list, zip(*data))
bg = dgl.batch(graphs)
bg.set_n_initializer(dgl.init.zero_initializer)
bg.set_e_initializer(dgl.init.zero_initializer)
labels = torch.stack(labels, dim=0)
return smiles, bg, labels
from .tox21 import Tox21 from .utils import *
\ No newline at end of file from .csv_dataset import CSVDataset
from .tox21 import Tox21
from .alchemy import TencentAlchemyDataset
...@@ -11,8 +11,7 @@ import zipfile ...@@ -11,8 +11,7 @@ import zipfile
from collections import defaultdict from collections import defaultdict
from .utils import mol_to_complete_graph from .utils import mol_to_complete_graph
from ..utils import download, get_download_dir from ..utils import download, get_download_dir, _get_dgl_url
from ...batched_graph import batch
from ... import backend as F from ... import backend as F
try: try:
...@@ -23,63 +22,154 @@ try: ...@@ -23,63 +22,154 @@ try:
except ImportError: except ImportError:
pass pass
_urls = {'Alchemy': 'https://alchemy.tencent.com/data/dgl/'} def alchemy_nodes(mol):
"""Featurization for all atoms in a molecule. The atom indices
class AlchemyBatcher(object): will be preserved.
"""Data structure for holding a batch of data.
Parameters Parameters
---------- ----------
graph : dgl.BatchedDGLGraph mol : rdkit.Chem.rdchem.Mol
A batch of DGLGraphs for B molecules RDKit molecule object
labels : tensor
Labels for B molecules
"""
def __init__(self, graph=None, label=None):
self.graph = graph
self.label = label
def batcher_dev(batch_data): Returns
"""Batch datapoints -------
atom_feats_dict : dict
Dictionary for atom features
"""
atom_feats_dict = defaultdict(list)
is_donor = defaultdict(int)
is_acceptor = defaultdict(int)
fdef_name = osp.join(RDConfig.RDDataDir, 'BaseFeatures.fdef')
mol_featurizer = ChemicalFeatures.BuildFeatureFactory(fdef_name)
mol_feats = mol_featurizer.GetFeaturesForMol(mol)
mol_conformers = mol.GetConformers()
assert len(mol_conformers) == 1
for i in range(len(mol_feats)):
if mol_feats[i].GetFamily() == 'Donor':
node_list = mol_feats[i].GetAtomIds()
for u in node_list:
is_donor[u] = 1
elif mol_feats[i].GetFamily() == 'Acceptor':
node_list = mol_feats[i].GetAtomIds()
for u in node_list:
is_acceptor[u] = 1
num_atoms = mol.GetNumAtoms()
for u in range(num_atoms):
atom = mol.GetAtomWithIdx(u)
symbol = atom.GetSymbol()
atom_type = atom.GetAtomicNum()
aromatic = atom.GetIsAromatic()
hybridization = atom.GetHybridization()
num_h = atom.GetTotalNumHs()
atom_feats_dict['node_type'].append(atom_type)
h_u = []
h_u += [int(symbol == x) for x in ['H', 'C', 'N', 'O', 'F', 'S', 'Cl']]
h_u.append(atom_type)
h_u.append(is_acceptor[u])
h_u.append(is_donor[u])
h_u.append(int(aromatic))
h_u += [
int(hybridization == x)
for x in (Chem.rdchem.HybridizationType.SP,
Chem.rdchem.HybridizationType.SP2,
Chem.rdchem.HybridizationType.SP3)
]
h_u.append(num_h)
atom_feats_dict['n_feat'].append(F.tensor(np.array(h_u).astype(np.float32)))
atom_feats_dict['n_feat'] = F.stack(atom_feats_dict['n_feat'], dim=0)
atom_feats_dict['node_type'] = F.tensor(np.array(
atom_feats_dict['node_type']).astype(np.int64))
return atom_feats_dict
def alchemy_edges(mol, self_loop=False):
"""Featurization for all bonds in a molecule.
The bond indices will be preserved.
Parameters Parameters
---------- ----------
batch_data : list mol : rdkit.Chem.rdchem.Mol
batch[i][0] gives the DGLGraph for the ith datapoint, RDKit molecule object
and batch[i][1] gives the label for the ith datapoint. self_loop : bool
Whether to add self loops. Default to be False.
Returns Returns
------- -------
AlchemyBatcher bond_feats_dict : dict
An object holding the batch of data Dictionary for bond features
""" """
graphs, labels = zip(*batch_data) bond_feats_dict = defaultdict(list)
batch_graphs = batch(graphs)
labels = F.stack(labels, 0)
return AlchemyBatcher(graph=batch_graphs, label=labels) mol_conformers = mol.GetConformers()
assert len(mol_conformers) == 1
geom = mol_conformers[0].GetPositions()
num_atoms = mol.GetNumAtoms()
for u in range(num_atoms):
for v in range(num_atoms):
if u == v and not self_loop:
continue
e_uv = mol.GetBondBetweenAtoms(u, v)
if e_uv is None:
bond_type = None
else:
bond_type = e_uv.GetBondType()
bond_feats_dict['e_feat'].append([
float(bond_type == x)
for x in (Chem.rdchem.BondType.SINGLE,
Chem.rdchem.BondType.DOUBLE,
Chem.rdchem.BondType.TRIPLE,
Chem.rdchem.BondType.AROMATIC, None)
])
bond_feats_dict['distance'].append(
np.linalg.norm(geom[u] - geom[v]))
bond_feats_dict['e_feat'] = F.tensor(
np.array(bond_feats_dict['e_feat']).astype(np.float32))
bond_feats_dict['distance'] = F.tensor(
np.array(bond_feats_dict['distance']).astype(np.float32)).reshape(-1 , 1)
return bond_feats_dict
class TencentAlchemyDataset(object): class TencentAlchemyDataset(object):
"""`Tencent Alchemy Dataset <https://arxiv.org/abs/1906.09427>`__ """
Developed by the Tencent Quantum Lab, the dataset lists 12 quantum mechanical
properties of 130, 000+ organic molecules, comprising up to 12 heavy atoms
(C, N, O, S, F and Cl), sampled from the GDBMedChem database. These properties
have been calculated using the open-source computational chemistry program
Python-based Simulation of Chemistry Framework (PySCF).
For more details, check the `paper <https://arxiv.org/abs/1906.09427>`__.
Parameters Parameters
---------- ----------
mode : str mode : str
'dev', 'valid' or 'test', default to be 'dev' 'dev', 'valid' or 'test', separately for training, validation and test.
transform : transform operation on DGLGraphs Default to be 'dev'. Note that 'test' is not available as the Alchemy
Default to be None. contest is ongoing.
from_raw : bool from_raw : bool
Whether to process dataset from scratch or use a Whether to process the dataset from scratch or use a
processed one for faster speed. Default to be False. processed one for faster speed. Default to be False.
""" """
def __init__(self, mode='dev', transform=None, from_raw=False): def __init__(self, mode='dev', from_raw=False):
assert mode in ['dev', 'valid', 'test'], "mode should be dev/valid/test" if mode == 'test':
raise ValueError('The test mode is not supported before '
'the Alchemy contest finishes.')
assert mode in ['dev', 'valid', 'test'], \
'Expect mode to be dev, valid or test, got {}.'.format(mode)
self.mode = mode self.mode = mode
self.transform = transform
# Construct DGLGraphs from raw data or use the preprocessed data # Construct DGLGraphs from raw data or use the preprocessed data
self.from_raw = from_raw self.from_raw = from_raw
file_dir = osp.join(get_download_dir(), './Alchemy_data') file_dir = osp.join(get_download_dir(), 'Alchemy_data')
if not from_raw: if not from_raw:
file_name = "%s_processed" % (mode) file_name = "%s_processed" % (mode)
...@@ -87,9 +177,9 @@ class TencentAlchemyDataset(object): ...@@ -87,9 +177,9 @@ class TencentAlchemyDataset(object):
file_name = "%s_single_sdf" % (mode) file_name = "%s_single_sdf" % (mode)
self.file_dir = pathlib.Path(file_dir, file_name) self.file_dir = pathlib.Path(file_dir, file_name)
self._url = 'dataset/alchemy/'
self.zip_file_path = pathlib.Path(file_dir, file_name + '.zip') self.zip_file_path = pathlib.Path(file_dir, file_name + '.zip')
download(_urls['Alchemy'] + file_name + '.zip', download(_get_dgl_url(self._url + file_name + '.zip'), path=str(self.zip_file_path))
path=str(self.zip_file_path))
if not os.path.exists(str(self.file_dir)): if not os.path.exists(str(self.file_dir)):
archive = zipfile.ZipFile(self.zip_file_path) archive = zipfile.ZipFile(self.zip_file_path)
archive.extractall(file_dir) archive.extractall(file_dir)
...@@ -98,154 +188,74 @@ class TencentAlchemyDataset(object): ...@@ -98,154 +188,74 @@ class TencentAlchemyDataset(object):
self._load() self._load()
def _load(self): def _load(self):
if self.mode == 'dev': if not self.from_raw:
if not self.from_raw: with open(osp.join(self.file_dir, "%s_graphs.pkl" % self.mode), "rb") as f:
with open(osp.join(self.file_dir, "dev_graphs.pkl"), "rb") as f: self.graphs = pickle.load(f)
self.graphs = pickle.load(f) with open(osp.join(self.file_dir, "%s_labels.pkl" % self.mode), "rb") as f:
with open(osp.join(self.file_dir, "dev_labels.pkl"), "rb") as f: self.labels = pickle.load(f)
self.labels = pickle.load(f) else:
else: print('Start preprocessing dataset...')
target_file = pathlib.Path(self.file_dir, "dev_target.csv") target_file = pathlib.Path(self.file_dir, "%s_target.csv" % self.mode)
self.target = pd.read_csv( self.target = pd.read_csv(
target_file, target_file,
index_col=0, index_col=0,
usecols=['gdb_idx',] + ['property_%d' % x for x in range(12)]) usecols=['gdb_idx',] + ['property_%d' % x for x in range(12)])
self.target = self.target[['property_%d' % x for x in range(12)]] self.target = self.target[['property_%d' % x for x in range(12)]]
self.graphs, self.labels = [], [] self.graphs, self.labels = [], []
supp = Chem.SDMolSupplier( supp = Chem.SDMolSupplier(osp.join(self.file_dir, self.mode + ".sdf"))
osp.join(self.file_dir, self.mode + ".sdf")) cnt = 0
cnt = 0 dataset_size = len(self.target)
for sdf, label in zip(supp, self.target.iterrows()): for mol, label in zip(supp, self.target.iterrows()):
graph = mol_to_complete_graph(sdf, atom_featurizer=self.alchemy_nodes, cnt += 1
bond_featurizer=self.alchemy_edges) print('Processing molecule {:d}/{:d}'.format(cnt, dataset_size))
cnt += 1 graph = mol_to_complete_graph(mol, atom_featurizer=alchemy_nodes,
self.graphs.append(graph) bond_featurizer=alchemy_edges)
label = F.tensor(np.array(label[1].tolist()).astype(np.float32)) smile = Chem.MolToSmiles(mol)
self.labels.append(label) graph.smile = smile
self.graphs.append(graph)
self.normalize() label = F.tensor(np.array(label[1].tolist()).astype(np.float32))
self.labels.append(label)
with open(osp.join(self.file_dir, "%s_graphs.pkl" % self.mode), "wb") as f:
pickle.dump(self.graphs, f)
with open(osp.join(self.file_dir, "%s_labels.pkl" % self.mode), "wb") as f:
pickle.dump(self.labels, f)
self.set_mean_and_std()
print(len(self.graphs), "loaded!") print(len(self.graphs), "loaded!")
def alchemy_nodes(self, mol): def __getitem__(self, item):
"""Featurization for all atoms in a molecule. The atom indices """Get datapoint with index
will be preserved.
Parameters Parameters
---------- ----------
mol : rdkit.Chem.rdchem.Mol item : int
RDKit molecule object Datapoint index
Returns Returns
------- -------
atom_feats_dict : dict str
Dictionary for atom features SMILES for the ith datapoint
DGLGraph
DGLGraph for the ith datapoint
Tensor of dtype float32
Labels of the datapoint for all tasks
""" """
atom_feats_dict = defaultdict(list) g, l = self.graphs[item], self.labels[item]
is_donor = defaultdict(int) return g.smile, g, l
is_acceptor = defaultdict(int)
fdef_name = osp.join(RDConfig.RDDataDir, 'BaseFeatures.fdef')
mol_featurizer = ChemicalFeatures.BuildFeatureFactory(fdef_name)
mol_feats = mol_featurizer.GetFeaturesForMol(mol)
mol_conformers = mol.GetConformers()
assert len(mol_conformers) == 1
geom = mol_conformers[0].GetPositions()
for i in range(len(mol_feats)):
if mol_feats[i].GetFamily() == 'Donor':
node_list = mol_feats[i].GetAtomIds()
for u in node_list:
is_donor[u] = 1
elif mol_feats[i].GetFamily() == 'Acceptor':
node_list = mol_feats[i].GetAtomIds()
for u in node_list:
is_acceptor[u] = 1
num_atoms = mol.GetNumAtoms()
for u in range(num_atoms):
atom = mol.GetAtomWithIdx(u)
symbol = atom.GetSymbol()
atom_type = atom.GetAtomicNum()
aromatic = atom.GetIsAromatic()
hybridization = atom.GetHybridization()
num_h = atom.GetTotalNumHs()
atom_feats_dict['pos'].append(F.tensor(geom[u].astype(np.float32)))
atom_feats_dict['node_type'].append(atom_type)
h_u = []
h_u += [
int(symbol == x) for x in ['H', 'C', 'N', 'O', 'F', 'S', 'Cl']
]
h_u.append(atom_type)
h_u.append(is_acceptor[u])
h_u.append(is_donor[u])
h_u.append(int(aromatic))
h_u += [
int(hybridization == x)
for x in (Chem.rdchem.HybridizationType.SP,
Chem.rdchem.HybridizationType.SP2,
Chem.rdchem.HybridizationType.SP3)
]
h_u.append(num_h)
atom_feats_dict['n_feat'].append(F.tensor(np.array(h_u).astype(np.float32)))
atom_feats_dict['n_feat'] = F.stack(atom_feats_dict['n_feat'], dim=0)
atom_feats_dict['pos'] = F.stack(atom_feats_dict['pos'], dim=0)
atom_feats_dict['node_type'] = F.tensor(np.array(
atom_feats_dict['node_type']).astype(np.int64))
return atom_feats_dict
def alchemy_edges(self, mol, self_loop=False):
"""Featurization for all bonds in a molecule.
The bond indices will be preserved.
Parameters def __len__(self):
---------- """Length of the dataset
mol : rdkit.Chem.rdchem.Mol
RDKit molecule object
Returns Returns
------- -------
bond_feats_dict : dict int
Dictionary for bond features Length of Dataset
""" """
bond_feats_dict = defaultdict(list) return len(self.graphs)
mol_conformers = mol.GetConformers() def set_mean_and_std(self, mean=None, std=None):
assert len(mol_conformers) == 1
geom = mol_conformers[0].GetPositions()
num_atoms = mol.GetNumAtoms()
for u in range(num_atoms):
for v in range(num_atoms):
if u == v and not self_loop:
continue
e_uv = mol.GetBondBetweenAtoms(u, v)
if e_uv is None:
bond_type = None
else:
bond_type = e_uv.GetBondType()
bond_feats_dict['e_feat'].append([
float(bond_type == x)
for x in (Chem.rdchem.BondType.SINGLE,
Chem.rdchem.BondType.DOUBLE,
Chem.rdchem.BondType.TRIPLE,
Chem.rdchem.BondType.AROMATIC, None)
])
bond_feats_dict['distance'].append(
np.linalg.norm(geom[u] - geom[v]))
bond_feats_dict['e_feat'] = F.tensor(
np.array(bond_feats_dict['e_feat']).astype(np.float32))
bond_feats_dict['distance'] = F.tensor(
np.array(bond_feats_dict['distance']).astype(np.float32)).reshape(-1 , 1)
return bond_feats_dict
def normalize(self, mean=None, std=None):
"""Set mean and std or compute from labels for future normalization. """Set mean and std or compute from labels for future normalization.
Parameters Parameters
...@@ -262,63 +272,3 @@ class TencentAlchemyDataset(object): ...@@ -262,63 +272,3 @@ class TencentAlchemyDataset(object):
std = np.std(labels, axis=0) std = np.std(labels, axis=0)
self.mean = mean self.mean = mean
self.std = std self.std = std
def __len__(self):
return len(self.graphs)
def __getitem__(self, idx):
g, l = self.graphs[idx], self.labels[idx]
if self.transform:
g = self.transform(g)
return g, l
def split(self, train_size=0.8):
"""Split the dataset into two AlchemySubset for train&test.
Parameters
----------
train_size : float
Proportion of dataset to use for training. Default to be 0.8.
Returns
-------
train_set : AlchemySubset
Dataset for training
test_set : AlchemySubset
Dataset for test
"""
assert 0 < train_size < 1
train_num = int(len(self.graphs) * train_size)
train_set = AlchemySubset(self.graphs[:train_num],
self.labels[:train_num], self.mean, self.std,
self.transform)
test_set = AlchemySubset(self.graphs[train_num:],
self.labels[train_num:], self.mean, self.std,
self.transform)
return train_set, test_set
class AlchemySubset(TencentAlchemyDataset):
"""
Sub-dataset split from TencentAlchemyDataset.
Used to construct the training & test set.
Parameters
----------
graphs : list of DGLGraphs
DGLGraphs for datapoints in the subset
labels : list of tensors
Labels for datapoints in the subset
mean : int or float
Mean of labels in the subset
std : int or float
Std of labels in the subset
transform : transform operation on DGLGraphs
Default to be None.
"""
def __init__(self, graphs, labels, mean=0, std=1, transform=None):
super(AlchemySubset, self).__init__()
self.graphs = graphs
self.labels = labels
self.mean = mean
self.std = std
self.transform = transform
...@@ -24,21 +24,17 @@ class CSVDataset(object): ...@@ -24,21 +24,17 @@ class CSVDataset(object):
Parameters Parameters
---------- ----------
df: pandas.DataFrame df: pandas.DataFrame
Dataframe including smiles and labels. Can be loaded by pandas.read_csv(file_path). Dataframe including smiles and labels. Can be loaded by pandas.read_csv(file_path).
One column includes smiles and other columns for labels. One column includes smiles and other columns for labels.
Column names other than smiles column would be considered as task names. Column names other than smiles column would be considered as task names.
smile_to_graph: callable, str -> DGLGraph smile_to_graph: callable, str -> DGLGraph
A function turns smiles into a DGLGraph. Default one can be found A function turns smiles into a DGLGraph. Default one can be found
at python/dgl/data/chem/utils.py named with smile_to_bigraph. at python/dgl/data/chem/utils.py named with smile_to_bigraph.
smile_column: str smile_column: str
Column name that including smiles Column name that including smiles
cache_file_path: str cache_file_path: str
Path to store the preprocessed data Path to store the preprocessed data
""" """
def __init__(self, df, smile_to_graph=smile_to_bigraph, smile_column='smiles', def __init__(self, df, smile_to_graph=smile_to_bigraph, smile_column='smiles',
cache_file_path="csvdata_dglgraph.pkl"): cache_file_path="csvdata_dglgraph.pkl"):
if 'rdkit' not in sys.modules: if 'rdkit' not in sys.modules:
...@@ -59,6 +55,11 @@ class CSVDataset(object): ...@@ -59,6 +55,11 @@ class CSVDataset(object):
and featurize their atoms and featurize their atoms
* Set missing labels to be 0 and use a binary masking * Set missing labels to be 0 and use a binary masking
matrix to mask them matrix to mask them
Parameters
----------
smile_to_graph : callable, SMILES -> DGLGraph
Function for converting a SMILES (str) into a DGLGraph
""" """
if os.path.exists(self.cache_file_path): if os.path.exists(self.cache_file_path):
# DGLGraphs have been constructed before, reload them # DGLGraphs have been constructed before, reload them
...@@ -76,7 +77,12 @@ class CSVDataset(object): ...@@ -76,7 +77,12 @@ class CSVDataset(object):
self.mask = (~np.isnan(_label_values)).astype(np.float32) self.mask = (~np.isnan(_label_values)).astype(np.float32)
def __getitem__(self, item): def __getitem__(self, item):
"""Get the ith datapoint """Get datapoint with index
Parameters
----------
item : int
Datapoint index
Returns Returns
------- -------
...@@ -87,17 +93,17 @@ class CSVDataset(object): ...@@ -87,17 +93,17 @@ class CSVDataset(object):
Tensor of dtype float32 Tensor of dtype float32
Labels of the datapoint for all tasks Labels of the datapoint for all tasks
Tensor of dtype float32 Tensor of dtype float32
Weights of the datapoint for all tasks Binary masks indicating the existence of labels for all tasks
""" """
return self.smiles[item], self.graphs[item], \ return self.smiles[item], self.graphs[item], \
F.zerocopy_from_numpy(self.labels[item]), \ F.zerocopy_from_numpy(self.labels[item]), \
F.zerocopy_from_numpy(self.mask[item]) F.zerocopy_from_numpy(self.mask[item])
def __len__(self): def __len__(self):
"""Length of Dataset """Length of the dataset
Return Returns
------ -------
int int
Length of Dataset Length of Dataset
""" """
......
...@@ -11,9 +11,6 @@ except ImportError: ...@@ -11,9 +11,6 @@ except ImportError:
pass pass
class Tox21(CSVDataset): class Tox21(CSVDataset):
_url = 'dataset/tox21.csv.gz'
"""Tox21 dataset. """Tox21 dataset.
The Toxicology in the 21st Century (https://tripod.nih.gov/tox21/challenge/) The Toxicology in the 21st Century (https://tripod.nih.gov/tox21/challenge/)
...@@ -33,14 +30,15 @@ class Tox21(CSVDataset): ...@@ -33,14 +30,15 @@ class Tox21(CSVDataset):
Parameters Parameters
---------- ----------
smile_to_graph: callable, str -> DGLGraph smile_to_graph: callable, str -> DGLGraph
A function turns smiles into a DGLGraph. Default one can be found A function turns smiles into a DGLGraph. Default one can be found
at python/dgl/data/chem/utils.py named with smile_to_bigraph. at python/dgl/data/chem/utils.py named with smile_to_bigraph.
""" """
def __init__(self, smile_to_graph=smile_to_bigraph): def __init__(self, smile_to_graph=smile_to_bigraph):
if 'pandas' not in sys.modules: if 'pandas' not in sys.modules:
from ...base import dgl_warning from ...base import dgl_warning
dgl_warning("Please install pandas") dgl_warning("Please install pandas")
self._url = 'dataset/tox21.csv.gz'
data_path = get_download_dir() + '/tox21.csv.gz' data_path = get_download_dir() + '/tox21.csv.gz'
download(_get_dgl_url(self._url), path=data_path) download(_get_dgl_url(self._url), path=data_path)
df = pd.read_csv(data_path) df = pd.read_csv(data_path)
...@@ -80,7 +78,7 @@ class Tox21(CSVDataset): ...@@ -80,7 +78,7 @@ class Tox21(CSVDataset):
Returns Returns
------- -------
list numpy.ndarray
numpy array gives the weight of positive samples on all tasks numpy array gives the weight of positive samples on all tasks
""" """
return self._task_pos_weights return self._task_pos_weights
...@@ -10,6 +10,10 @@ try: ...@@ -10,6 +10,10 @@ try:
except ImportError: except ImportError:
pass pass
__all__ = ['one_hot_encoding', 'BaseAtomFeaturizer', 'CanonicalAtomFeaturizer',
'mol_to_graph', 'smile_to_bigraph', 'mol_to_bigraph',
'smile_to_complete_graph', 'mol_to_complete_graph']
def one_hot_encoding(x, allowable_set): def one_hot_encoding(x, allowable_set):
"""One-hot encoding. """One-hot encoding.
......
...@@ -29,13 +29,37 @@ def _get_dgl_url(file_url): ...@@ -29,13 +29,37 @@ def _get_dgl_url(file_url):
def split_dataset(dataset, frac_list=None, shuffle=False, random_state=None): def split_dataset(dataset, frac_list=None, shuffle=False, random_state=None):
"""Split dataset into training, validation and test set.
Parameters
----------
dataset
We assume ``len(dataset)`` gives the number of datapoints and ``dataset[i]``
gives the ith datapoint.
frac_list : list or None, optional
A list of length 3 containing the fraction to use for training,
validation and test. If None, we will use [0.8, 0.1, 0.1].
shuffle : bool, optional
By default we perform a consecutive split of the dataset. If True,
we will first randomly shuffle the dataset.
random_state : None, int or array_like, optional
Random seed used to initialize the pseudo-random number generator.
Can be any integer between 0 and 2**32 - 1 inclusive, an array
(or other sequence) of such integers, or None (the default).
If seed is None, then RandomState will try to read data from /dev/urandom
(or the Windows analogue) if available or seed from the clock otherwise.
Returns
-------
list of length 3
Subsets for training, validation and test.
"""
from itertools import accumulate from itertools import accumulate
if frac_list is None: if frac_list is None:
frac_list = [0.8, 0.1, 0.1] frac_list = [0.8, 0.1, 0.1]
frac_list = np.array(frac_list) frac_list = np.array(frac_list)
assert np.allclose(np.sum(frac_list), 1.), \ assert np.allclose(np.sum(frac_list), 1.), \
'Expect frac_list sum to 1, got {:.4f}'.format( 'Expect frac_list sum to 1, got {:.4f}'.format(np.sum(frac_list))
np.sum(frac_list))
num_data = len(dataset) num_data = len(dataset)
lengths = (num_data * frac_list).astype(int) lengths = (num_data * frac_list).astype(int)
lengths[-1] = num_data - np.sum(lengths[:-1]) lengths[-1] = num_data - np.sum(lengths[:-1])
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
"""Model Zoo Package""" """Model Zoo Package"""
from .classifiers import GCNClassifier, GATClassifier from .classifiers import GCNClassifier, GATClassifier
from .sch import SchNetModel from .schnet import SchNet
from .mgcn import MGCNModel from .mgcn import MGCNModel
from .mpnn import MPNNModel from .mpnn import MPNNModel
from .dgmg import DGMG from .dgmg import DGMG
......
...@@ -629,6 +629,8 @@ def dgmg_message_weight_init(m): ...@@ -629,6 +629,8 @@ def dgmg_message_weight_init(m):
class DGMG(nn.Module): class DGMG(nn.Module):
"""DGMG model """DGMG model
`Learning Deep Generative Models of Graphs <https://arxiv.org/abs/1803.03324>`__
Users only need to initialize an instance of this class. Users only need to initialize an instance of this class.
Parameters Parameters
......
...@@ -27,7 +27,10 @@ from .nnutils import create_var, cuda, move_dgl_to_cuda ...@@ -27,7 +27,10 @@ from .nnutils import create_var, cuda, move_dgl_to_cuda
class DGLJTNNVAE(nn.Module): class DGLJTNNVAE(nn.Module):
"""
`Junction Tree Variational Autoencoder for Molecular Graph Generation
<https://arxiv.org/abs/1802.04364>`__
"""
def __init__(self, hidden_size, latent_size, depth, vocab=None, vocab_file=None): def __init__(self, hidden_size, latent_size, depth, vocab=None, vocab_file=None):
super(DGLJTNNVAE, self).__init__() super(DGLJTNNVAE, self).__init__()
if vocab is None: if vocab is None:
......
...@@ -3,14 +3,13 @@ ...@@ -3,14 +3,13 @@
""" """
The implementation of neural network layers used in SchNet and MGCN. The implementation of neural network layers used in SchNet and MGCN.
""" """
import torch as th import torch
import torch.nn as nn import torch.nn as nn
from torch.nn import Softplus from torch.nn import Softplus
import numpy as np import numpy as np
from ... import function as fn from ... import function as fn
class AtomEmbedding(nn.Module): class AtomEmbedding(nn.Module):
""" """
Convert the atom(node) list to atom embeddings. Convert the atom(node) list to atom embeddings.
...@@ -19,7 +18,7 @@ class AtomEmbedding(nn.Module): ...@@ -19,7 +18,7 @@ class AtomEmbedding(nn.Module):
Parameters Parameters
---------- ----------
dim : int dim : int
Dim of embeddings, default to be 128. Size of embeddings, default to be 128.
type_num : int type_num : int
The largest atomic number of atoms in the dataset, default to be 100. The largest atomic number of atoms in the dataset, default to be 100.
pre_train : None or pre-trained embeddings pre_train : None or pre-trained embeddings
...@@ -27,6 +26,7 @@ class AtomEmbedding(nn.Module): ...@@ -27,6 +26,7 @@ class AtomEmbedding(nn.Module):
""" """
def __init__(self, dim=128, type_num=100, pre_train=None): def __init__(self, dim=128, type_num=100, pre_train=None):
super(AtomEmbedding, self).__init__() super(AtomEmbedding, self).__init__()
self._dim = dim self._dim = dim
self._type_num = type_num self._type_num = type_num
if pre_train is not None: if pre_train is not None:
...@@ -34,19 +34,19 @@ class AtomEmbedding(nn.Module): ...@@ -34,19 +34,19 @@ class AtomEmbedding(nn.Module):
else: else:
self.embedding = nn.Embedding(type_num, dim, padding_idx=0) self.embedding = nn.Embedding(type_num, dim, padding_idx=0)
def forward(self, g, p_name="node"): def forward(self, atom_types):
""" """
Parameters Parameters
---------- ----------
g : DGLGraph atom_types : int64 tensor of shape (B1)
Input DGLGraph(s) Types for atoms in the graph(s), B1 for the number of atoms.
p_name : str
Name for storing atom embeddings
"""
atom_list = g.ndata["node_type"]
g.ndata[p_name] = self.embedding(atom_list)
return g.ndata[p_name]
Returns
-------
float32 tensor of shape (B1, self._dim)
Atom embeddings.
"""
return self.embedding(atom_types)
class EdgeEmbedding(nn.Module): class EdgeEmbedding(nn.Module):
""" """
...@@ -56,14 +56,15 @@ class EdgeEmbedding(nn.Module): ...@@ -56,14 +56,15 @@ class EdgeEmbedding(nn.Module):
Parameters Parameters
---------- ----------
dim : int dim : int
Dim of edge embeddings, default to be 128. Size of embeddings, default to be 128.
edge_num : int edge_num : int
Maximum number of edge types, default to be 128. Maximum number of edge types allowed, default to be 3000.
pre_train : Edge embeddings or None pre_train : Edge embeddings or None
Pre-trained edge embeddings Pre-trained edge embeddings, default to be None.
""" """
def __init__(self, dim=128, edge_num=3000, pre_train=None): def __init__(self, dim=128, edge_num=3000, pre_train=None):
super(EdgeEmbedding, self).__init__() super(EdgeEmbedding, self).__init__()
self._dim = dim self._dim = dim
self._edge_num = edge_num self._edge_num = edge_num
if pre_train is not None: if pre_train is not None:
...@@ -74,12 +75,13 @@ class EdgeEmbedding(nn.Module): ...@@ -74,12 +75,13 @@ class EdgeEmbedding(nn.Module):
def generate_edge_type(self, edges): def generate_edge_type(self, edges):
"""Generate edge type. """Generate edge type.
The edge type is based on the type of the src&dst atom. The edge type is based on the type of the src & dst atom.
Note that C-O and O-C are the same edge type. Note that directions are not distinguished, e.g. C-O and O-C are the same edge type.
To map a pair of nodes to one number, we use an unordered pairing function here To map a pair of nodes to one number, we use an unordered pairing function here
See more detail in this disscussion: See more detail in this disscussion:
https://math.stackexchange.com/questions/23503/create-unique-number-from-2-numbers https://math.stackexchange.com/questions/23503/create-unique-number-from-2-numbers
Note that, the edge_num should be larger than the square of maximum atomic number Note that the edge_num should be larger than the square of maximum atomic number
in the dataset. in the dataset.
Parameters Parameters
...@@ -92,45 +94,51 @@ class EdgeEmbedding(nn.Module): ...@@ -92,45 +94,51 @@ class EdgeEmbedding(nn.Module):
dict dict
Stores the edge types in "type" Stores the edge types in "type"
""" """
atom_type_x = edges.src["node_type"] atom_type_x = edges.src['ntype']
atom_type_y = edges.dst["node_type"] atom_type_y = edges.dst['ntype']
return { return {
"type": atom_type_x * atom_type_y + 'etype': atom_type_x * atom_type_y + \
(th.abs(atom_type_x - atom_type_y) - 1)**2 / 4 (torch.abs(atom_type_x - atom_type_y) - 1) ** 2 / 4
} }
def forward(self, g, p_name="edge_f"): def forward(self, g, atom_types):
"""Compute edge embeddings """Compute edge embeddings
Parameters Parameters
---------- ----------
g : DGLGraph g : DGLGraph
The graph to compute edge embeddings The graph to compute edge embeddings
p_name : str atom_types : int64 tensor of shape (B1)
Types for atoms in the graph(s), B1 for the number of atoms.
Returns Returns
------- -------
computed edge embeddings float32 tensor of shape (B2, self._dim)
Computed edge embeddings
""" """
g = g.local_var()
g.ndata['ntype'] = atom_types
g.apply_edges(self.generate_edge_type) g.apply_edges(self.generate_edge_type)
g.edata[p_name] = self.embedding(g.edata["type"]) return self.embedding(g.edata.pop('etype'))
return g.edata[p_name]
class ShiftSoftplus(Softplus): class ShiftSoftplus(nn.Module):
""" """
Shiftsoft plus activation function: ShiftSoftplus activation function:
1/beta * (log(1 + exp**(beta * x)) - log(shift)) 1/beta * (log(1 + exp**(beta * x)) - log(shift))
Parameters Parameters
---------- ----------
beta : int beta : int
Default to be 1.
shift : int shift : int
Default to be 2.
threshold : int threshold : int
Default to be 20.
""" """
def __init__(self, beta=1, shift=2, threshold=20): def __init__(self, beta=1, shift=2, threshold=20):
super(ShiftSoftplus, self).__init__(beta, threshold) super(ShiftSoftplus, self).__init__()
self.shift = shift self.shift = shift
self.softplus = Softplus(beta, threshold) self.softplus = Softplus(beta, threshold)
...@@ -138,43 +146,56 @@ class ShiftSoftplus(Softplus): ...@@ -138,43 +146,56 @@ class ShiftSoftplus(Softplus):
"""Applies the activation function""" """Applies the activation function"""
return self.softplus(x) - np.log(float(self.shift)) return self.softplus(x) - np.log(float(self.shift))
class RBFLayer(nn.Module): class RBFLayer(nn.Module):
""" """
Radial basis functions Layer. Radial basis functions Layer.
e(d) = exp(- gamma * ||d - mu_k||^2) e(d) = exp(- gamma * ||d - mu_k||^2)
default settings:
gamma = 10 With the default parameters below, we are using a default settings:
0 <= mu_k <= 30 for k=1~300 * gamma = 10
* 0 <= mu_k <= 30 for k=1~300
Parameters
----------
low : int
Smallest value to take for mu_k, default to be 0.
high : int
Largest value to take for mu_k, default to be 30.
gap : float
Difference between two consecutive values for mu_k, default to be 0.1.
dim : int
Output size for each center, default to be 1.
""" """
def __init__(self, low=0, high=30, gap=0.1, dim=1): def __init__(self, low=0, high=30, gap=0.1, dim=1):
super(RBFLayer, self).__init__() super(RBFLayer, self).__init__()
self._low = low self._low = low
self._high = high self._high = high
self._gap = gap
self._dim = dim self._dim = dim
self._n_centers = int(np.ceil((high - low) / gap)) self._n_centers = int(np.ceil((high - low) / gap))
centers = np.linspace(low, high, self._n_centers) centers = np.linspace(low, high, self._n_centers)
self.centers = th.tensor(centers, dtype=th.float, requires_grad=False) self.centers = torch.tensor(centers, dtype=torch.float, requires_grad=False)
self.centers = nn.Parameter(self.centers, requires_grad=False) self.centers = nn.Parameter(self.centers, requires_grad=False)
self._fan_out = self._dim * self._n_centers self._fan_out = self._dim * self._n_centers
self._gap = centers[1] - centers[0] self._gap = centers[1] - centers[0]
def dis2rbf(self, edges): def forward(self, edge_distances):
"""Convert distance matrix to RBF tensor.""" """
dist = edges.data["distance"] Parameters
radial = dist - self.centers ----------
coef = -1 / self._gap edge_distances : float32 tensor of shape (B, 1)
rbf = th.exp(coef * (radial**2)) Edge distances, B for the number of edges.
return {"rbf": rbf}
def forward(self, g):
"""Convert distance scalar to rbf vector"""
g.apply_edges(self.dis2rbf)
return g.edata["rbf"]
Returns
-------
float32 tensor of shape (B, self._fan_out)
Computed RBF results
"""
radial = edge_distances - self.centers
coef = -1 / self._gap
return torch.exp(coef * (radial ** 2))
class CFConv(nn.Module): class CFConv(nn.Module):
""" """
...@@ -185,38 +206,45 @@ class CFConv(nn.Module): ...@@ -185,38 +206,45 @@ class CFConv(nn.Module):
rbf_dim : int rbf_dim : int
Dimension of the RBF layer output Dimension of the RBF layer output
dim : int dim : int
Dimension of linear layers, default to be 64 Dimension of output, default to be 64
act : str or activation function act : activation function or None.
Activation function, default to be shifted softplus Activation function, default to be shifted softplus
""" """
def __init__(self, rbf_dim, dim=64, act=None): def __init__(self, rbf_dim, dim=64, act=None):
super(CFConv, self).__init__() super(CFConv, self).__init__()
self._rbf_dim = rbf_dim self._rbf_dim = rbf_dim
self._dim = dim self._dim = dim
self.linear_layer1 = nn.Linear(self._rbf_dim, self._dim)
self.linear_layer2 = nn.Linear(self._dim, self._dim)
if act is None: if act is None:
self.activation = nn.Softplus(beta=0.5, threshold=14) activation = nn.Softplus(beta=0.5, threshold=14)
else: else:
self.activation = act activation = act
def update_edge(self, edges):
"""Update the edge features with two FC layers."""
rbf = edges.data["rbf"]
h = self.linear_layer1(rbf)
h = self.activation(h)
h = self.linear_layer2(h)
return {"h": h}
def forward(self, g): self.project = nn.Sequential(
"""Forward CFConv""" nn.Linear(self._rbf_dim, self._dim),
g.apply_edges(self.update_edge) activation,
g.update_all(message_func=fn.u_mul_e('new_node', 'h', 'neighbor_info'), nn.Linear(self._dim, self._dim)
reduce_func=fn.sum('neighbor_info', 'new_node')) )
return g.ndata["new_node"]
def forward(self, g, node_weight, rbf_out):
"""
Parameters
----------
g : DGLGraph
The graph for performing convolution
node_weight : float32 tensor of shape (B1, D1)
The weight of nodes in message passing, B1 for number of nodes and
D1 for node weight size.
rbf_out : float32 tensor of shape (B2, D2)
The output of RBFLayer, B2 for number of edges and D2 for rbf out size.
"""
g = g.local_var()
e = self.project(rbf_out)
g.ndata['node_weight'] = node_weight
g.edata['e'] = e
g.update_all(fn.u_mul_e('node_weight', 'e', 'm'), fn.sum('m', 'h'))
return g.ndata.pop('h')
class Interaction(nn.Module): class Interaction(nn.Module):
""" """
...@@ -231,118 +259,101 @@ class Interaction(nn.Module): ...@@ -231,118 +259,101 @@ class Interaction(nn.Module):
""" """
def __init__(self, rbf_dim, dim): def __init__(self, rbf_dim, dim):
super(Interaction, self).__init__() super(Interaction, self).__init__()
self._node_dim = dim
self.activation = nn.Softplus(beta=0.5, threshold=14)
self.node_layer1 = nn.Linear(dim, dim, bias=False)
self.cfconv = CFConv(rbf_dim, dim, act=self.activation)
self.node_layer2 = nn.Linear(dim, dim)
self.node_layer3 = nn.Linear(dim, dim)
def forward(self, g): self._dim = dim
self.node_layer1 = nn.Linear(dim, dim, bias=False)
self.cfconv = CFConv(rbf_dim, dim, Softplus(beta=0.5, threshold=14))
self.node_layer2 = nn.Sequential(
nn.Linear(dim, dim),
Softplus(beta=0.5, threshold=14),
nn.Linear(dim, dim)
)
def forward(self, g, n_feat, rbf_out):
""" """
Parameters Parameters
---------- ----------
g : DGLGraph g : DGLGraph
The graph for performing convolution
n_feat : float32 tensor of shape (B1, D1)
Node features, B1 for number of nodes and D1 for feature size.
rbf_out : float32 tensor of shape (B2, D2)
The output of RBFLayer, B2 for number of edges and D2 for rbf out size.
Returns Returns
------- -------
tensor float32 tensor of shape (B1, D1)
Updated atom representations Updated node representations
""" """
g.ndata["new_node"] = self.node_layer1(g.ndata["node"]) n_weight = self.node_layer1(n_feat)
cf_node = self.cfconv(g) new_n_feat = self.cfconv(g, n_weight, rbf_out)
cf_node_1 = self.node_layer2(cf_node) new_n_feat = self.node_layer2(new_n_feat)
cf_node_1a = self.activation(cf_node_1) return n_feat + new_n_feat
new_node = self.node_layer3(cf_node_1a)
g.ndata["node"] = g.ndata["node"] + new_node
return g.ndata["node"]
class VEConv(nn.Module): class VEConv(nn.Module):
""" """
The Vertex-Edge convolution layer in MGCN which takes edge & vertex features The Vertex-Edge convolution layer in MGCN which takes both edge & vertex features
in consideration at the same time. in consideration.
Parameters Parameters
---------- ----------
rbf_dim : int rbf_dim : int
Dimension of the RBF layer output Size of the RBF layer output
dim : int dim : int
Dimension of intermediate representations, default to be 64. Size of intermediate representations, default to be 64.
update_edge : bool update_edge : bool
Whether to apply a linear layer to update edge representations, default to be True. Whether to apply a linear layer to update edge representations, default to be True.
""" """
def __init__(self, rbf_dim, dim=64, update_edge=True): def __init__(self, rbf_dim, dim=64, update_edge=True):
super(VEConv, self).__init__() super(VEConv, self).__init__()
self._rbf_dim = rbf_dim self._rbf_dim = rbf_dim
self._dim = dim self._dim = dim
self._update_edge = update_edge self._update_edge = update_edge
self.linear_layer1 = nn.Linear(self._rbf_dim, self._dim) self.update_rbf = nn.Sequential(
self.linear_layer2 = nn.Linear(self._dim, self._dim) nn.Linear(self._rbf_dim, self._dim),
self.linear_layer3 = nn.Linear(self._dim, self._dim) nn.Softplus(beta=0.5, threshold=14),
nn.Linear(self._dim, self._dim)
self.activation = nn.Softplus(beta=0.5, threshold=14) )
self.update_efeat = nn.Linear(self._dim, self._dim)
def update_rbf(self, edges): def forward(self, g, n_feat, e_feat, rbf_out):
"""Update the RBF features
Parameters
----------
edges : EdgeBatch
Returns
-------
dict
Stores updated features in 'h'
""" """
rbf = edges.data["rbf"]
h = self.linear_layer1(rbf)
h = self.activation(h)
h = self.linear_layer2(h)
return {"h": h}
def update_edge(self, edges):
"""Update the edge features.
Parameters
----------
edges : EdgeBatch
Returns
-------
dict
Stores updated features in 'edge_f'
"""
edge_f = edges.data["edge_f"]
h = self.linear_layer3(edge_f)
return {"edge_f": h}
def forward(self, g):
"""VEConv layer forward
Parameters Parameters
---------- ----------
g : DGLGraph g : DGLGraph
The graph for performing convolution
n_feat : float32 tensor of shape (B1, D1)
Node features, B1 for number of nodes and D1 for feature size.
e_feat : float32 tensor of shape (B2, D2)
Edge features. B2 for number of edges and D2 for
the edge feature size.
rbf_out : float32 tensor of shape (B2, D3)
The output of RBFLayer, B2 for number of edges and D3 for rbf out size.
Returns Returns
------- -------
tensor n_feat : float32 tensor
Updated atom representations Updated node features.
e_feat : float32 tensor
(Potentially updated) edge features
""" """
g.apply_edges(self.update_rbf) rbf_out = self.update_rbf(rbf_out)
if self._update_edge:
g.apply_edges(self.update_edge)
g.update_all(message_func=[fn.u_mul_e("new_node", "h", "m_0"), if self._update_edge:
fn.copy_e("edge_f", "m_1")], e_feat = self.update_efeat(e_feat)
reduce_func=[fn.sum("m_0", "new_node_0"),
fn.sum("m_1", "new_node_1")])
g.ndata["new_node"] = g.ndata.pop("new_node_0") + \
g.ndata.pop("new_node_1")
return g.ndata["new_node"] g = g.local_var()
g.ndata.update({'n_feat': n_feat})
g.edata.update({'rbf_out': rbf_out, 'e_feat': e_feat})
g.update_all(message_func=[fn.u_mul_e('n_feat', 'rbf_out', 'm_0'),
fn.copy_e('e_feat', 'm_1')],
reduce_func=[fn.sum('m_0', 'n_feat_0'),
fn.sum('m_1', 'n_feat_1')])
n_feat = g.ndata.pop('n_feat_0') + g.ndata.pop('n_feat_1')
return n_feat, e_feat
class MultiLevelInteraction(nn.Module): class MultiLevelInteraction(nn.Module):
""" """
...@@ -359,35 +370,43 @@ class MultiLevelInteraction(nn.Module): ...@@ -359,35 +370,43 @@ class MultiLevelInteraction(nn.Module):
super(MultiLevelInteraction, self).__init__() super(MultiLevelInteraction, self).__init__()
self._atom_dim = dim self._atom_dim = dim
self.activation = nn.Softplus(beta=0.5, threshold=14)
self.node_layer1 = nn.Linear(dim, dim, bias=True) self.node_layer1 = nn.Linear(dim, dim, bias=True)
self.edge_layer1 = nn.Linear(dim, dim, bias=True)
self.conv_layer = VEConv(rbf_dim, dim) self.conv_layer = VEConv(rbf_dim, dim)
self.node_layer2 = nn.Linear(dim, dim) self.activation = nn.Softplus(beta=0.5, threshold=14)
self.node_layer3 = nn.Linear(dim, dim) self.edge_layer1 = nn.Linear(dim, dim, bias=True)
self.node_out = nn.Sequential(
nn.Linear(dim, dim),
nn.Softplus(beta=0.5, threshold=14),
nn.Linear(dim, dim)
)
def forward(self, g, level=1): def forward(self, g, n_feat, e_feat, rbf_out):
""" """
Parameters Parameters
---------- ----------
g : DGLGraph g : DGLGraph
level : int The graph for performing convolution
Level of interaction n_feat : float32 tensor of shape (B1, D1)
Node features, B1 for number of nodes and D1 for feature size.
e_feat : float32 tensor of shape (B2, D2)
Edge features. B2 for number of edges and D2 for
the edge feature size.
rbf_out : float32 tensor of shape (B2, D3)
The output of RBFLayer, B2 for number of edges and D3 for rbf out size.
Returns Returns
------- -------
tensor n_feat : float32 tensor
Updated atom representations Updated node representations
e_feat : float32 tensor
Updated edge representations
""" """
g.ndata["new_node"] = self.node_layer1( new_n_feat = self.node_layer1(n_feat)
g.ndata["node_%s" % (level - 1)]) new_n_feat, e_feat = self.conv_layer(g, new_n_feat, e_feat, rbf_out)
node = self.conv_layer(g) new_n_feat = self.node_out(new_n_feat)
g.edata["edge_f"] = self.activation( n_feat = n_feat + new_n_feat
self.edge_layer1(g.edata["edge_f"]))
node_1 = self.node_layer2(node) e_feat = self.activation(self.edge_layer1(e_feat))
node_1a = self.activation(node_1)
new_node = self.node_layer3(node_1a) return n_feat, e_feat
g.ndata["node_%s" % (level)] = g.ndata["node_%s" % (level - 1)] + new_node
return g.ndata["node_%s" % (level)]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment