Unverified Commit 189c2c09 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[Model Zoo] Refactor Model Zoo for Chemistry (#839)

* Update

* Update

* Update

* Update fix

* Update

* Update

* Refactor

* Update

* Update

* Update

* Update

* Update

* Update

* Fix style
parent 5b417683
......@@ -5,13 +5,17 @@
[Documentation](https://docs.dgl.ai) | [DGL at a glance](https://docs.dgl.ai/tutorials/basics/1_first.html#sphx-glr-tutorials-basics-1-first-py) |
[Model Tutorials](https://docs.dgl.ai/tutorials/models/index.html) | [Discussion Forum](https://discuss.dgl.ai)
Model Zoos: [Chemistry](https://github.com/dmlc/dgl/tree/master/examples/pytorch/model_zoo) | [Citation Networks](https://github.com/dmlc/dgl/tree/master/examples/pytorch/model_zoo/citation_network)
DGL is a Python package that interfaces between existing tensor libraries and data being expressed as
graphs.
It makes implementing graph neural networks (including Graph Convolution Networks, TreeLSTM, and many others) easy while
maintaining high computation efficiency.
A summary of the model accuracy and training speed with the Pytorch backend (on Amazon EC2 p3.2x instance (w/ V100 GPU)), as compared with the best open-source implementations:
All model examples can be found [here](https://github.com/dmlc/dgl/tree/master/examples).
A summary of part of the model accuracy and training speed with the Pytorch backend (on Amazon EC2 p3.2x instance (w/ V100 GPU)), as compared with the best open-source implementations:
| Model | Reported <br> Accuracy | DGL <br> Accuracy | Author's training speed (epoch time) | DGL speed (epoch time) | Improvement |
| ----- | ----------------- | ------------ | ------------------------------------ | ---------------------- | ----------- |
......
......@@ -15,6 +15,10 @@ Utils
utils.download
utils.check_sha1
utils.extract_archive
utils.split_dataset
.. autoclass:: dgl.data.utils.Subset
:members: __getitem__, __len__
Dataset Classes
---------------
......@@ -57,3 +61,54 @@ Protein-Protein Interaction dataset
.. autoclass:: PPIDataset
:members: __getitem__, __len__
Molecular Graphs
----------------
To work on molecular graphs, make sure you have installed `RDKit 2018.09.3 <https://www.rdkit.org/docs/Install.html>`__.
Featurization
`````````````
For the use of graph neural networks, we need to featurize nodes (atoms) and edges (bonds). Below we list some
featurization methods/utilities:
.. autosummary::
:toctree: ../../generated/
chem.one_hot_encoding
chem.BaseAtomFeaturizer
chem.CanonicalAtomFeaturizer
Graph Construction
``````````````````
Several methods for constructing DGLGraphs from SMILES/RDKit molecule objects are listed below:
.. autosummary::
:toctree: ../../generated/
chem.mol_to_graph
chem.smile_to_bigraph
chem.mol_to_bigraph
chem.smile_to_complete_graph
chem.mol_to_complete_graph
Dataset Classes
```````````````
If your dataset is stored in a ``.csv`` file, you may find it helpful to use
.. autoclass:: dgl.data.chem.CSVDataset
:members: __getitem__, __len__
Currently two datasets are supported:
* Tox21
* TencentAlchemyDataset
.. autoclass:: dgl.data.chem.Tox21
:members: __getitem__, __len__, task_pos_weights
.. autoclass:: dgl.data.chem.TencentAlchemyDataset
:members: __getitem__, __len__, set_mean_and_std
......@@ -19,3 +19,4 @@ API Reference
graph_store
nodeflow
random
model_zoo
.. _apimodelzoo:
Model Zoo
=========
.. currentmodule:: dgl.model_zoo
Chemistry
---------
Utils
`````
.. autosummary::
:toctree: ../../generated/
chem.load_pretrained
Property Prediction
```````````````````
Currently supported model architectures:
* GCNClassifier
* GATClassifier
* MPNN
* SchNet
* MGCN
.. autoclass:: dgl.model_zoo.chem.GCNClassifier
:members: forward
.. autoclass:: dgl.model_zoo.chem.GATClassifier
:members: forward
.. autoclass:: dgl.model_zoo.chem.MPNNModel
:members: forward
.. autoclass:: dgl.model_zoo.chem.SchNet
:members: forward
.. autoclass:: dgl.model_zoo.chem.MGCNModel
:members: forward
Generative Models
`````````````````
Currently supported model architectures:
* DGMG
* JTNN
.. autoclass:: dgl.model_zoo.chem.DGMG
:members: forward
.. autoclass:: dgl.model_zoo.chem.DGLJTNNVAE
:members: forward
......@@ -86,6 +86,7 @@ are also two accompanying review papers that are well written [7], [8].
### Models
- **Deep Generative Models of Graphs (DGMG)** [11]: A very general framework for graph distribution learning by
progressively adding atoms and bonds.
- **Junction Tree Variational Autoencoder for Molecular Graph Generation (JTNN)** [13]:
### Example Usage of Pre-trained Models
......@@ -143,3 +144,6 @@ Machine Learning* JMLR. 1263-1272.
[11] Li et al. (2018) Learning Deep Generative Models of Graphs. *arXiv preprint arXiv:1803.03324*.
[12] Goh et al. (2017) Deep learning for computational chemistry. *Journal of Computational Chemistry* 16, 1291-1307.
[13] Jin et al. (2018) Junction Tree Variational Autoencoder for Molecular Graph Generation.
*Proceedings of the 35th International Conference on Machine Learning (ICML)*, 2323-2332.
......@@ -15,7 +15,7 @@ into training, validation and test set with a 80/10/10 ratio. By default we foll
- **Graph Convolutional Network** [2], [3]. Graph Convolutional Networks (GCN) have been one of the most popular graph neural
networks and they can be easily extended for graph level prediction. MoleculeNet [1] reports baseline results of graph
convolutions over multiple datasets.
- **Graph Attention Networks** [7]: Graph Attention Networks (GATs) incorporate multi-head attention into GCNs,
- **Graph Attention Networks** [7]. Graph Attention Networks (GATs) incorporate multi-head attention into GCNs,
explicitly modeling the interactions between adjacent atoms.
### Usage
......@@ -49,16 +49,11 @@ a real difference.
| ---------------- | ---------------------- |
| Pretrained model | 0.827 |
## Dataset Customization
To customize your own dataset, see the instructions
[here](https://github.com/dmlc/dgl/tree/master/python/dgl/data/chem).
## Regression
Regression tasks require assigning continuous labels to a molecule, e.g. molecular energy.
### Dataset
### Datasets
- **Alchemy**. The [Alchemy Dataset](https://alchemy.tencent.com/) is introduced by Tencent Quantum Lab to facilitate the development of new
machine learning models useful for chemistry and materials science. The dataset lists 12 quantum mechanical properties of 130,000+ organic
......@@ -68,29 +63,39 @@ These properties have been calculated using the open-source computational chemis
### Models
- **SchNet**: SchNet is a novel deep learning architecture modeling quantum interactions in molecules which utilize the continuous-filter
convolutional layers [4].
- **Multilevel Graph Convolutional neural Network**: Multilevel Graph Convolutional neural Network (MGCN) is a hierarchical
graph neural network directly extracts features from the conformation and spatial information followed by the multilevel interactions [5].
- **Message Passing Neural Network**: Message Passing Neural Network (MPNN) is a network with edge network (enn) as front end
and Set2Set for output prediction [6].
- **Message Passing Neural Network** [6]. Message Passing Neural Networks (MPNNs) have reached the best performance on
the QM9 dataset for some time.
- **SchNet** [4]. SchNet employs continuous filter convolutional layers to model quantum interactions in molecules
without requiring them to lie on grids.
- **Multilevel Graph Convolutional Neural Network** [5]. Multilevel Graph Convolutional Neural Networks (MGCN) are
hierarchical graph neural networks that extract features from the conformation and spatial information followed by the
multilevel interactions.
### Usage
```py
python regression.py --model sch --epoch 200
Use `regression.py` with arguments
```
-m {MPNN,SCHNET,MGCN}, Model to use
-d {Alchemy}, Dataset to use
```
The model option must be one of 'sch', 'mgcn' or 'mpnn'.
### Performance
#### Alchemy
|Model |Mean Absolute Error (MAE)|
|-------------|-------------------------|
|SchNet[4] |0.065|
|MGCN[5] |0.050|
|MPNN[6] |0.056|
The Alchemy contest is still ongoing. Before the test set is fully released, we only include the performance numbers
on the training and validation set for reference.
| Model | Training MAE | Validation MAE |
| ---------- | ------------ | -------------- |
| SchNet [4] | 0.2665 | 0.6139 |
| MGCN [5] | 0.2395 | 0.6463 |
| MPNN [6] | 0.2452 | 0.6259 |
## Dataset Customization
To customize your own dataset, see the instructions
[here](https://github.com/dmlc/dgl/tree/master/python/dgl/data/chem).
## References
[1] Wu et al. (2017) MoleculeNet: a benchmark for molecular machine learning. *Chemical Science* 9, 513-530.
......
from dgl.data.utils import split_dataset
from dgl import model_zoo
import torch
from torch.nn import BCEWithLogitsLoss
from torch.optim import Adam
from torch.utils.data import DataLoader
from utils import Meter, EarlyStopping, collate_molgraphs, set_random_seed
from dgl import model_zoo
from dgl.data.utils import split_dataset
from utils import Meter, EarlyStopping, collate_molgraphs_for_classification, set_random_seed
def run_a_train_epoch(args, epoch, model, data_loader, loss_criterion, optimizer):
model.train()
......@@ -45,15 +46,18 @@ def main(args):
args['device'] = "cuda" if torch.cuda.is_available() else "cpu"
set_random_seed()
# Interchangeable with other Dataset
# Interchangeable with other datasets
if args['dataset'] == 'Tox21':
from dgl.data.chem import Tox21
dataset = Tox21()
trainset, valset, testset = split_dataset(dataset, args['train_val_test_split'])
train_loader = DataLoader(trainset, batch_size=args['batch_size'], collate_fn=collate_molgraphs)
val_loader = DataLoader(valset, batch_size=args['batch_size'], collate_fn=collate_molgraphs)
test_loader = DataLoader(testset, batch_size=args['batch_size'], collate_fn=collate_molgraphs)
train_loader = DataLoader(trainset, batch_size=args['batch_size'],
collate_fn=collate_molgraphs_for_classification)
val_loader = DataLoader(valset, batch_size=args['batch_size'],
collate_fn=collate_molgraphs_for_classification)
test_loader = DataLoader(testset, batch_size=args['batch_size'],
collate_fn=collate_molgraphs_for_classification)
if args['pre_trained']:
args['num_epochs'] = 0
......
......@@ -23,9 +23,38 @@ GAT_Tox21 = {
'patience': 10
}
MPNN_Alchemy = {
'batch_size': 16,
'num_epochs': 250,
'output_dim': 12,
'lr': 0.0001,
'patience': 50
}
SCHNET_Alchemy = {
'batch_size': 16,
'num_epochs': 250,
'norm': True,
'output_dim': 12,
'lr': 0.0001,
'patience': 50
}
MGCN_Alchemy = {
'batch_size': 16,
'num_epochs': 250,
'norm': True,
'output_dim': 12,
'lr': 0.0001,
'patience': 50
}
experiment_configures = {
'GCN_Tox21': GCN_Tox21,
'GAT_Tox21': GAT_Tox21
'GAT_Tox21': GAT_Tox21,
'MPNN_Alchemy': MPNN_Alchemy,
'SCHNET_Alchemy': SCHNET_Alchemy,
'MGCN_Alchemy': MGCN_Alchemy
}
def get_exp_configure(exp_name):
......
import argparse
import torch
import torch.nn as nn
from dgl.data.chem import alchemy
from dgl import model_zoo
from torch.utils.data import DataLoader
def train(model="sch",
epochs=80,
device=torch.device("cpu"),
training_set_size=0.8):
print("start")
alchemy_dataset = alchemy.TencentAlchemyDataset()
train_set, test_set = alchemy_dataset.split(train_size=training_set_size)
train_loader = DataLoader(dataset=train_set,
batch_size=20,
collate_fn=alchemy.batcher_dev,
shuffle=False,
num_workers=0)
test_loader = DataLoader(dataset=test_set,
batch_size=20,
collate_fn=alchemy.batcher_dev,
shuffle=False,
num_workers=0)
if model == "sch":
model = model_zoo.chem.SchNetModel(norm=True, output_dim=12)
model.set_mean_std(alchemy_dataset.mean, alchemy_dataset.std, device)
elif model == "mgcn":
model = model_zoo.chem.MGCNModel(norm=True, output_dim=12)
model.set_mean_std(alchemy_dataset.mean, alchemy_dataset.std, device)
elif model == "mpnn":
model = model_zoo.chem.MPNNModel(output_dim=12)
model.to(device)
loss_fn = nn.MSELoss()
MAE_fn = nn.L1Loss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
for epoch in range(epochs):
from dgl import model_zoo
w_loss, w_mae = 0, 0
from utils import set_random_seed, collate_molgraphs_for_regression, EarlyStopping
def regress(args, model, bg):
if args['model'] == 'MPNN':
h = bg.ndata.pop('n_feat')
e = bg.edata.pop('e_feat')
h, e = h.to(args['device']), e.to(args['device'])
return model(bg, h, e)
else:
node_types = bg.ndata.pop('node_type')
edge_distances = bg.edata.pop('distance')
node_types, edge_distances = node_types.to(args['device']), \
edge_distances.to(args['device'])
return model(bg, node_types, edge_distances)
def run_a_train_epoch(args, epoch, model, data_loader,
loss_criterion, score_criterion, optimizer):
model.train()
for idx, batch in enumerate(train_loader):
batch.graph.to(device)
batch.label = batch.label.to(device)
res = model(batch.graph)
loss = loss_fn(res, batch.label)
mae = MAE_fn(res, batch.label)
total_loss, total_score = 0, 0
for batch_id, batch_data in enumerate(data_loader):
smiles, bg, labels = batch_data
labels = labels.to(args['device'])
prediction = regress(args, model, bg)
loss = loss_criterion(prediction, labels)
score = score_criterion(prediction, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.detach().item() * bg.batch_size
total_score += score.detach().item() * bg.batch_size
total_loss /= len(data_loader.dataset)
total_score /= len(data_loader.dataset)
print('epoch {:d}/{:d}, training loss {:.4f}, training score {:.4f}'.format(
epoch + 1, args['num_epochs'], total_loss, total_score))
def run_an_eval_epoch(args, model, data_loader, score_criterion):
model.eval()
total_score = 0
with torch.no_grad():
for batch_id, batch_data in enumerate(data_loader):
smiles, bg, labels = batch_data
labels = labels.to(args['device'])
prediction = regress(args, model, bg)
score = score_criterion(prediction, labels)
total_score += score.detach().item() * bg.batch_size
total_score /= len(data_loader.dataset)
return total_score
def main(args):
args['device'] = "cuda" if torch.cuda.is_available() else "cpu"
set_random_seed()
# Interchangeable with other datasets
if args['dataset'] == 'Alchemy':
from dgl.data.chem import TencentAlchemyDataset
train_set = TencentAlchemyDataset(mode='dev')
val_set = TencentAlchemyDataset(mode='valid')
w_mae += mae.detach().item()
w_loss += loss.detach().item()
w_mae /= idx + 1
print("Epoch {:2d}, loss: {:.7f}, MAE: {:.7f}".format(
epoch, w_loss, w_mae))
train_loader = DataLoader(dataset=train_set,
batch_size=args['batch_size'],
collate_fn=collate_molgraphs_for_regression)
val_loader = DataLoader(dataset=val_set,
batch_size=args['batch_size'],
collate_fn=collate_molgraphs_for_regression)
if args['model'] == 'MPNN':
model = model_zoo.chem.MPNNModel(output_dim=args['output_dim'])
elif args['model'] == 'SCHNET':
model = model_zoo.chem.SchNet(norm=args['norm'], output_dim=args['output_dim'])
model.set_mean_std(train_set.mean, train_set.std, args['device'])
elif args['model'] == 'MGCN':
model = model_zoo.chem.MGCNModel(norm=args['norm'], output_dim=args['output_dim'])
model.set_mean_std(train_set.mean, train_set.std, args['device'])
model.to(args['device'])
w_loss, w_mae = 0, 0
model.eval()
loss_fn = nn.MSELoss()
score_fn = nn.L1Loss()
optimizer = torch.optim.Adam(model.parameters(), lr=args['lr'])
stopper = EarlyStopping(mode='lower', patience=args['patience'])
for epoch in range(args['num_epochs']):
# Train
run_a_train_epoch(args, epoch, model, train_loader, loss_fn, score_fn, optimizer)
# Validation and early stop
val_score = run_an_eval_epoch(args, model, val_loader, score_fn)
early_stop = stopper.step(val_score, model)
print('epoch {:d}/{:d}, validation score {:.4f}, best validation score {:.4f}'.format(
epoch + 1, args['num_epochs'], val_score, stopper.best_score))
if early_stop:
break
for idx, batch in enumerate(test_loader):
batch.graph.to(device)
batch.label = batch.label.to(device)
if __name__ == "__main__":
import argparse
res = model(batch.graph)
mae = MAE_fn(res, batch.label)
from configure import get_exp_configure
w_mae += mae.detach().item()
w_loss += loss.detach().item()
w_mae /= idx + 1
print("MAE (test set): {:.7f}".format(w_mae))
parser = argparse.ArgumentParser(description='Molecule Regression')
parser.add_argument('-m', '--model', type=str, choices=['MPNN', 'SCHNET', 'MGCN'],
help='Model to use')
parser.add_argument('-d', '--dataset', type=str, choices=['Alchemy'],
help='Dataset to use')
args = parser.parse_args().__dict__
args['exp'] = '_'.join([args['model'], args['dataset']])
args.update(get_exp_configure(args['exp']))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-M",
"--model",
help="model name (sch, mgcn, mpnn)",
choices=['sch', 'mgcn', 'mpnn'],
default="sch")
parser.add_argument("--epochs",
help="number of epochs",
default=250,
type=int)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
args = parser.parse_args()
train(args.model, args.epochs, device)
main(args)
......@@ -6,6 +6,13 @@ import torch
from sklearn.metrics import roc_auc_score
def set_random_seed(seed=0):
"""Set random seed.
Parameters
----------
seed : int
Random seed to use
"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
......@@ -13,19 +20,39 @@ def set_random_seed(seed=0):
torch.cuda.manual_seed(seed)
class Meter(object):
"""Track and summarize model performance on a dataset for
(multi-label) binary classification."""
def __init__(self):
self.mask = []
self.y_pred = []
self.y_true = []
def update(self, y_pred, y_true, mask):
self.y_pred.append(y_pred)
self.y_true.append(y_true)
self.mask.append(mask)
"""Update for the result of an iteration
Parameters
----------
y_pred : float32 tensor
Predicted molecule labels with shape (B, T),
B for batch size and T for the number of tasks
y_true : float32 tensor
Ground truth molecule labels with shape (B, T)
mask : float32 tensor
Mask for indicating the existence of ground
truth labels with shape (B, T)
"""
self.y_pred.append(y_pred.detach().cpu())
self.y_true.append(y_true.detach().cpu())
self.mask.append(mask.detach().cpu())
# Todo: Allow different evaluation metrics
def roc_auc_averaged_over_tasks(self):
"""Compute roc-auc score for each task and return the average."""
"""Compute roc-auc score for each task and return the average.
Returns
-------
float
roc-auc score averaged over all tasks
"""
mask = torch.cat(self.mask, dim=0)
y_pred = torch.cat(self.y_pred, dim=0)
y_true = torch.cat(self.y_true, dim=0)
......@@ -36,40 +63,64 @@ class Meter(object):
total_score = 0
for task in range(n_tasks):
task_w = mask[:, task]
task_y_true = y_true[:, task][task_w != 0].cpu().numpy()
task_y_pred = y_pred[:, task][task_w != 0].cpu().detach().numpy()
task_y_true = y_true[:, task][task_w != 0].numpy()
task_y_pred = y_pred[:, task][task_w != 0].numpy()
total_score += roc_auc_score(task_y_true, task_y_pred)
return total_score / n_tasks
class EarlyStopping(object):
def __init__(self, patience=10, filename=None):
"""Early stop performing
Parameters
----------
mode : str
* 'higher': Higher metric suggests a better model
* 'lower': Lower metric suggests a better model
patience : int
Number of epochs to wait before early stop
if the metric stops getting improved
filename : str or None
Filename for storing the model checkpoint
"""
def __init__(self, mode='higher', patience=10, filename=None):
if filename is None:
dt = datetime.datetime.now()
filename = 'early_stop_{}_{:02d}-{:02d}-{:02d}.pth'.format(
dt.date(), dt.hour, dt.minute, dt.second)
assert mode in ['higher', 'lower']
self.mode = mode
if self.mode == 'higher':
self._check = self._check_higher
else:
self._check = self._check_lower
self.patience = patience
self.counter = 0
self.filename = filename
self.best_score = None
self.early_stop = False
def step(self, acc, model):
score = acc
def _check_higher(self, score, prev_best_score):
return (score > prev_best_score)
def _check_lower(self, score, prev_best_score):
return (score < prev_best_score)
def step(self, score, model):
if self.best_score is None:
self.best_score = score
self.save_checkpoint(model)
# Todo: this is not true for all metrics.
elif score < self.best_score:
elif self._check(score, self.best_score):
self.best_score = score
self.save_checkpoint(model)
self.counter = 0
else:
self.counter += 1
print(
f'EarlyStopping counter: {self.counter} out of {self.patience}')
if self.counter >= self.patience:
self.early_stop = True
else:
self.best_score = score
self.save_checkpoint(model)
self.counter = 0
return self.early_stop
def save_checkpoint(self, model):
......@@ -80,14 +131,14 @@ class EarlyStopping(object):
'''Load model saved with early stopping.'''
model.load_state_dict(torch.load(self.filename)['model_state_dict'])
def collate_molgraphs(data):
"""Batching a list of datapoints for dataloader
def collate_molgraphs_for_classification(data):
"""Batching a list of datapoints for dataloader in classification tasks.
Parameters
----------
data : list of 4-tuples
Each tuple is for a single datapoint, consisting of
A SMILE, a DGLGraph, all-task labels and all-task weights
a SMILE, a DGLGraph, all-task labels and all-task weights
Returns
-------
......@@ -109,3 +160,29 @@ def collate_molgraphs(data):
labels = torch.stack(labels, dim=0)
mask = torch.stack(mask, dim=0)
return smiles, bg, labels, mask
def collate_molgraphs_for_regression(data):
"""Batching a list of datapoints for dataloader in regression tasks.
Parameters
----------
data : list of 3-tuples
Each tuple is for a single datapoint, consisting of
a SMILE, a DGLGraph and all-task labels.
Returns
-------
smiles : list
List of smiles
bg : BatchedDGLGraph
Batched DGLGraphs
labels : Tensor of dtype float32 and shape (B, T)
Batched datapoint labels. B is len(data) and
T is the number of total tasks.
"""
smiles, graphs, labels = map(list, zip(*data))
bg = dgl.batch(graphs)
bg.set_n_initializer(dgl.init.zero_initializer)
bg.set_e_initializer(dgl.init.zero_initializer)
labels = torch.stack(labels, dim=0)
return smiles, bg, labels
from .utils import *
from .csv_dataset import CSVDataset
from .tox21 import Tox21
from .alchemy import TencentAlchemyDataset
......@@ -11,8 +11,7 @@ import zipfile
from collections import defaultdict
from .utils import mol_to_complete_graph
from ..utils import download, get_download_dir
from ...batched_graph import batch
from ..utils import download, get_download_dir, _get_dgl_url
from ... import backend as F
try:
......@@ -23,111 +22,7 @@ try:
except ImportError:
pass
_urls = {'Alchemy': 'https://alchemy.tencent.com/data/dgl/'}
class AlchemyBatcher(object):
"""Data structure for holding a batch of data.
Parameters
----------
graph : dgl.BatchedDGLGraph
A batch of DGLGraphs for B molecules
labels : tensor
Labels for B molecules
"""
def __init__(self, graph=None, label=None):
self.graph = graph
self.label = label
def batcher_dev(batch_data):
"""Batch datapoints
Parameters
----------
batch_data : list
batch[i][0] gives the DGLGraph for the ith datapoint,
and batch[i][1] gives the label for the ith datapoint.
Returns
-------
AlchemyBatcher
An object holding the batch of data
"""
graphs, labels = zip(*batch_data)
batch_graphs = batch(graphs)
labels = F.stack(labels, 0)
return AlchemyBatcher(graph=batch_graphs, label=labels)
class TencentAlchemyDataset(object):
"""`Tencent Alchemy Dataset <https://arxiv.org/abs/1906.09427>`__
Parameters
----------
mode : str
'dev', 'valid' or 'test', default to be 'dev'
transform : transform operation on DGLGraphs
Default to be None.
from_raw : bool
Whether to process dataset from scratch or use a
processed one for faster speed. Default to be False.
"""
def __init__(self, mode='dev', transform=None, from_raw=False):
assert mode in ['dev', 'valid', 'test'], "mode should be dev/valid/test"
self.mode = mode
self.transform = transform
# Construct DGLGraphs from raw data or use the preprocessed data
self.from_raw = from_raw
file_dir = osp.join(get_download_dir(), './Alchemy_data')
if not from_raw:
file_name = "%s_processed" % (mode)
else:
file_name = "%s_single_sdf" % (mode)
self.file_dir = pathlib.Path(file_dir, file_name)
self.zip_file_path = pathlib.Path(file_dir, file_name + '.zip')
download(_urls['Alchemy'] + file_name + '.zip',
path=str(self.zip_file_path))
if not os.path.exists(str(self.file_dir)):
archive = zipfile.ZipFile(self.zip_file_path)
archive.extractall(file_dir)
archive.close()
self._load()
def _load(self):
if self.mode == 'dev':
if not self.from_raw:
with open(osp.join(self.file_dir, "dev_graphs.pkl"), "rb") as f:
self.graphs = pickle.load(f)
with open(osp.join(self.file_dir, "dev_labels.pkl"), "rb") as f:
self.labels = pickle.load(f)
else:
target_file = pathlib.Path(self.file_dir, "dev_target.csv")
self.target = pd.read_csv(
target_file,
index_col=0,
usecols=['gdb_idx',] + ['property_%d' % x for x in range(12)])
self.target = self.target[['property_%d' % x for x in range(12)]]
self.graphs, self.labels = [], []
supp = Chem.SDMolSupplier(
osp.join(self.file_dir, self.mode + ".sdf"))
cnt = 0
for sdf, label in zip(supp, self.target.iterrows()):
graph = mol_to_complete_graph(sdf, atom_featurizer=self.alchemy_nodes,
bond_featurizer=self.alchemy_edges)
cnt += 1
self.graphs.append(graph)
label = F.tensor(np.array(label[1].tolist()).astype(np.float32))
self.labels.append(label)
self.normalize()
print(len(self.graphs), "loaded!")
def alchemy_nodes(self, mol):
def alchemy_nodes(mol):
"""Featurization for all atoms in a molecule. The atom indices
will be preserved.
......@@ -150,7 +45,6 @@ class TencentAlchemyDataset(object):
mol_feats = mol_featurizer.GetFeaturesForMol(mol)
mol_conformers = mol.GetConformers()
assert len(mol_conformers) == 1
geom = mol_conformers[0].GetPositions()
for i in range(len(mol_feats)):
if mol_feats[i].GetFamily() == 'Donor':
......@@ -170,13 +64,10 @@ class TencentAlchemyDataset(object):
aromatic = atom.GetIsAromatic()
hybridization = atom.GetHybridization()
num_h = atom.GetTotalNumHs()
atom_feats_dict['pos'].append(F.tensor(geom[u].astype(np.float32)))
atom_feats_dict['node_type'].append(atom_type)
h_u = []
h_u += [
int(symbol == x) for x in ['H', 'C', 'N', 'O', 'F', 'S', 'Cl']
]
h_u += [int(symbol == x) for x in ['H', 'C', 'N', 'O', 'F', 'S', 'Cl']]
h_u.append(atom_type)
h_u.append(is_acceptor[u])
h_u.append(is_donor[u])
......@@ -191,13 +82,12 @@ class TencentAlchemyDataset(object):
atom_feats_dict['n_feat'].append(F.tensor(np.array(h_u).astype(np.float32)))
atom_feats_dict['n_feat'] = F.stack(atom_feats_dict['n_feat'], dim=0)
atom_feats_dict['pos'] = F.stack(atom_feats_dict['pos'], dim=0)
atom_feats_dict['node_type'] = F.tensor(np.array(
atom_feats_dict['node_type']).astype(np.int64))
return atom_feats_dict
def alchemy_edges(self, mol, self_loop=False):
def alchemy_edges(mol, self_loop=False):
"""Featurization for all bonds in a molecule.
The bond indices will be preserved.
......@@ -205,6 +95,8 @@ class TencentAlchemyDataset(object):
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule object
self_loop : bool
Whether to add self loops. Default to be False.
Returns
-------
......@@ -245,80 +137,138 @@ class TencentAlchemyDataset(object):
return bond_feats_dict
def normalize(self, mean=None, std=None):
"""Set mean and std or compute from labels for future normalization.
class TencentAlchemyDataset(object):
"""
Developed by the Tencent Quantum Lab, the dataset lists 12 quantum mechanical
properties of 130, 000+ organic molecules, comprising up to 12 heavy atoms
(C, N, O, S, F and Cl), sampled from the GDBMedChem database. These properties
have been calculated using the open-source computational chemistry program
Python-based Simulation of Chemistry Framework (PySCF).
For more details, check the `paper <https://arxiv.org/abs/1906.09427>`__.
Parameters
----------
mean : int or float
Default to be None.
std : int or float
Default to be None.
mode : str
'dev', 'valid' or 'test', separately for training, validation and test.
Default to be 'dev'. Note that 'test' is not available as the Alchemy
contest is ongoing.
from_raw : bool
Whether to process the dataset from scratch or use a
processed one for faster speed. Default to be False.
"""
labels = np.array([i.numpy() for i in self.labels])
if mean is None:
mean = np.mean(labels, axis=0)
if std is None:
std = np.std(labels, axis=0)
self.mean = mean
self.std = std
def __init__(self, mode='dev', from_raw=False):
if mode == 'test':
raise ValueError('The test mode is not supported before '
'the Alchemy contest finishes.')
def __len__(self):
return len(self.graphs)
assert mode in ['dev', 'valid', 'test'], \
'Expect mode to be dev, valid or test, got {}.'.format(mode)
def __getitem__(self, idx):
g, l = self.graphs[idx], self.labels[idx]
if self.transform:
g = self.transform(g)
return g, l
self.mode = mode
def split(self, train_size=0.8):
"""Split the dataset into two AlchemySubset for train&test.
# Construct DGLGraphs from raw data or use the preprocessed data
self.from_raw = from_raw
file_dir = osp.join(get_download_dir(), 'Alchemy_data')
if not from_raw:
file_name = "%s_processed" % (mode)
else:
file_name = "%s_single_sdf" % (mode)
self.file_dir = pathlib.Path(file_dir, file_name)
self._url = 'dataset/alchemy/'
self.zip_file_path = pathlib.Path(file_dir, file_name + '.zip')
download(_get_dgl_url(self._url + file_name + '.zip'), path=str(self.zip_file_path))
if not os.path.exists(str(self.file_dir)):
archive = zipfile.ZipFile(self.zip_file_path)
archive.extractall(file_dir)
archive.close()
self._load()
def _load(self):
if not self.from_raw:
with open(osp.join(self.file_dir, "%s_graphs.pkl" % self.mode), "rb") as f:
self.graphs = pickle.load(f)
with open(osp.join(self.file_dir, "%s_labels.pkl" % self.mode), "rb") as f:
self.labels = pickle.load(f)
else:
print('Start preprocessing dataset...')
target_file = pathlib.Path(self.file_dir, "%s_target.csv" % self.mode)
self.target = pd.read_csv(
target_file,
index_col=0,
usecols=['gdb_idx',] + ['property_%d' % x for x in range(12)])
self.target = self.target[['property_%d' % x for x in range(12)]]
self.graphs, self.labels = [], []
supp = Chem.SDMolSupplier(osp.join(self.file_dir, self.mode + ".sdf"))
cnt = 0
dataset_size = len(self.target)
for mol, label in zip(supp, self.target.iterrows()):
cnt += 1
print('Processing molecule {:d}/{:d}'.format(cnt, dataset_size))
graph = mol_to_complete_graph(mol, atom_featurizer=alchemy_nodes,
bond_featurizer=alchemy_edges)
smile = Chem.MolToSmiles(mol)
graph.smile = smile
self.graphs.append(graph)
label = F.tensor(np.array(label[1].tolist()).astype(np.float32))
self.labels.append(label)
with open(osp.join(self.file_dir, "%s_graphs.pkl" % self.mode), "wb") as f:
pickle.dump(self.graphs, f)
with open(osp.join(self.file_dir, "%s_labels.pkl" % self.mode), "wb") as f:
pickle.dump(self.labels, f)
self.set_mean_and_std()
print(len(self.graphs), "loaded!")
def __getitem__(self, item):
"""Get datapoint with index
Parameters
----------
train_size : float
Proportion of dataset to use for training. Default to be 0.8.
item : int
Datapoint index
Returns
-------
train_set : AlchemySubset
Dataset for training
test_set : AlchemySubset
Dataset for test
str
SMILES for the ith datapoint
DGLGraph
DGLGraph for the ith datapoint
Tensor of dtype float32
Labels of the datapoint for all tasks
"""
assert 0 < train_size < 1
train_num = int(len(self.graphs) * train_size)
train_set = AlchemySubset(self.graphs[:train_num],
self.labels[:train_num], self.mean, self.std,
self.transform)
test_set = AlchemySubset(self.graphs[train_num:],
self.labels[train_num:], self.mean, self.std,
self.transform)
return train_set, test_set
class AlchemySubset(TencentAlchemyDataset):
g, l = self.graphs[item], self.labels[item]
return g.smile, g, l
def __len__(self):
"""Length of the dataset
Returns
-------
int
Length of Dataset
"""
Sub-dataset split from TencentAlchemyDataset.
Used to construct the training & test set.
return len(self.graphs)
def set_mean_and_std(self, mean=None, std=None):
"""Set mean and std or compute from labels for future normalization.
Parameters
----------
graphs : list of DGLGraphs
DGLGraphs for datapoints in the subset
labels : list of tensors
Labels for datapoints in the subset
mean : int or float
Mean of labels in the subset
Default to be None.
std : int or float
Std of labels in the subset
transform : transform operation on DGLGraphs
Default to be None.
"""
def __init__(self, graphs, labels, mean=0, std=1, transform=None):
super(AlchemySubset, self).__init__()
self.graphs = graphs
self.labels = labels
labels = np.array([i.numpy() for i in self.labels])
if mean is None:
mean = np.mean(labels, axis=0)
if std is None:
std = np.std(labels, axis=0)
self.mean = mean
self.std = std
self.transform = transform
......@@ -27,18 +27,14 @@ class CSVDataset(object):
Dataframe including smiles and labels. Can be loaded by pandas.read_csv(file_path).
One column includes smiles and other columns for labels.
Column names other than smiles column would be considered as task names.
smile_to_graph: callable, str -> DGLGraph
A function turns smiles into a DGLGraph. Default one can be found
at python/dgl/data/chem/utils.py named with smile_to_bigraph.
smile_column: str
Column name that including smiles
cache_file_path: str
Path to store the preprocessed data
"""
def __init__(self, df, smile_to_graph=smile_to_bigraph, smile_column='smiles',
cache_file_path="csvdata_dglgraph.pkl"):
if 'rdkit' not in sys.modules:
......@@ -59,6 +55,11 @@ class CSVDataset(object):
and featurize their atoms
* Set missing labels to be 0 and use a binary masking
matrix to mask them
Parameters
----------
smile_to_graph : callable, SMILES -> DGLGraph
Function for converting a SMILES (str) into a DGLGraph
"""
if os.path.exists(self.cache_file_path):
# DGLGraphs have been constructed before, reload them
......@@ -76,7 +77,12 @@ class CSVDataset(object):
self.mask = (~np.isnan(_label_values)).astype(np.float32)
def __getitem__(self, item):
"""Get the ith datapoint
"""Get datapoint with index
Parameters
----------
item : int
Datapoint index
Returns
-------
......@@ -87,17 +93,17 @@ class CSVDataset(object):
Tensor of dtype float32
Labels of the datapoint for all tasks
Tensor of dtype float32
Weights of the datapoint for all tasks
Binary masks indicating the existence of labels for all tasks
"""
return self.smiles[item], self.graphs[item], \
F.zerocopy_from_numpy(self.labels[item]), \
F.zerocopy_from_numpy(self.mask[item])
def __len__(self):
"""Length of Dataset
"""Length of the dataset
Return
------
Returns
-------
int
Length of Dataset
"""
......
......@@ -11,9 +11,6 @@ except ImportError:
pass
class Tox21(CSVDataset):
_url = 'dataset/tox21.csv.gz'
"""Tox21 dataset.
The Toxicology in the 21st Century (https://tripod.nih.gov/tox21/challenge/)
......@@ -41,6 +38,7 @@ class Tox21(CSVDataset):
from ...base import dgl_warning
dgl_warning("Please install pandas")
self._url = 'dataset/tox21.csv.gz'
data_path = get_download_dir() + '/tox21.csv.gz'
download(_get_dgl_url(self._url), path=data_path)
df = pd.read_csv(data_path)
......@@ -80,7 +78,7 @@ class Tox21(CSVDataset):
Returns
-------
list
numpy.ndarray
numpy array gives the weight of positive samples on all tasks
"""
return self._task_pos_weights
......@@ -10,6 +10,10 @@ try:
except ImportError:
pass
__all__ = ['one_hot_encoding', 'BaseAtomFeaturizer', 'CanonicalAtomFeaturizer',
'mol_to_graph', 'smile_to_bigraph', 'mol_to_bigraph',
'smile_to_complete_graph', 'mol_to_complete_graph']
def one_hot_encoding(x, allowable_set):
"""One-hot encoding.
......
......@@ -29,13 +29,37 @@ def _get_dgl_url(file_url):
def split_dataset(dataset, frac_list=None, shuffle=False, random_state=None):
"""Split dataset into training, validation and test set.
Parameters
----------
dataset
We assume ``len(dataset)`` gives the number of datapoints and ``dataset[i]``
gives the ith datapoint.
frac_list : list or None, optional
A list of length 3 containing the fraction to use for training,
validation and test. If None, we will use [0.8, 0.1, 0.1].
shuffle : bool, optional
By default we perform a consecutive split of the dataset. If True,
we will first randomly shuffle the dataset.
random_state : None, int or array_like, optional
Random seed used to initialize the pseudo-random number generator.
Can be any integer between 0 and 2**32 - 1 inclusive, an array
(or other sequence) of such integers, or None (the default).
If seed is None, then RandomState will try to read data from /dev/urandom
(or the Windows analogue) if available or seed from the clock otherwise.
Returns
-------
list of length 3
Subsets for training, validation and test.
"""
from itertools import accumulate
if frac_list is None:
frac_list = [0.8, 0.1, 0.1]
frac_list = np.array(frac_list)
assert np.allclose(np.sum(frac_list), 1.), \
'Expect frac_list sum to 1, got {:.4f}'.format(
np.sum(frac_list))
'Expect frac_list sum to 1, got {:.4f}'.format(np.sum(frac_list))
num_data = len(dataset)
lengths = (num_data * frac_list).astype(int)
lengths[-1] = num_data - np.sum(lengths[:-1])
......
......@@ -2,7 +2,7 @@
"""Model Zoo Package"""
from .classifiers import GCNClassifier, GATClassifier
from .sch import SchNetModel
from .schnet import SchNet
from .mgcn import MGCNModel
from .mpnn import MPNNModel
from .dgmg import DGMG
......
......@@ -629,6 +629,8 @@ def dgmg_message_weight_init(m):
class DGMG(nn.Module):
"""DGMG model
`Learning Deep Generative Models of Graphs <https://arxiv.org/abs/1803.03324>`__
Users only need to initialize an instance of this class.
Parameters
......
......@@ -27,7 +27,10 @@ from .nnutils import create_var, cuda, move_dgl_to_cuda
class DGLJTNNVAE(nn.Module):
"""
`Junction Tree Variational Autoencoder for Molecular Graph Generation
<https://arxiv.org/abs/1802.04364>`__
"""
def __init__(self, hidden_size, latent_size, depth, vocab=None, vocab_file=None):
super(DGLJTNNVAE, self).__init__()
if vocab is None:
......
......@@ -3,14 +3,13 @@
"""
The implementation of neural network layers used in SchNet and MGCN.
"""
import torch as th
import torch
import torch.nn as nn
from torch.nn import Softplus
import numpy as np
from ... import function as fn
class AtomEmbedding(nn.Module):
"""
Convert the atom(node) list to atom embeddings.
......@@ -19,7 +18,7 @@ class AtomEmbedding(nn.Module):
Parameters
----------
dim : int
Dim of embeddings, default to be 128.
Size of embeddings, default to be 128.
type_num : int
The largest atomic number of atoms in the dataset, default to be 100.
pre_train : None or pre-trained embeddings
......@@ -27,6 +26,7 @@ class AtomEmbedding(nn.Module):
"""
def __init__(self, dim=128, type_num=100, pre_train=None):
super(AtomEmbedding, self).__init__()
self._dim = dim
self._type_num = type_num
if pre_train is not None:
......@@ -34,19 +34,19 @@ class AtomEmbedding(nn.Module):
else:
self.embedding = nn.Embedding(type_num, dim, padding_idx=0)
def forward(self, g, p_name="node"):
def forward(self, atom_types):
"""
Parameters
----------
g : DGLGraph
Input DGLGraph(s)
p_name : str
Name for storing atom embeddings
"""
atom_list = g.ndata["node_type"]
g.ndata[p_name] = self.embedding(atom_list)
return g.ndata[p_name]
atom_types : int64 tensor of shape (B1)
Types for atoms in the graph(s), B1 for the number of atoms.
Returns
-------
float32 tensor of shape (B1, self._dim)
Atom embeddings.
"""
return self.embedding(atom_types)
class EdgeEmbedding(nn.Module):
"""
......@@ -56,14 +56,15 @@ class EdgeEmbedding(nn.Module):
Parameters
----------
dim : int
Dim of edge embeddings, default to be 128.
Size of embeddings, default to be 128.
edge_num : int
Maximum number of edge types, default to be 128.
Maximum number of edge types allowed, default to be 3000.
pre_train : Edge embeddings or None
Pre-trained edge embeddings
Pre-trained edge embeddings, default to be None.
"""
def __init__(self, dim=128, edge_num=3000, pre_train=None):
super(EdgeEmbedding, self).__init__()
self._dim = dim
self._edge_num = edge_num
if pre_train is not None:
......@@ -74,12 +75,13 @@ class EdgeEmbedding(nn.Module):
def generate_edge_type(self, edges):
"""Generate edge type.
The edge type is based on the type of the src&dst atom.
Note that C-O and O-C are the same edge type.
The edge type is based on the type of the src & dst atom.
Note that directions are not distinguished, e.g. C-O and O-C are the same edge type.
To map a pair of nodes to one number, we use an unordered pairing function here
See more detail in this disscussion:
https://math.stackexchange.com/questions/23503/create-unique-number-from-2-numbers
Note that, the edge_num should be larger than the square of maximum atomic number
Note that the edge_num should be larger than the square of maximum atomic number
in the dataset.
Parameters
......@@ -92,45 +94,51 @@ class EdgeEmbedding(nn.Module):
dict
Stores the edge types in "type"
"""
atom_type_x = edges.src["node_type"]
atom_type_y = edges.dst["node_type"]
atom_type_x = edges.src['ntype']
atom_type_y = edges.dst['ntype']
return {
"type": atom_type_x * atom_type_y +
(th.abs(atom_type_x - atom_type_y) - 1)**2 / 4
'etype': atom_type_x * atom_type_y + \
(torch.abs(atom_type_x - atom_type_y) - 1) ** 2 / 4
}
def forward(self, g, p_name="edge_f"):
def forward(self, g, atom_types):
"""Compute edge embeddings
Parameters
----------
g : DGLGraph
The graph to compute edge embeddings
p_name : str
atom_types : int64 tensor of shape (B1)
Types for atoms in the graph(s), B1 for the number of atoms.
Returns
-------
computed edge embeddings
float32 tensor of shape (B2, self._dim)
Computed edge embeddings
"""
g = g.local_var()
g.ndata['ntype'] = atom_types
g.apply_edges(self.generate_edge_type)
g.edata[p_name] = self.embedding(g.edata["type"])
return g.edata[p_name]
return self.embedding(g.edata.pop('etype'))
class ShiftSoftplus(Softplus):
class ShiftSoftplus(nn.Module):
"""
Shiftsoft plus activation function:
ShiftSoftplus activation function:
1/beta * (log(1 + exp**(beta * x)) - log(shift))
Parameters
----------
beta : int
Default to be 1.
shift : int
Default to be 2.
threshold : int
Default to be 20.
"""
def __init__(self, beta=1, shift=2, threshold=20):
super(ShiftSoftplus, self).__init__(beta, threshold)
super(ShiftSoftplus, self).__init__()
self.shift = shift
self.softplus = Softplus(beta, threshold)
......@@ -138,43 +146,56 @@ class ShiftSoftplus(Softplus):
"""Applies the activation function"""
return self.softplus(x) - np.log(float(self.shift))
class RBFLayer(nn.Module):
"""
Radial basis functions Layer.
e(d) = exp(- gamma * ||d - mu_k||^2)
default settings:
gamma = 10
0 <= mu_k <= 30 for k=1~300
With the default parameters below, we are using a default settings:
* gamma = 10
* 0 <= mu_k <= 30 for k=1~300
Parameters
----------
low : int
Smallest value to take for mu_k, default to be 0.
high : int
Largest value to take for mu_k, default to be 30.
gap : float
Difference between two consecutive values for mu_k, default to be 0.1.
dim : int
Output size for each center, default to be 1.
"""
def __init__(self, low=0, high=30, gap=0.1, dim=1):
super(RBFLayer, self).__init__()
self._low = low
self._high = high
self._gap = gap
self._dim = dim
self._n_centers = int(np.ceil((high - low) / gap))
centers = np.linspace(low, high, self._n_centers)
self.centers = th.tensor(centers, dtype=th.float, requires_grad=False)
self.centers = torch.tensor(centers, dtype=torch.float, requires_grad=False)
self.centers = nn.Parameter(self.centers, requires_grad=False)
self._fan_out = self._dim * self._n_centers
self._gap = centers[1] - centers[0]
def dis2rbf(self, edges):
"""Convert distance matrix to RBF tensor."""
dist = edges.data["distance"]
radial = dist - self.centers
coef = -1 / self._gap
rbf = th.exp(coef * (radial**2))
return {"rbf": rbf}
def forward(self, g):
"""Convert distance scalar to rbf vector"""
g.apply_edges(self.dis2rbf)
return g.edata["rbf"]
def forward(self, edge_distances):
"""
Parameters
----------
edge_distances : float32 tensor of shape (B, 1)
Edge distances, B for the number of edges.
Returns
-------
float32 tensor of shape (B, self._fan_out)
Computed RBF results
"""
radial = edge_distances - self.centers
coef = -1 / self._gap
return torch.exp(coef * (radial ** 2))
class CFConv(nn.Module):
"""
......@@ -185,38 +206,45 @@ class CFConv(nn.Module):
rbf_dim : int
Dimension of the RBF layer output
dim : int
Dimension of linear layers, default to be 64
act : str or activation function
Dimension of output, default to be 64
act : activation function or None.
Activation function, default to be shifted softplus
"""
def __init__(self, rbf_dim, dim=64, act=None):
super(CFConv, self).__init__()
self._rbf_dim = rbf_dim
self._dim = dim
self.linear_layer1 = nn.Linear(self._rbf_dim, self._dim)
self.linear_layer2 = nn.Linear(self._dim, self._dim)
if act is None:
self.activation = nn.Softplus(beta=0.5, threshold=14)
activation = nn.Softplus(beta=0.5, threshold=14)
else:
self.activation = act
activation = act
def update_edge(self, edges):
"""Update the edge features with two FC layers."""
rbf = edges.data["rbf"]
h = self.linear_layer1(rbf)
h = self.activation(h)
h = self.linear_layer2(h)
return {"h": h}
def forward(self, g):
"""Forward CFConv"""
g.apply_edges(self.update_edge)
g.update_all(message_func=fn.u_mul_e('new_node', 'h', 'neighbor_info'),
reduce_func=fn.sum('neighbor_info', 'new_node'))
return g.ndata["new_node"]
self.project = nn.Sequential(
nn.Linear(self._rbf_dim, self._dim),
activation,
nn.Linear(self._dim, self._dim)
)
def forward(self, g, node_weight, rbf_out):
"""
Parameters
----------
g : DGLGraph
The graph for performing convolution
node_weight : float32 tensor of shape (B1, D1)
The weight of nodes in message passing, B1 for number of nodes and
D1 for node weight size.
rbf_out : float32 tensor of shape (B2, D2)
The output of RBFLayer, B2 for number of edges and D2 for rbf out size.
"""
g = g.local_var()
e = self.project(rbf_out)
g.ndata['node_weight'] = node_weight
g.edata['e'] = e
g.update_all(fn.u_mul_e('node_weight', 'e', 'm'), fn.sum('m', 'h'))
return g.ndata.pop('h')
class Interaction(nn.Module):
"""
......@@ -231,118 +259,101 @@ class Interaction(nn.Module):
"""
def __init__(self, rbf_dim, dim):
super(Interaction, self).__init__()
self._node_dim = dim
self.activation = nn.Softplus(beta=0.5, threshold=14)
self.node_layer1 = nn.Linear(dim, dim, bias=False)
self.cfconv = CFConv(rbf_dim, dim, act=self.activation)
self.node_layer2 = nn.Linear(dim, dim)
self.node_layer3 = nn.Linear(dim, dim)
def forward(self, g):
self._dim = dim
self.node_layer1 = nn.Linear(dim, dim, bias=False)
self.cfconv = CFConv(rbf_dim, dim, Softplus(beta=0.5, threshold=14))
self.node_layer2 = nn.Sequential(
nn.Linear(dim, dim),
Softplus(beta=0.5, threshold=14),
nn.Linear(dim, dim)
)
def forward(self, g, n_feat, rbf_out):
"""
Parameters
----------
g : DGLGraph
The graph for performing convolution
n_feat : float32 tensor of shape (B1, D1)
Node features, B1 for number of nodes and D1 for feature size.
rbf_out : float32 tensor of shape (B2, D2)
The output of RBFLayer, B2 for number of edges and D2 for rbf out size.
Returns
-------
tensor
Updated atom representations
float32 tensor of shape (B1, D1)
Updated node representations
"""
g.ndata["new_node"] = self.node_layer1(g.ndata["node"])
cf_node = self.cfconv(g)
cf_node_1 = self.node_layer2(cf_node)
cf_node_1a = self.activation(cf_node_1)
new_node = self.node_layer3(cf_node_1a)
g.ndata["node"] = g.ndata["node"] + new_node
return g.ndata["node"]
n_weight = self.node_layer1(n_feat)
new_n_feat = self.cfconv(g, n_weight, rbf_out)
new_n_feat = self.node_layer2(new_n_feat)
return n_feat + new_n_feat
class VEConv(nn.Module):
"""
The Vertex-Edge convolution layer in MGCN which takes edge & vertex features
in consideration at the same time.
The Vertex-Edge convolution layer in MGCN which takes both edge & vertex features
in consideration.
Parameters
----------
rbf_dim : int
Dimension of the RBF layer output
Size of the RBF layer output
dim : int
Dimension of intermediate representations, default to be 64.
Size of intermediate representations, default to be 64.
update_edge : bool
Whether to apply a linear layer to update edge representations, default to be True.
"""
def __init__(self, rbf_dim, dim=64, update_edge=True):
super(VEConv, self).__init__()
self._rbf_dim = rbf_dim
self._dim = dim
self._update_edge = update_edge
self.linear_layer1 = nn.Linear(self._rbf_dim, self._dim)
self.linear_layer2 = nn.Linear(self._dim, self._dim)
self.linear_layer3 = nn.Linear(self._dim, self._dim)
self.activation = nn.Softplus(beta=0.5, threshold=14)
def update_rbf(self, edges):
"""Update the RBF features
Parameters
----------
edges : EdgeBatch
Returns
-------
dict
Stores updated features in 'h'
"""
rbf = edges.data["rbf"]
h = self.linear_layer1(rbf)
h = self.activation(h)
h = self.linear_layer2(h)
return {"h": h}
def update_edge(self, edges):
"""Update the edge features.
Parameters
----------
edges : EdgeBatch
self.update_rbf = nn.Sequential(
nn.Linear(self._rbf_dim, self._dim),
nn.Softplus(beta=0.5, threshold=14),
nn.Linear(self._dim, self._dim)
)
self.update_efeat = nn.Linear(self._dim, self._dim)
Returns
-------
dict
Stores updated features in 'edge_f'
def forward(self, g, n_feat, e_feat, rbf_out):
"""
edge_f = edges.data["edge_f"]
h = self.linear_layer3(edge_f)
return {"edge_f": h}
def forward(self, g):
"""VEConv layer forward
Parameters
----------
g : DGLGraph
The graph for performing convolution
n_feat : float32 tensor of shape (B1, D1)
Node features, B1 for number of nodes and D1 for feature size.
e_feat : float32 tensor of shape (B2, D2)
Edge features. B2 for number of edges and D2 for
the edge feature size.
rbf_out : float32 tensor of shape (B2, D3)
The output of RBFLayer, B2 for number of edges and D3 for rbf out size.
Returns
-------
tensor
Updated atom representations
n_feat : float32 tensor
Updated node features.
e_feat : float32 tensor
(Potentially updated) edge features
"""
g.apply_edges(self.update_rbf)
if self._update_edge:
g.apply_edges(self.update_edge)
rbf_out = self.update_rbf(rbf_out)
g.update_all(message_func=[fn.u_mul_e("new_node", "h", "m_0"),
fn.copy_e("edge_f", "m_1")],
reduce_func=[fn.sum("m_0", "new_node_0"),
fn.sum("m_1", "new_node_1")])
g.ndata["new_node"] = g.ndata.pop("new_node_0") + \
g.ndata.pop("new_node_1")
if self._update_edge:
e_feat = self.update_efeat(e_feat)
return g.ndata["new_node"]
g = g.local_var()
g.ndata.update({'n_feat': n_feat})
g.edata.update({'rbf_out': rbf_out, 'e_feat': e_feat})
g.update_all(message_func=[fn.u_mul_e('n_feat', 'rbf_out', 'm_0'),
fn.copy_e('e_feat', 'm_1')],
reduce_func=[fn.sum('m_0', 'n_feat_0'),
fn.sum('m_1', 'n_feat_1')])
n_feat = g.ndata.pop('n_feat_0') + g.ndata.pop('n_feat_1')
return n_feat, e_feat
class MultiLevelInteraction(nn.Module):
"""
......@@ -359,35 +370,43 @@ class MultiLevelInteraction(nn.Module):
super(MultiLevelInteraction, self).__init__()
self._atom_dim = dim
self.activation = nn.Softplus(beta=0.5, threshold=14)
self.node_layer1 = nn.Linear(dim, dim, bias=True)
self.edge_layer1 = nn.Linear(dim, dim, bias=True)
self.conv_layer = VEConv(rbf_dim, dim)
self.node_layer2 = nn.Linear(dim, dim)
self.node_layer3 = nn.Linear(dim, dim)
self.activation = nn.Softplus(beta=0.5, threshold=14)
self.edge_layer1 = nn.Linear(dim, dim, bias=True)
self.node_out = nn.Sequential(
nn.Linear(dim, dim),
nn.Softplus(beta=0.5, threshold=14),
nn.Linear(dim, dim)
)
def forward(self, g, level=1):
def forward(self, g, n_feat, e_feat, rbf_out):
"""
Parameters
----------
g : DGLGraph
level : int
Level of interaction
The graph for performing convolution
n_feat : float32 tensor of shape (B1, D1)
Node features, B1 for number of nodes and D1 for feature size.
e_feat : float32 tensor of shape (B2, D2)
Edge features. B2 for number of edges and D2 for
the edge feature size.
rbf_out : float32 tensor of shape (B2, D3)
The output of RBFLayer, B2 for number of edges and D3 for rbf out size.
Returns
-------
tensor
Updated atom representations
n_feat : float32 tensor
Updated node representations
e_feat : float32 tensor
Updated edge representations
"""
g.ndata["new_node"] = self.node_layer1(
g.ndata["node_%s" % (level - 1)])
node = self.conv_layer(g)
g.edata["edge_f"] = self.activation(
self.edge_layer1(g.edata["edge_f"]))
node_1 = self.node_layer2(node)
node_1a = self.activation(node_1)
new_node = self.node_layer3(node_1a)
g.ndata["node_%s" % (level)] = g.ndata["node_%s" % (level - 1)] + new_node
return g.ndata["node_%s" % (level)]
new_n_feat = self.node_layer1(n_feat)
new_n_feat, e_feat = self.conv_layer(g, new_n_feat, e_feat, rbf_out)
new_n_feat = self.node_out(new_n_feat)
n_feat = n_feat + new_n_feat
e_feat = self.activation(self.edge_layer1(e_feat))
return n_feat, e_feat
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment