Unverified Commit 189c2c09 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[Model Zoo] Refactor Model Zoo for Chemistry (#839)

* Update

* Update

* Update

* Update fix

* Update

* Update

* Refactor

* Update

* Update

* Update

* Update

* Update

* Update

* Fix style
parent 5b417683
# -*- coding:utf-8 -*-
# pylint: disable=C0103, C0111, W0621
"""Implementation of MGCN model"""
import torch as th
import torch
import torch.nn as nn
from .layers import AtomEmbedding, RBFLayer, EdgeEmbedding, \
MultiLevelInteraction
from ...batched_graph import sum_nodes
from ...nn.pytorch import SumPooling
class MGCNModel(nn.Module):
"""
MGCN from `Molecular Property Prediction: A Multilevel
`Molecular Property Prediction: A Multilevel
Quantum Interactions Modeling Perspective <https://arxiv.org/abs/1906.11081>`__
Parameters
----------
dim : int
Dimension of feature maps, default to be 128.
out_put_dim: int
Number of target properties to predict, default to be 1.
edge_dim : int
Dimension of edge feature, default to be 128.
cutoff : float
The maximum distance between nodes, default to be 5.0.
Size for embeddings, default to be 128.
width : int
Width in the RBF layer, default to be 1.
cutoff : float
The maximum distance between nodes, default to be 5.0.
edge_dim : int
Size for edge embedding, default to be 128.
out_put_dim: int
Number of target properties to predict, default to be 1.
n_conv : int
Number of convolutional layers, default to be 3.
norm : bool
......@@ -39,16 +39,16 @@ class MGCNModel(nn.Module):
"""
def __init__(self,
dim=128,
output_dim=1,
edge_dim=128,
cutoff=5.0,
width=1,
cutoff=5.0,
edge_dim=128,
output_dim=1,
n_conv=3,
norm=False,
atom_ref=None,
pre_train=None):
super(MGCNModel, self).__init__()
self.name = "MGCN"
self._dim = dim
self.output_dim = output_dim
self.edge_dim = edge_dim
......@@ -58,27 +58,29 @@ class MGCNModel(nn.Module):
self.atom_ref = atom_ref
self.norm = norm
self.activation = nn.Softplus(beta=1, threshold=20)
if atom_ref is not None:
self.e0 = AtomEmbedding(1, pre_train=atom_ref)
if pre_train is None:
self.embedding_layer = AtomEmbedding(dim)
else:
self.embedding_layer = AtomEmbedding(pre_train=pre_train)
self.rbf_layer = RBFLayer(0, cutoff, width)
self.edge_embedding_layer = EdgeEmbedding(dim=edge_dim)
self.rbf_layer = RBFLayer(0, cutoff, width)
if atom_ref is not None:
self.e0 = AtomEmbedding(1, pre_train=atom_ref)
self.conv_layers = nn.ModuleList([
MultiLevelInteraction(self.rbf_layer._fan_out, dim)
for i in range(n_conv)
])
self.node_dense_layer1 = nn.Linear(dim * (self.n_conv + 1), 64)
self.node_dense_layer2 = nn.Linear(64, output_dim)
self.out_project = nn.Sequential(
nn.Linear(dim * (self.n_conv + 1), 64),
nn.Softplus(beta=1, threshold=20),
nn.Linear(64, output_dim)
)
self.readout = SumPooling()
def set_mean_std(self, mean, std, device):
def set_mean_std(self, mean, std, device="cpu"):
"""Set the mean and std of atom representations for normalization.
Parameters
......@@ -90,46 +92,45 @@ class MGCNModel(nn.Module):
device : str or torch.device
Device for storing the mean and std
"""
self.mean_per_node = th.tensor(mean, device=device)
self.std_per_node = th.tensor(std, device=device)
self.mean_per_node = torch.tensor(mean, device=device)
self.std_per_node = torch.tensor(std, device=device)
def forward(self, g):
def forward(self, g, atom_types, edge_distances):
"""Predict molecule labels
Parameters
----------
g : DGLGraph
Input DGLGraph for molecule(s)
atom_types : int64 tensor of shape (B1)
Types for atoms in the graph(s), B1 for the number of atoms.
edge_distances : float32 tensor of shape (B2, 1)
Edge distances, B2 for the number of edges.
Returns
-------
res : Predicted labels
prediction : float32 tensor of shape (B, output_dim)
Model prediction for the batch of graphs, B for the number
of graphs, output_dim for the prediction size.
"""
self.embedding_layer(g, "node_0")
if self.atom_ref is not None:
self.e0(g, "e0")
self.rbf_layer(g)
self.edge_embedding_layer(g)
h = self.embedding_layer(atom_types)
e = self.edge_embedding_layer(g, atom_types)
rbf_out = self.rbf_layer(edge_distances)
all_layer_h = [h]
for idx in range(self.n_conv):
self.conv_layers[idx](g, idx + 1)
node_embeddings = tuple(g.ndata["node_%d" % (i)]
for i in range(self.n_conv + 1))
g.ndata["node"] = th.cat(node_embeddings, 1)
h, e = self.conv_layers[idx](g, h, e, rbf_out)
all_layer_h.append(h)
# concat multilevel representations
node = self.node_dense_layer1(g.ndata["node"])
node = self.activation(node)
res = self.node_dense_layer2(node)
g.ndata["res"] = res
h = torch.cat(all_layer_h, dim=1)
h = self.out_project(h)
if self.atom_ref is not None:
g.ndata["res"] = g.ndata["res"] + g.ndata["e0"]
h_ref = self.e0(atom_types)
h = h + h_ref
if self.norm:
g.ndata["res"] = g.ndata[
"res"] * self.std_per_node + self.mean_per_node
res = sum_nodes(g, "res")
return res
h = h * self.std_per_node + self.mean_per_node
return self.readout(g, h)
......@@ -2,134 +2,10 @@
# coding: utf-8
# pylint: disable=C0103, C0111, E1101, W0612
"""Implementation of MPNN model."""
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Parameter
from ... import function as fn
from ...nn.pytorch import Set2Set
class NNConvLayer(nn.Module):
"""
MPNN Conv Layer from Section 5 of
`Neural Message Passing for Quantum Chemistry <https://arxiv.org/abs/1704.01212>`__
Parameters
----------
in_channels : int
Number of input channels
out_channels : int
Number of output channels
edge_net : Module processing edge information
root_weight : bool
Whether to add the root node feature to output
bias : bool
Whether to add bias to the output
"""
def __init__(self,
in_channels,
out_channels,
edge_net,
root_weight=True,
bias=True):
super(NNConvLayer, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.edge_net = edge_net
if root_weight:
self.root = Parameter(torch.Tensor(in_channels, out_channels))
else:
self.register_parameter('root', None)
if bias:
self.bias = Parameter(torch.Tensor(out_channels))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
"""Reinitialize model parameters"""
if self.root is not None:
nn.init.xavier_normal_(self.root.data, gain=1.414)
if self.bias is not None:
self.bias.data.zero_()
for m in self.edge_net.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_normal_(m.weight.data, gain=1.414)
def message(self, edges):
"""Function for computing messages from source nodes
Parameters
----------
edges : EdgeBatch
Edges over which we want to send messages
Returns
-------
dict
Stores message in key 'm'
"""
return {
'm':
torch.matmul(edges.src['h'].unsqueeze(1),
edges.data['w']).squeeze(1)
}
def apply_node_func(self, nodes):
"""Function for updating node features directly
Parameters
----------
nodes : NodeBatch
Returns
-------
dict
Stores updated node features in 'h'
"""
aggr_out = nodes.data['aggr_out']
if self.root is not None:
aggr_out = torch.mm(nodes.data['h'], self.root) + aggr_out
if self.bias is not None:
aggr_out = aggr_out + self.bias
return {'h': aggr_out}
def forward(self, g, h, e):
"""Propagate messages and aggregate results for updating
atom representations
Parameters
----------
g : DGLGraph
DGLgraph(s) for molecules
h : tensor
Input atom representations
e : tensor
Input bond representations
Returns
-------
tensor
Aggregated atom information
"""
h = h.unsqueeze(-1) if h.dim() == 1 else h
e = e.unsqueeze(-1) if e.dim() == 1 else e
g.ndata['h'] = h
g.edata['w'] = self.edge_net(e).view(-1, self.in_channels,
self.out_channels)
g.update_all(self.message, fn.sum("m", "aggr_out"),
self.apply_node_func)
return g.ndata.pop('h')
from ...nn.pytorch import Set2Set, NNConv
class MPNNModel(nn.Module):
"""
......@@ -166,40 +42,44 @@ class MPNNModel(nn.Module):
num_layer_set2set=3):
super(MPNNModel, self).__init__()
self.name = "MPNN"
self.num_step_message_passing = num_step_message_passing
self.lin0 = nn.Linear(node_input_dim, node_hidden_dim)
edge_network = nn.Sequential(
nn.Linear(edge_input_dim, edge_hidden_dim), nn.ReLU(),
nn.Linear(edge_hidden_dim, node_hidden_dim * node_hidden_dim))
self.conv = NNConvLayer(in_channels=node_hidden_dim,
out_channels=node_hidden_dim,
edge_net=edge_network,
root_weight=False)
self.conv = NNConv(in_feats=node_hidden_dim,
out_feats=node_hidden_dim,
edge_func=edge_network,
aggregator_type='sum')
self.gru = nn.GRU(node_hidden_dim, node_hidden_dim)
self.set2set = Set2Set(node_hidden_dim, num_step_set2set, num_layer_set2set)
self.lin1 = nn.Linear(2 * node_hidden_dim, node_hidden_dim)
self.lin2 = nn.Linear(node_hidden_dim, output_dim)
def forward(self, g):
def forward(self, g, n_feat, e_feat):
"""Predict molecule labels
Parameters
----------
g : DGLGraph
Input DGLGraph for molecule(s)
n_feat : tensor of dtype float32 and shape (B1, D1)
Node features. B1 for number of nodes and D1 for
the node feature size.
e_feat : tensor of dtype float32 and shape (B2, D2)
Edge features. B2 for number of edges and D2 for
the edge feature size.
Returns
-------
res : Predicted labels
"""
h = g.ndata['n_feat']
out = F.relu(self.lin0(h))
h = out.unsqueeze(0)
out = F.relu(self.lin0(n_feat)) # (B1, H1)
h = out.unsqueeze(0) # (1, B1, H1)
for i in range(self.num_step_message_passing):
m = F.relu(self.conv(g, out, g.edata['e_feat']))
m = F.relu(self.conv(g, out, e_feat)) # (B1, H1)
out, h = self.gru(m.unsqueeze(0), h)
out = out.squeeze(0)
......
......@@ -7,7 +7,7 @@ from .classifiers import GCNClassifier, GATClassifier
from .dgmg import DGMG
from .mgcn import MGCNModel
from .mpnn import MPNNModel
from .sch import SchNetModel
from .schnet import SchNet
from ...data.utils import _get_dgl_url, download, get_download_dir
URL = {
......@@ -61,6 +61,19 @@ def load_pretrained(model_name, log=True):
Parameters
----------
model_name : str
Currently supported options include
* ``'GCN_Tox21'``
* ``'GAT_Tox21'``
* ``'MGCN_Alchemy'``
* ``'SCHNET_Alchemy'``
* ``'MPNN_Alchemy'``
* ``'DGMG_ChEMBL_canonical'``
* ``'DGMG_ChEMBL_random'``
* ``'DGMG_ZINC_canonical'``
* ``'DGMG_ZINC_random'``
* ``'JTNN_ZINC'``
log : bool
Whether to print progress for model loading
......@@ -103,7 +116,7 @@ def load_pretrained(model_name, log=True):
model = MGCNModel(norm=True, output_dim=12)
elif model_name == 'SCHNET_Alchemy':
model = SchNetModel(norm=True, output_dim=12)
model = SchNet(norm=True, output_dim=12)
elif model_name == 'MPNN_Alchemy':
model = MPNNModel(output_dim=12)
......
# -*- coding:utf-8 -*-
# pylint: disable=C0103, C0111, W0621
"""Implementation of SchNet model."""
import torch as th
import torch
import torch.nn as nn
from .layers import AtomEmbedding, Interaction, ShiftSoftplus, RBFLayer
from ...batched_graph import sum_nodes
from ...nn.pytorch import SumPooling
class SchNetModel(nn.Module):
class SchNet(nn.Module):
"""
`SchNet: A continuous-filter convolutional neural network for modeling
quantum interactions. (NIPS'2017) <https://arxiv.org/abs/1706.08566>`__
......@@ -16,15 +16,15 @@ class SchNetModel(nn.Module):
Parameters
----------
dim : int
Dimension of features, default to be 64
Size for atom embeddings, default to be 64.
cutoff : float
Radius cutoff for RBF, default to be 5.0
Radius cutoff for RBF, default to be 5.0.
output_dim : int
Dimension of prediction, default to be 1
Number of target properties to predict, default to be 1.
width : int
Width in RBF, default to 1
Width in RBF, default to 1.
n_conv : int
Number of conv (interaction) layers, default to be 1
Number of conv (interaction) layers, default to be 1.
norm : bool
Whether to normalize the output atom representations, default to be False.
atom_ref : Atom embeddings or None
......@@ -43,28 +43,32 @@ class SchNetModel(nn.Module):
norm=False,
atom_ref=None,
pre_train=None):
super().__init__()
self.name = "SchNet"
super(SchNet, self).__init__()
self._dim = dim
self.cutoff = cutoff
self.width = width
self.n_conv = n_conv
self.atom_ref = atom_ref
self.norm = norm
self.activation = ShiftSoftplus()
if atom_ref is not None:
self.e0 = AtomEmbedding(1, pre_train=atom_ref)
if pre_train is None:
self.embedding_layer = AtomEmbedding(dim)
else:
self.embedding_layer = AtomEmbedding(pre_train=pre_train)
self.rbf_layer = RBFLayer(0, cutoff, width)
self.conv_layers = nn.ModuleList(
[Interaction(self.rbf_layer._fan_out, dim) for i in range(n_conv)])
self.atom_dense_layer1 = nn.Linear(dim, 64)
self.atom_dense_layer2 = nn.Linear(64, output_dim)
[Interaction(self.rbf_layer._fan_out, dim) for _ in range(n_conv)])
self.atom_update = nn.Sequential(
nn.Linear(dim, 64),
ShiftSoftplus(),
nn.Linear(64, output_dim)
)
self.readout = SumPooling()
def set_mean_std(self, mean, std, device="cpu"):
"""Set the mean and std of atom representations for normalization.
......@@ -78,37 +82,38 @@ class SchNetModel(nn.Module):
device : str or torch.device
Device for storing the mean and std
"""
self.mean_per_atom = th.tensor(mean, device=device)
self.std_per_atom = th.tensor(std, device=device)
self.mean_per_node = torch.tensor(mean, device=device)
self.std_per_node = torch.tensor(std, device=device)
def forward(self, g):
def forward(self, g, atom_types, edge_distances):
"""Predict molecule labels
Parameters
----------
g : DGLGraph
Input DGLGraph for molecule(s)
atom_types : int64 tensor of shape (B1)
Types for atoms in the graph(s), B1 for the number of atoms.
edge_distances : float32 tensor of shape (B2, 1)
Edge distances, B2 for the number of edges.
Returns
-------
res : Predicted labels
prediction : float32 tensor of shape (B, output_dim)
Model prediction for the batch of graphs, B for the number
of graphs, output_dim for the prediction size.
"""
self.embedding_layer(g)
if self.atom_ref is not None:
self.e0(g, "e0")
self.rbf_layer(g)
h = self.embedding_layer(atom_types)
rbf_out = self.rbf_layer(edge_distances)
for idx in range(self.n_conv):
self.conv_layers[idx](g)
atom = self.atom_dense_layer1(g.ndata["node"])
atom = self.activation(atom)
res = self.atom_dense_layer2(atom)
g.ndata["res"] = res
h = self.conv_layers[idx](g, h, rbf_out)
h = self.atom_update(h)
if self.atom_ref is not None:
g.ndata["res"] = g.ndata["res"] + g.ndata["e0"]
h_ref = self.e0(atom_types)
h = h + h_ref
if self.norm:
g.ndata["res"] = g.ndata["res"] * self.std_per_atom + self.mean_per_atom
res = sum_nodes(g, "res")
return res
h = h * self.std_per_node + self.mean_per_node
return self.readout(g, h)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment