Commit 3bc73931 authored by lunar's avatar lunar Committed by Mufei Li
Browse files

[Model Zoo] Molecule Regression (#779)

* [Model] MPNN

* [Model] MPNN 🔨 reorganize the mpnn/sch/mgcn model & alchemy dataset

* [Model] MPNN alchemy dataloader refactoring

* [Model] Chem model zoo minor change

* [Model] Chem Model Zoo 🔥 remove old samples

* [Model Zoo] molecule regression minor change

* Fix dataset import

* Fix dataset import

* [Model Zoo] molecule regression test set

* [Model Zoo] molecule prediction MPNN model hyperparameter tuning

* [Model Zoo] molecule prediction  mpnn performance update
parent 2c234118
......@@ -41,8 +41,52 @@ match exact results for this model, please use the pre-trained model as in the u
To customize your own dataset, see the instructions
[here](https://github.com/dmlc/dgl/tree/master/python/dgl/data/chem).
### References
## Regression
Regression tasks require assigning continuous labels to a molecule, e.g. molecular energy.
###Dataset
- **Alchemy**. The [Alchemy Dataset](https://alchemy.tencent.com/) is introduced by Tencent Quantum Lab to facilitate the development of new machine learning models useful for chemistry and materials science.
The dataset lists 12 quantum mechanical properties of 130,000+ organic molecules comprising up to 12 heavy atoms (C, N, O, S, F and Cl), sampled from the [GDBMedChem](http://gdb.unibe.ch/downloads/) database.
These properties have been calculated using the open-source computational chemistry program Python-based Simulation of Chemistry Framework ([PySCF](https://github.com/pyscf/pyscf)).
The Alchemy dataset expands on the volume and diversity of existing molecular datasets such as QM9.
### Models
- **SchNet**: SchNet is a novel deep learning architecture modeling quantum interactions in molecules which utilize the continuous-filter convolutional layers [3].
- **Multilevel Graph Convolutional neural Network**: Multilevel Graph Convolutional neural Network (MGCN) is a well-designed hierarchical graph neural network directly extracts features from the conformation and spatial information followed by the multilevel interactions [4].
- **Message Passing Neural Network**: Message Passing Neural Network (MPNN) is a well-designed network with edge network (enn) as front end and us Set2Set to output prediction [5].
### Usage
```py
python regression.py --model sch --epoch 200
```
The model option must be one of 'sch', 'mgcn' or 'mpnn'.
### Performance
#### Alchemy
|Model |Mean Absolute Error (MAE)|
|-------------|-------------------------|
|SchNet[3] |0.065|
|MGCN[4] |0.050|
|MPNN[5] |0.056|
## References
[1] Wu et al. (2017) MoleculeNet: a benchmark for molecular machine learning. *Chemical Science* 9, 513-530.
[2] Kipf et al. (2017) Semi-Supervised Classification with Graph Convolutional Networks.
*The International Conference on Learning Representations (ICLR)*.
[3] Schütt et al. (2017) SchNet: A continuous-filter convolutional neural network for modeling quantum interactions.
*Advances in Neural Information Processing Systems (NeurIPS)*, 992-1002.
[4] Lu et al. (2019) Molecular Property Prediction: A Multilevel Quantum Interactions Modeling Perspective.
*The 33rd AAAI Conference on Artificial Intelligence*.
[5] Gilmer et al. (2017) Neural Message Passing for Quantum Chemistry. *Proceedings of the 34th International Conference on
Machine Learning*, JMLR. 1263-1272.
......@@ -5,27 +5,39 @@
import argparse
import torch as th
import torch.nn as nn
from sch import SchNetModel
from mgcn import MGCNModel
from dgl.data.chem import alchemy
from dgl import model_zoo
from torch.utils.data import DataLoader
from Alchemy_dataset import TencentAlchemyDataset, batcher
# from Alchemy_dataset import TencentAlchemyDataset, batcher
def train(model="sch", epochs=80, device=th.device("cpu")):
def train(model="sch",
epochs=80,
device=th.device("cpu"),
training_set_size=0.8):
print("start")
alchemy_dataset = TencentAlchemyDataset()
alchemy_loader = DataLoader(dataset=alchemy_dataset,
batch_size=20,
collate_fn=batcher(),
shuffle=False,
num_workers=0)
alchemy_dataset = alchemy.TencentAlchemyDataset()
train_set, test_set = alchemy_dataset.split(train_size=0.8)
train_loader = DataLoader(dataset=train_set,
batch_size=20,
collate_fn=alchemy.batcher(),
shuffle=False,
num_workers=0)
test_loader = DataLoader(dataset=test_set,
batch_size=20,
collate_fn=alchemy.batcher(),
shuffle=False,
num_workers=0)
if model == "sch":
model = SchNetModel(norm=True, output_dim=12)
model = model_zoo.chem.SchNetModel(norm=True, output_dim=12)
model.set_mean_std(alchemy_dataset.mean, alchemy_dataset.std, device)
elif model == "mgcn":
model = MGCNModel(norm=True, output_dim=12)
model = model_zoo.chem.MGCNModel(norm=True, output_dim=12)
model.set_mean_std(alchemy_dataset.mean, alchemy_dataset.std, device)
elif model == "mpnn":
model = model_zoo.chem.MPNNModel(output_dim=12)
model.set_mean_std(alchemy_dataset.mean, alchemy_dataset.std, device)
model.to(device)
loss_fn = nn.MSELoss()
......@@ -37,7 +49,7 @@ def train(model="sch", epochs=80, device=th.device("cpu")):
w_loss, w_mae = 0, 0
model.train()
for idx, batch in enumerate(alchemy_loader):
for idx, batch in enumerate(train_loader):
batch.graph.to(device)
batch.label = batch.label.to(device)
......@@ -53,17 +65,38 @@ def train(model="sch", epochs=80, device=th.device("cpu")):
w_loss += loss.detach().item()
w_mae /= idx + 1
print("Epoch {:2d}, loss: {:.7f}, mae: {:.7f}".format(
print("Epoch {:2d}, loss: {:.7f}, MAE: {:.7f}".format(
epoch, w_loss, w_mae))
w_loss, w_mae = 0, 0
model.eval()
for idx, batch in enumerate(test_loader):
batch.graph.to(device)
batch.label = batch.label.to(device)
res = model(batch.graph)
mae = MAE_fn(res, batch.label)
w_mae += mae.detach().item()
w_loss += loss.detach().item()
w_mae /= idx + 1
print("MAE (test set): {:.7f}".format(w_mae))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-M",
"--model",
help="model name (sch or mgcn)",
help="model name (sch, mgcn, mpnn)",
choices=['sch', 'mgcn', 'mpnn'],
default="sch")
parser.add_argument("--epochs", help="number of epochs", default=250)
parser.add_argument("--epochs",
help="number of epochs",
default=250,
type=int)
device = th.device('cuda:0' if th.cuda.is_available() else 'cpu')
args = parser.parse_args()
train(args.model, int(args.epochs), device)
assert args.model in ['sch', 'mgcn',
'mpnn'], "model name must be sch, mgcn or mpnn"
train(args.model, args.epochs, device)
SchNet & MGCN
============
- K.T. Schütt. P.-J. Kindermans, H. E. Sauceda, S. Chmiela, A. Tkatchenko, K.-R. Müller.
SchNet: A continuous-filter convolutional neural network for modeling quantum interactions. Advances in Neural Information Processing Systems 30, pp. 992-1002 (2017) [link](http://papers.nips.cc/paper/6700-schnet-a-continuous-filter-convolutional-neural-network-for-modeling-quantum-interactions)
- C. Lu, Q. Liu, C. Wang, Z. Huang, P. Lin, L. He, Molecular Property Prediction: A Multilevel Quantum Interactions Modeling Perspective. The 33rd AAAI Conference on Artificial Intelligence (2019) [link](https://arxiv.org/abs/1906.11081)
Dependencies
------------
- PyTorch 1.0+
- dgl 0.3+
- RDKit (If use [Alchemy dataset](https://arxiv.org/abs/1906.09427).)
Usage
-----
Example usage on Alchemy dataset:
+ SchNet: expected MAE 0.065
```py
python train.py --model sch --epochs 250
```
+ MGCN: expected MAE 0.050
```py
python train.py --model mgcn --epochs 250
```
*With Tesla V100, SchNet takes 80s/epoch and MGCN takes 110s/epoch.*
Codes
-----
The folder contains five python files:
- `sch.py` the implementation of SchNet model.
- `mgcn.py` the implementation of Multilevel Graph Convolutional Network(MGCN).
- `layers.py` layers contained by two models above.
- `Alchemy_dataset.py` example dataloader of [Tencent Alchemy](https://alchemy.tencent.com) dataset.
- `train.py` example training code.
Modify `train.py` to switch between different implementations.
......@@ -11,14 +11,20 @@ from .reddit import RedditDataset
from .ppi import PPIDataset
from .tu import TUDataset
from .gindt import GINDataset
from .chem import Tox21
# from .chem import Tox21, alchemy
def register_data_args(parser):
parser.add_argument("--dataset", type=str, required=False,
help="The input dataset. Can be cora, citeseer, pubmed, syn(synthetic dataset) or reddit")
parser.add_argument(
"--dataset",
type=str,
required=False,
help=
"The input dataset. Can be cora, citeseer, pubmed, syn(synthetic dataset) or reddit"
)
citegrh.register_args(parser)
def load_data(args):
if args.dataset == 'cora':
return citegrh.load_cora()
......
......@@ -2,22 +2,28 @@
"""Example dataloader of Tencent Alchemy Dataset
https://alchemy.tencent.com/
"""
import os
import zipfile
import os.path as osp
from rdkit import Chem
from rdkit.Chem import ChemicalFeatures
from rdkit import RDConfig
import zipfile
import dgl
from dgl.data.utils import download
import pickle
import torch
from collections import defaultdict
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import pathlib
import pandas as pd
import numpy as np
_urls = {'Alchemy': 'https://alchemy.tencent.com/data/'}
import pathlib
from ..utils import get_download_dir, download, _get_dgl_url
try:
import pandas as pd
from rdkit import Chem
from rdkit.Chem import ChemicalFeatures
from rdkit import RDConfig
except ImportError:
pass
_urls = {'Alchemy': 'https://alchemy.tencent.com/data/dgl/'}
class AlchemyBatcher:
......@@ -155,18 +161,16 @@ class TencentAlchemyDataset(Dataset):
return bond_feats_dict
def sdf_to_dgl(self, sdf_file, self_loop=False):
def mol_to_dgl(self, mol, self_loop=False):
"""
Read sdf file and convert to dgl_graph
Convert RDKit molecule object to DGLGraph
Args:
sdf_file: path of sdf file
mol: Chem.rdchem.Mol read from sdf
self_loop: Whetaher to add self loop
Returns:
g: DGLGraph
l: related labels
"""
sdf = open(str(sdf_file)).read()
mol = Chem.MolFromMolBlock(sdf, removeHs=False)
g = dgl.DGLGraph()
......@@ -175,13 +179,6 @@ class TencentAlchemyDataset(Dataset):
atom_feats = self.alchemy_nodes(mol)
g.add_nodes(num=num_atoms, data=atom_feats)
# add edges
# The model we were interested assumes a complete graph.
# If this is not the case, do the code below instead
#
# for bond in mol.GetBonds():
# u = bond.GetBeginAtomIdx()
# v = bond.GetEndAtomIdx()
if self_loop:
g.add_edges(
[i for i in range(num_atoms) for j in range(num_atoms)],
......@@ -196,46 +193,69 @@ class TencentAlchemyDataset(Dataset):
bond_feats = self.alchemy_edges(mol, self_loop)
g.edata.update(bond_feats)
# for val/test set, labels are molecule ID
l = torch.FloatTensor(self.target.loc[int(sdf_file.stem)].tolist()) \
if self.mode == 'dev' else torch.LongTensor([int(sdf_file.stem)])
return (g, l)
return g
def __init__(self, mode='dev', transform=None):
def __init__(self, mode='dev', transform=None, from_raw=False):
assert mode in ['dev', 'valid',
'test'], "mode should be dev/valid/test"
self.mode = mode
self.transform = transform
self.file_dir = pathlib.Path('./Alchemy_data', mode)
self.zip_file_path = pathlib.Path('./Alchemy_data', '%s.zip' % mode)
download(_urls['Alchemy'] + "%s.zip" % mode,
# Construct the dgl graph from raw data or use the preprocessed data directly
self.from_raw = from_raw
file_dir = osp.join(get_download_dir(), './Alchemy_data')
if not from_raw:
file_name = "%s_processed" % (mode)
else:
file_name = "%s_single_sdf" % (mode)
self.file_dir = pathlib.Path(file_dir, file_name)
self.zip_file_path = pathlib.Path(file_dir, file_name + '.zip')
download(_urls['Alchemy'] + file_name + '.zip',
path=str(self.zip_file_path))
if not os.path.exists(str(self.file_dir)):
archive = zipfile.ZipFile(str(self.zip_file_path))
archive.extractall('./Alchemy_data')
archive = zipfile.ZipFile(self.zip_file_path)
archive.extractall(file_dir)
archive.close()
self._load()
def _load(self):
if self.mode == 'dev':
target_file = pathlib.Path(self.file_dir, "dev_target.csv")
self.target = pd.read_csv(target_file,
index_col=0,
usecols=[
'gdb_idx',
] +
['property_%d' % x for x in range(12)])
self.target = self.target[['property_%d' % x for x in range(12)]]
sdf_dir = pathlib.Path(self.file_dir, "sdf")
self.graphs, self.labels = [], []
for sdf_file in sdf_dir.glob("**/*.sdf"):
result = self.sdf_to_dgl(sdf_file)
if result is None:
continue
self.graphs.append(result[0])
self.labels.append(result[1])
if not self.from_raw:
with open(osp.join(self.file_dir, "dev_graphs.pkl"),
"rb") as f:
self.graphs = pickle.load(f)
with open(osp.join(self.file_dir, "dev_labels.pkl"),
"rb") as f:
self.labels = pickle.load(f)
else:
target_file = pathlib.Path(self.file_dir, "dev_target.csv")
self.target = pd.read_csv(
target_file,
index_col=0,
usecols=[
'gdb_idx',
] + ['property_%d' % x for x in range(12)])
self.target = self.target[[
'property_%d' % x for x in range(12)
]]
self.graphs, self.labels = [], []
sdf_dir = pathlib.Path(self.file_dir, "sdf")
supp = Chem.SDMolSupplier(
osp.join(self.file_dir, self.mode + ".sdf"))
cnt = 0
for sdf, label in zip(supp, self.target.iterrows()):
graph = self.mol_to_dgl(sdf)
cnt += 1
self.graphs.append(graph)
label = torch.FloatTensor(label[1].tolist())
self.labels.append(label)
self.normalize()
print(len(self.graphs), "loaded!")
......@@ -257,24 +277,32 @@ class TencentAlchemyDataset(Dataset):
g = self.transform(g)
return g, l
def split(self, train_size=0.8):
"""
Split the dataset into two AlchemySubset for train&test.
"""
assert 0 < train_size < 1
train_num = int(len(self.graphs) * train_size)
train_set = AlchemySubset(self.graphs[:train_num],
self.labels[:train_num], self.mean, self.std,
self.transform)
test_set = AlchemySubset(self.graphs[train_num:],
self.labels[train_num:], self.mean, self.std,
self.transform)
return train_set, test_set
class AlchemySubset(TencentAlchemyDataset):
"""
Sub-dataset split from TencentAlchemyDataset.
Used to construct the training & test set.
"""
if __name__ == '__main__':
alchemy_dataset = TencentAlchemyDataset()
device = torch.device('cpu')
# To speed up the training with multi-process data loader,
# the num_workers could be set to > 1.
alchemy_loader = DataLoader(dataset=alchemy_dataset,
batch_size=20,
collate_fn=batcher(),
shuffle=False,
num_workers=0)
for step, batch in enumerate(alchemy_loader):
print("bs =", batch.graph.batch_size)
print('feature size =', batch.graph.ndata['n_feat'].size())
print('pos size =', batch.graph.ndata['pos'].size())
print('edge feature size =', batch.graph.edata['e_feat'].size())
print('edge distance size =', batch.graph.edata['distance'].size())
print('label size=', batch.label.size())
print(dgl.sum_nodes(batch.graph, 'n_feat').size())
break
def __init__(self, graphs, labels, mean=0, std=1, transform=None):
self.graphs = graphs
self.labels = labels
self.mean = mean
self.std = std
self.transform = transform
......@@ -40,9 +40,16 @@ mostly developed based on molecule fingerprints.
Graph neural networks make it possible for a data-driven representation of molecules out of the atoms, bonds and
molecular graph topology, which may be viewed as a learned fingerprint [3].
### Models
### Models
- **Graph Convolutional Network**: Graph Convolutional Networks (GCN) have been one of the most popular graph neural
networks and they can be easily extended for graph level prediction.
networks and they can be easily extended for graph level prediction.
- **SchNet**: SchNet is a novel deep learning architecture modeling quantum interactions in molecules which utilize
the continuous-filter convolutional layers [4].
- **Multilevel Graph Convolutional neural Network**: Multilevel Graph Convolutional neural Network (MGCN) is a well-designed
hierarchical graph neural network directly extracts features from the conformation and spatial information followed
by the multilevel interactions [5].
- **Message Passing Neural Network**: Message Passing Neural Network (MPNN) is a well-designed network with edge network (enn) as front end and us Set2Set to output prediction [6].
## Generative Models
......@@ -56,7 +63,7 @@ working with alternative representations like SMILES.
Generative models are known to be difficult for evaluation. [GuacaMol](https://github.com/BenevolentAI/guacamol) and
[MOSES](https://github.com/molecularsets/moses) have been two recent efforts to benchmark generative models. There
are also two accompanying review papers that are well written [4], [5].
are also two accompanying review papers that are well written [7], [8].
### Models
- **Deep Generative Models of Graphs (DGMG)**: A very general framework for graph distribution learning by progressively
......@@ -72,7 +79,16 @@ adding atoms and bonds.
[3] Duvenaud et al. (2015) Convolutional networks on graphs for learning molecular fingerprints. *Advances in neural
information processing systems (NeurIPS)*, 2224-2232.
[4] Brown et al. (2019) GuacaMol: Benchmarking Models for de Novo Molecular Design. *J. Chem. Inf. Model*, 2019, 59, 3,
[4] Schütt et al. (2017) SchNet: A continuous-filter convolutional neural network for modeling quantum interactions.
*Advances in Neural Information Processing Systems (NeurIPS)*, 992-1002.
[5] Lu et al. Molecular Property Prediction: A Multilevel Quantum Interactions Modeling Perspective.
*The 33rd AAAI Conference on Artificial Intelligence*.
[6] Gilmer et al. (2017) Neural Message Passing for Quantum Chemistry. *Proceedings of the 34th International Conference on
Machine Learning* JMLR. 1263-1272.
[7] Brown et al. (2019) GuacaMol: Benchmarking Models for de Novo Molecular Design. *J. Chem. Inf. Model*, 2019, 59, 3,
1096-1108.
[5] Polykovskiy et al. (2019) Molecular Sets (MOSES): A Benchmarking Platform for Molecular Generation Models. *arXiv*.
[8] Polykovskiy et al. (2019) Molecular Sets (MOSES): A Benchmarking Platform for Molecular Generation Models. *arXiv*.
......@@ -2,5 +2,8 @@
"""Model Zoo Package"""
from .gcn import GCNClassifier
from .sch import SchNetModel
from .mgcn import MGCNModel
from .mpnn import MPNNModel
from .dgmg import DGMG
from .pretrain import load_pretrained
# -*- coding: utf-8 -*-
# pylint: disable=C0103, E1101, C0111
"""
The implementation of neural network layers used in SchNet and MGCN.
"""
import torch as th
import numpy as np
import torch.nn as nn
import dgl.function as fn
from torch.nn import Softplus
import numpy as np
import dgl.function as fn
class AtomEmbedding(nn.Module):
"""
Convert the atom(node) list to atom embeddings.
The atom with the same element share the same initial embeddding.
The atoms with the same element share the same initial embeddding.
"""
def __init__(self, dim=128, type_num=100, pre_train=None):
......@@ -66,7 +71,7 @@ class EdgeEmbedding(nn.Module):
To map a pair of nodes to one number, we use an unordered pairing function here
See more detail in this disscussion:
https://math.stackexchange.com/questions/23503/create-unique-number-from-2-numbers
Note that, the edge_num should larger than the square of maximum atomic number
Note that, the edge_num should be larger than the square of maximum atomic number
in the dataset.
"""
atom_type_x = edges.src["node_type"]
......@@ -95,8 +100,8 @@ class ShiftSoftplus(Softplus):
self.shift = shift
self.softplus = Softplus(beta, threshold)
def forward(self, input):
return self.softplus(input) - np.log(float(self.shift))
def forward(self, x):
return self.softplus(x) - np.log(float(self.shift))
class RBFLayer(nn.Module):
......@@ -124,6 +129,7 @@ class RBFLayer(nn.Module):
self._gap = centers[1] - centers[0]
def dis2rbf(self, edges):
"""Convert distance matrix to RBF tensor."""
dist = edges.data["distance"]
radial = dist - self.centers
coef = -1 / self._gap
......@@ -163,6 +169,7 @@ class CFConv(nn.Module):
self.activation = act
def update_edge(self, edges):
"""Update the edge features with two FC layers."""
rbf = edges.data["rbf"]
h = self.linear_layer1(rbf)
h = self.activation(h)
......@@ -170,6 +177,7 @@ class CFConv(nn.Module):
return {"h": h}
def forward(self, g):
"""Forward CFConv"""
g.apply_edges(self.update_edge)
g.update_all(message_func=fn.u_mul_e('new_node', 'h', 'neighbor_info'),
reduce_func=fn.sum('neighbor_info', 'new_node'))
......@@ -191,7 +199,7 @@ class Interaction(nn.Module):
self.node_layer3 = nn.Linear(dim, dim)
def forward(self, g):
"""Interaction layer forward."""
g.ndata["new_node"] = self.node_layer1(g.ndata["node"])
cf_node = self.cfconv(g)
cf_node_1 = self.node_layer2(cf_node)
......@@ -203,8 +211,8 @@ class Interaction(nn.Module):
class VEConv(nn.Module):
"""
The Vertex-Edge convolution layer in MGCN which take edge & vertex features
in consideratoin at the same time.
The Vertex-Edge convolution layer in MGCN which takes edge & vertex features
in consideration at the same time.
"""
def __init__(self, rbf_dim, dim=64, update_edge=True):
......@@ -226,6 +234,7 @@ class VEConv(nn.Module):
self.activation = nn.Softplus(beta=0.5, threshold=14)
def update_rbf(self, edges):
"""Update the RBF features."""
rbf = edges.data["rbf"]
h = self.linear_layer1(rbf)
h = self.activation(h)
......@@ -233,22 +242,25 @@ class VEConv(nn.Module):
return {"h": h}
def update_edge(self, edges):
"""Update the edge features."""
edge_f = edges.data["edge_f"]
h = self.linear_layer3(edge_f)
return {"edge_f": h}
def forward(self, g):
"""VEConv layer forward"""
g.apply_edges(self.update_rbf)
if self._update_edge:
g.apply_edges(self.update_edge)
g.update_all(
message_func=[
fn.u_mul_e("new_node", "h", "m_0"),
fn.copy_e("edge_f", "m_1")],
reduce_func=[
fn.sum("m_0", "new_node_0"),
fn.sum("m_1", "new_node_1")])
g.update_all(message_func=[
fn.u_mul_e("new_node", "h", "m_0"),
fn.copy_e("edge_f", "m_1")
],
reduce_func=[
fn.sum("m_0", "new_node_0"),
fn.sum("m_1", "new_node_1")
])
g.ndata["new_node"] = g.ndata.pop("new_node_0") + g.ndata.pop(
"new_node_1")
......@@ -274,6 +286,13 @@ class MultiLevelInteraction(nn.Module):
self.node_layer3 = nn.Linear(dim, dim)
def forward(self, g, level=1):
"""
MultiLevel Interaction Layer forward.
Args:
g: DGLGraph
level: current level of this layer
"""
g.ndata["new_node"] = self.node_layer1(g.ndata["node_%s" %
(level - 1)])
node = self.conv_layer(g)
......
# -*- coding:utf-8 -*-
# pylint: disable=C0103, C0111, W0621
"""Implementation of MGCN model"""
import dgl
import torch as th
import torch.nn as nn
from layers import AtomEmbedding, RBFLayer, EdgeEmbedding, \
from .layers import AtomEmbedding, RBFLayer, EdgeEmbedding, \
MultiLevelInteraction
......@@ -104,14 +106,3 @@ class MGCNModel(nn.Module):
"res"] * self.std_per_node + self.mean_per_node
res = dgl.sum_nodes(g, "res")
return res
if __name__ == "__main__":
g = dgl.DGLGraph()
g.add_nodes(2)
g.add_edges([0, 0, 1, 1], [1, 0, 1, 0])
g.edata["distance"] = th.tensor([1.0, 3.0, 2.0, 4.0]).reshape(-1, 1)
g.ndata["node_type"] = th.LongTensor([1, 2])
model = MGCNModel(dim=2, edge_dim=2)
node = model(g)
print(node)
#!/usr/bin/env python
# coding: utf-8
# pylint: disable=C0103, C0111, E1101, W0612
"""Implementation of MPNN model."""
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Parameter
import dgl.function as fn
import dgl.nn.pytorch as dgl_nn
class NNConvLayer(nn.Module):
"""
MPNN Conv Layer from Section.5 in the paper "Neural Message Passing for Quantum Chemistry."
"""
def __init__(self,
in_channels,
out_channels,
edge_net,
root_weight=True,
bias=True):
"""
Args:
in_channels: number of input channels
out_channels: number of output channels
edge_net: the network modules process the edge info
root_weight: whether add the root node feature to output
bias: whether add bias to the output
"""
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.edge_net = edge_net
if root_weight:
self.root = Parameter(torch.Tensor(in_channels, out_channels))
else:
self.register_parameter('root', None)
if bias:
self.bias = Parameter(torch.Tensor(out_channels))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
if self.root is not None:
nn.init.xavier_normal_(self.root.data, gain=1.414)
if self.bias is not None:
self.bias.data.zero_()
for m in self.edge_net.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_normal_(m.weight.data, gain=1.414)
def message(self, edges):
return {
'm':
torch.matmul(edges.src['h'].unsqueeze(1),
edges.data['w']).squeeze(1)
}
def apply_node_func(self, nodes):
aggr_out = nodes.data['aggr_out']
if self.root is not None:
aggr_out = torch.mm(nodes.data['h'], self.root) + aggr_out
if self.bias is not None:
aggr_out = aggr_out + self.bias
return {'h': aggr_out}
def forward(self, g, h, e):
"""MPNN Conv layer forward."""
h = h.unsqueeze(-1) if h.dim() == 1 else h
e = e.unsqueeze(-1) if e.dim() == 1 else e
g.ndata['h'] = h
g.edata['w'] = self.edge_net(e).view(-1, self.in_channels,
self.out_channels)
g.update_all(self.message, fn.sum("m", "aggr_out"),
self.apply_node_func)
return g.ndata.pop('h')
class MPNNModel(nn.Module):
"""
MPNN model from:
Gilmer, Justin, et al.
Neural message passing for quantum chemistry.
"""
def __init__(self,
node_input_dim=15,
edge_input_dim=5,
output_dim=12,
node_hidden_dim=64,
edge_hidden_dim=128,
num_step_message_passing=6,
num_step_set2set=6,
num_layer_set2set=3):
"""model parameters setting
Args:
node_input_dim: dimension of input node feature
edge_input_dim: dimension of input edge feature
output_dim: dimension of prediction
node_hidden_dim: dimension of node feature in hidden layers
edge_hidden_dim: dimension of edge feature in hidden layers
num_step_message_passing: number of message passing steps
num_step_set2set: number of set2set steps
num_layer_ste2set: number of set2set layers
"""
super().__init__()
self.name = "MPNN"
self.num_step_message_passing = num_step_message_passing
self.lin0 = nn.Linear(node_input_dim, node_hidden_dim)
edge_network = nn.Sequential(
nn.Linear(edge_input_dim, edge_hidden_dim), nn.ReLU(),
nn.Linear(edge_hidden_dim, node_hidden_dim * node_hidden_dim))
self.conv = NNConvLayer(in_channels=node_hidden_dim,
out_channels=node_hidden_dim,
edge_net=edge_network,
root_weight=False)
self.gru = nn.GRU(node_hidden_dim, node_hidden_dim)
self.set2set = dgl_nn.glob.Set2Set(node_hidden_dim, num_step_set2set,
num_layer_set2set)
self.lin1 = nn.Linear(2 * node_hidden_dim, node_hidden_dim)
self.lin2 = nn.Linear(node_hidden_dim, output_dim)
def forward(self, g):
h = g.ndata['n_feat']
out = F.relu(self.lin0(h))
h = out.unsqueeze(0)
for i in range(self.num_step_message_passing):
m = F.relu(self.conv(g, out, g.edata['e_feat']))
out, h = self.gru(m.unsqueeze(0), h)
out = out.squeeze(0)
out = self.set2set(out, g)
out = F.relu(self.lin1(out))
out = self.lin2(out)
return out
# -*- coding:utf-8 -*-
# pylint: disable=C0103, C0111, W0621
"""Implementation of SchNet model."""
import dgl
import torch as th
import torch.nn as nn
from layers import AtomEmbedding, Interaction, ShiftSoftplus, RBFLayer
from .layers import AtomEmbedding, Interaction, ShiftSoftplus, RBFLayer
class SchNetModel(nn.Module):
......@@ -62,8 +64,7 @@ class SchNetModel(nn.Module):
self.std_per_atom = th.tensor(std, device=device)
def forward(self, g):
"""g is the DGL.graph"""
"""g is the DGLGraph"""
self.embedding_layer(g)
if self.atom_ref is not None:
self.e0(g, "e0")
......@@ -84,13 +85,3 @@ class SchNetModel(nn.Module):
"res"] * self.std_per_atom + self.mean_per_atom
res = dgl.sum_nodes(g, "res")
return res
if __name__ == "__main__":
g = dgl.DGLGraph()
g.add_nodes(2)
g.add_edges([0, 0, 1, 1], [1, 0, 1, 0])
g.edata["distance"] = th.tensor([1.0, 3.0, 2.0, 4.0]).reshape(-1, 1)
g.ndata["node_type"] = th.LongTensor([1, 2])
model = SchNetModel(dim=2)
atom = model(g)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment