Unverified Commit 189c2c09 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[Model Zoo] Refactor Model Zoo for Chemistry (#839)

* Update

* Update

* Update

* Update fix

* Update

* Update

* Refactor

* Update

* Update

* Update

* Update

* Update

* Update

* Fix style
parent 5b417683
# -*- coding:utf-8 -*- # -*- coding:utf-8 -*-
# pylint: disable=C0103, C0111, W0621 # pylint: disable=C0103, C0111, W0621
"""Implementation of MGCN model""" """Implementation of MGCN model"""
import torch as th import torch
import torch.nn as nn import torch.nn as nn
from .layers import AtomEmbedding, RBFLayer, EdgeEmbedding, \ from .layers import AtomEmbedding, RBFLayer, EdgeEmbedding, \
MultiLevelInteraction MultiLevelInteraction
from ...batched_graph import sum_nodes from ...nn.pytorch import SumPooling
class MGCNModel(nn.Module): class MGCNModel(nn.Module):
""" """
MGCN from `Molecular Property Prediction: A Multilevel `Molecular Property Prediction: A Multilevel
Quantum Interactions Modeling Perspective <https://arxiv.org/abs/1906.11081>`__ Quantum Interactions Modeling Perspective <https://arxiv.org/abs/1906.11081>`__
Parameters Parameters
---------- ----------
dim : int dim : int
Dimension of feature maps, default to be 128. Size for embeddings, default to be 128.
out_put_dim: int
Number of target properties to predict, default to be 1.
edge_dim : int
Dimension of edge feature, default to be 128.
cutoff : float
The maximum distance between nodes, default to be 5.0.
width : int width : int
Width in the RBF layer, default to be 1. Width in the RBF layer, default to be 1.
cutoff : float
The maximum distance between nodes, default to be 5.0.
edge_dim : int
Size for edge embedding, default to be 128.
out_put_dim: int
Number of target properties to predict, default to be 1.
n_conv : int n_conv : int
Number of convolutional layers, default to be 3. Number of convolutional layers, default to be 3.
norm : bool norm : bool
...@@ -39,16 +39,16 @@ class MGCNModel(nn.Module): ...@@ -39,16 +39,16 @@ class MGCNModel(nn.Module):
""" """
def __init__(self, def __init__(self,
dim=128, dim=128,
output_dim=1,
edge_dim=128,
cutoff=5.0,
width=1, width=1,
cutoff=5.0,
edge_dim=128,
output_dim=1,
n_conv=3, n_conv=3,
norm=False, norm=False,
atom_ref=None, atom_ref=None,
pre_train=None): pre_train=None):
super(MGCNModel, self).__init__() super(MGCNModel, self).__init__()
self.name = "MGCN"
self._dim = dim self._dim = dim
self.output_dim = output_dim self.output_dim = output_dim
self.edge_dim = edge_dim self.edge_dim = edge_dim
...@@ -58,27 +58,29 @@ class MGCNModel(nn.Module): ...@@ -58,27 +58,29 @@ class MGCNModel(nn.Module):
self.atom_ref = atom_ref self.atom_ref = atom_ref
self.norm = norm self.norm = norm
self.activation = nn.Softplus(beta=1, threshold=20)
if atom_ref is not None:
self.e0 = AtomEmbedding(1, pre_train=atom_ref)
if pre_train is None: if pre_train is None:
self.embedding_layer = AtomEmbedding(dim) self.embedding_layer = AtomEmbedding(dim)
else: else:
self.embedding_layer = AtomEmbedding(pre_train=pre_train) self.embedding_layer = AtomEmbedding(pre_train=pre_train)
self.rbf_layer = RBFLayer(0, cutoff, width)
self.edge_embedding_layer = EdgeEmbedding(dim=edge_dim) self.edge_embedding_layer = EdgeEmbedding(dim=edge_dim)
self.rbf_layer = RBFLayer(0, cutoff, width) if atom_ref is not None:
self.e0 = AtomEmbedding(1, pre_train=atom_ref)
self.conv_layers = nn.ModuleList([ self.conv_layers = nn.ModuleList([
MultiLevelInteraction(self.rbf_layer._fan_out, dim) MultiLevelInteraction(self.rbf_layer._fan_out, dim)
for i in range(n_conv) for i in range(n_conv)
]) ])
self.node_dense_layer1 = nn.Linear(dim * (self.n_conv + 1), 64) self.out_project = nn.Sequential(
self.node_dense_layer2 = nn.Linear(64, output_dim) nn.Linear(dim * (self.n_conv + 1), 64),
nn.Softplus(beta=1, threshold=20),
nn.Linear(64, output_dim)
)
self.readout = SumPooling()
def set_mean_std(self, mean, std, device): def set_mean_std(self, mean, std, device="cpu"):
"""Set the mean and std of atom representations for normalization. """Set the mean and std of atom representations for normalization.
Parameters Parameters
...@@ -90,46 +92,45 @@ class MGCNModel(nn.Module): ...@@ -90,46 +92,45 @@ class MGCNModel(nn.Module):
device : str or torch.device device : str or torch.device
Device for storing the mean and std Device for storing the mean and std
""" """
self.mean_per_node = th.tensor(mean, device=device) self.mean_per_node = torch.tensor(mean, device=device)
self.std_per_node = th.tensor(std, device=device) self.std_per_node = torch.tensor(std, device=device)
def forward(self, g): def forward(self, g, atom_types, edge_distances):
"""Predict molecule labels """Predict molecule labels
Parameters Parameters
---------- ----------
g : DGLGraph g : DGLGraph
Input DGLGraph for molecule(s) Input DGLGraph for molecule(s)
atom_types : int64 tensor of shape (B1)
Types for atoms in the graph(s), B1 for the number of atoms.
edge_distances : float32 tensor of shape (B2, 1)
Edge distances, B2 for the number of edges.
Returns Returns
------- -------
res : Predicted labels prediction : float32 tensor of shape (B, output_dim)
Model prediction for the batch of graphs, B for the number
of graphs, output_dim for the prediction size.
""" """
self.embedding_layer(g, "node_0") h = self.embedding_layer(atom_types)
if self.atom_ref is not None: e = self.edge_embedding_layer(g, atom_types)
self.e0(g, "e0") rbf_out = self.rbf_layer(edge_distances)
self.rbf_layer(g)
self.edge_embedding_layer(g)
all_layer_h = [h]
for idx in range(self.n_conv): for idx in range(self.n_conv):
self.conv_layers[idx](g, idx + 1) h, e = self.conv_layers[idx](g, h, e, rbf_out)
all_layer_h.append(h)
node_embeddings = tuple(g.ndata["node_%d" % (i)]
for i in range(self.n_conv + 1))
g.ndata["node"] = th.cat(node_embeddings, 1)
# concat multilevel representations # concat multilevel representations
node = self.node_dense_layer1(g.ndata["node"]) h = torch.cat(all_layer_h, dim=1)
node = self.activation(node) h = self.out_project(h)
res = self.node_dense_layer2(node)
g.ndata["res"] = res
if self.atom_ref is not None: if self.atom_ref is not None:
g.ndata["res"] = g.ndata["res"] + g.ndata["e0"] h_ref = self.e0(atom_types)
h = h + h_ref
if self.norm: if self.norm:
g.ndata["res"] = g.ndata[ h = h * self.std_per_node + self.mean_per_node
"res"] * self.std_per_node + self.mean_per_node
res = sum_nodes(g, "res") return self.readout(g, h)
return res
...@@ -2,134 +2,10 @@ ...@@ -2,134 +2,10 @@
# coding: utf-8 # coding: utf-8
# pylint: disable=C0103, C0111, E1101, W0612 # pylint: disable=C0103, C0111, E1101, W0612
"""Implementation of MPNN model.""" """Implementation of MPNN model."""
import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.nn import Parameter
from ... import function as fn
from ...nn.pytorch import Set2Set
class NNConvLayer(nn.Module):
"""
MPNN Conv Layer from Section 5 of
`Neural Message Passing for Quantum Chemistry <https://arxiv.org/abs/1704.01212>`__
Parameters
----------
in_channels : int
Number of input channels
out_channels : int
Number of output channels
edge_net : Module processing edge information
root_weight : bool
Whether to add the root node feature to output
bias : bool
Whether to add bias to the output
"""
def __init__(self,
in_channels,
out_channels,
edge_net,
root_weight=True,
bias=True):
super(NNConvLayer, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.edge_net = edge_net
if root_weight:
self.root = Parameter(torch.Tensor(in_channels, out_channels))
else:
self.register_parameter('root', None)
if bias:
self.bias = Parameter(torch.Tensor(out_channels))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
"""Reinitialize model parameters"""
if self.root is not None:
nn.init.xavier_normal_(self.root.data, gain=1.414)
if self.bias is not None:
self.bias.data.zero_()
for m in self.edge_net.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_normal_(m.weight.data, gain=1.414)
def message(self, edges):
"""Function for computing messages from source nodes
Parameters
----------
edges : EdgeBatch
Edges over which we want to send messages
Returns
-------
dict
Stores message in key 'm'
"""
return {
'm':
torch.matmul(edges.src['h'].unsqueeze(1),
edges.data['w']).squeeze(1)
}
def apply_node_func(self, nodes):
"""Function for updating node features directly
Parameters
----------
nodes : NodeBatch
Returns
-------
dict
Stores updated node features in 'h'
"""
aggr_out = nodes.data['aggr_out']
if self.root is not None:
aggr_out = torch.mm(nodes.data['h'], self.root) + aggr_out
if self.bias is not None:
aggr_out = aggr_out + self.bias
return {'h': aggr_out}
def forward(self, g, h, e):
"""Propagate messages and aggregate results for updating
atom representations
Parameters
----------
g : DGLGraph
DGLgraph(s) for molecules
h : tensor
Input atom representations
e : tensor
Input bond representations
Returns
-------
tensor
Aggregated atom information
"""
h = h.unsqueeze(-1) if h.dim() == 1 else h
e = e.unsqueeze(-1) if e.dim() == 1 else e
g.ndata['h'] = h
g.edata['w'] = self.edge_net(e).view(-1, self.in_channels,
self.out_channels)
g.update_all(self.message, fn.sum("m", "aggr_out"),
self.apply_node_func)
return g.ndata.pop('h')
from ...nn.pytorch import Set2Set, NNConv
class MPNNModel(nn.Module): class MPNNModel(nn.Module):
""" """
...@@ -166,40 +42,44 @@ class MPNNModel(nn.Module): ...@@ -166,40 +42,44 @@ class MPNNModel(nn.Module):
num_layer_set2set=3): num_layer_set2set=3):
super(MPNNModel, self).__init__() super(MPNNModel, self).__init__()
self.name = "MPNN"
self.num_step_message_passing = num_step_message_passing self.num_step_message_passing = num_step_message_passing
self.lin0 = nn.Linear(node_input_dim, node_hidden_dim) self.lin0 = nn.Linear(node_input_dim, node_hidden_dim)
edge_network = nn.Sequential( edge_network = nn.Sequential(
nn.Linear(edge_input_dim, edge_hidden_dim), nn.ReLU(), nn.Linear(edge_input_dim, edge_hidden_dim), nn.ReLU(),
nn.Linear(edge_hidden_dim, node_hidden_dim * node_hidden_dim)) nn.Linear(edge_hidden_dim, node_hidden_dim * node_hidden_dim))
self.conv = NNConvLayer(in_channels=node_hidden_dim, self.conv = NNConv(in_feats=node_hidden_dim,
out_channels=node_hidden_dim, out_feats=node_hidden_dim,
edge_net=edge_network, edge_func=edge_network,
root_weight=False) aggregator_type='sum')
self.gru = nn.GRU(node_hidden_dim, node_hidden_dim) self.gru = nn.GRU(node_hidden_dim, node_hidden_dim)
self.set2set = Set2Set(node_hidden_dim, num_step_set2set, num_layer_set2set) self.set2set = Set2Set(node_hidden_dim, num_step_set2set, num_layer_set2set)
self.lin1 = nn.Linear(2 * node_hidden_dim, node_hidden_dim) self.lin1 = nn.Linear(2 * node_hidden_dim, node_hidden_dim)
self.lin2 = nn.Linear(node_hidden_dim, output_dim) self.lin2 = nn.Linear(node_hidden_dim, output_dim)
def forward(self, g): def forward(self, g, n_feat, e_feat):
"""Predict molecule labels """Predict molecule labels
Parameters Parameters
---------- ----------
g : DGLGraph g : DGLGraph
Input DGLGraph for molecule(s) Input DGLGraph for molecule(s)
n_feat : tensor of dtype float32 and shape (B1, D1)
Node features. B1 for number of nodes and D1 for
the node feature size.
e_feat : tensor of dtype float32 and shape (B2, D2)
Edge features. B2 for number of edges and D2 for
the edge feature size.
Returns Returns
------- -------
res : Predicted labels res : Predicted labels
""" """
h = g.ndata['n_feat'] out = F.relu(self.lin0(n_feat)) # (B1, H1)
out = F.relu(self.lin0(h)) h = out.unsqueeze(0) # (1, B1, H1)
h = out.unsqueeze(0)
for i in range(self.num_step_message_passing): for i in range(self.num_step_message_passing):
m = F.relu(self.conv(g, out, g.edata['e_feat'])) m = F.relu(self.conv(g, out, e_feat)) # (B1, H1)
out, h = self.gru(m.unsqueeze(0), h) out, h = self.gru(m.unsqueeze(0), h)
out = out.squeeze(0) out = out.squeeze(0)
......
...@@ -7,7 +7,7 @@ from .classifiers import GCNClassifier, GATClassifier ...@@ -7,7 +7,7 @@ from .classifiers import GCNClassifier, GATClassifier
from .dgmg import DGMG from .dgmg import DGMG
from .mgcn import MGCNModel from .mgcn import MGCNModel
from .mpnn import MPNNModel from .mpnn import MPNNModel
from .sch import SchNetModel from .schnet import SchNet
from ...data.utils import _get_dgl_url, download, get_download_dir from ...data.utils import _get_dgl_url, download, get_download_dir
URL = { URL = {
...@@ -61,6 +61,19 @@ def load_pretrained(model_name, log=True): ...@@ -61,6 +61,19 @@ def load_pretrained(model_name, log=True):
Parameters Parameters
---------- ----------
model_name : str model_name : str
Currently supported options include
* ``'GCN_Tox21'``
* ``'GAT_Tox21'``
* ``'MGCN_Alchemy'``
* ``'SCHNET_Alchemy'``
* ``'MPNN_Alchemy'``
* ``'DGMG_ChEMBL_canonical'``
* ``'DGMG_ChEMBL_random'``
* ``'DGMG_ZINC_canonical'``
* ``'DGMG_ZINC_random'``
* ``'JTNN_ZINC'``
log : bool log : bool
Whether to print progress for model loading Whether to print progress for model loading
...@@ -103,7 +116,7 @@ def load_pretrained(model_name, log=True): ...@@ -103,7 +116,7 @@ def load_pretrained(model_name, log=True):
model = MGCNModel(norm=True, output_dim=12) model = MGCNModel(norm=True, output_dim=12)
elif model_name == 'SCHNET_Alchemy': elif model_name == 'SCHNET_Alchemy':
model = SchNetModel(norm=True, output_dim=12) model = SchNet(norm=True, output_dim=12)
elif model_name == 'MPNN_Alchemy': elif model_name == 'MPNN_Alchemy':
model = MPNNModel(output_dim=12) model = MPNNModel(output_dim=12)
......
# -*- coding:utf-8 -*- # -*- coding:utf-8 -*-
# pylint: disable=C0103, C0111, W0621 # pylint: disable=C0103, C0111, W0621
"""Implementation of SchNet model.""" """Implementation of SchNet model."""
import torch as th import torch
import torch.nn as nn import torch.nn as nn
from .layers import AtomEmbedding, Interaction, ShiftSoftplus, RBFLayer from .layers import AtomEmbedding, Interaction, ShiftSoftplus, RBFLayer
from ...batched_graph import sum_nodes from ...nn.pytorch import SumPooling
class SchNetModel(nn.Module): class SchNet(nn.Module):
""" """
`SchNet: A continuous-filter convolutional neural network for modeling `SchNet: A continuous-filter convolutional neural network for modeling
quantum interactions. (NIPS'2017) <https://arxiv.org/abs/1706.08566>`__ quantum interactions. (NIPS'2017) <https://arxiv.org/abs/1706.08566>`__
...@@ -16,15 +16,15 @@ class SchNetModel(nn.Module): ...@@ -16,15 +16,15 @@ class SchNetModel(nn.Module):
Parameters Parameters
---------- ----------
dim : int dim : int
Dimension of features, default to be 64 Size for atom embeddings, default to be 64.
cutoff : float cutoff : float
Radius cutoff for RBF, default to be 5.0 Radius cutoff for RBF, default to be 5.0.
output_dim : int output_dim : int
Dimension of prediction, default to be 1 Number of target properties to predict, default to be 1.
width : int width : int
Width in RBF, default to 1 Width in RBF, default to 1.
n_conv : int n_conv : int
Number of conv (interaction) layers, default to be 1 Number of conv (interaction) layers, default to be 1.
norm : bool norm : bool
Whether to normalize the output atom representations, default to be False. Whether to normalize the output atom representations, default to be False.
atom_ref : Atom embeddings or None atom_ref : Atom embeddings or None
...@@ -43,28 +43,32 @@ class SchNetModel(nn.Module): ...@@ -43,28 +43,32 @@ class SchNetModel(nn.Module):
norm=False, norm=False,
atom_ref=None, atom_ref=None,
pre_train=None): pre_train=None):
super().__init__() super(SchNet, self).__init__()
self.name = "SchNet"
self._dim = dim self._dim = dim
self.cutoff = cutoff self.cutoff = cutoff
self.width = width self.width = width
self.n_conv = n_conv self.n_conv = n_conv
self.atom_ref = atom_ref self.atom_ref = atom_ref
self.norm = norm self.norm = norm
self.activation = ShiftSoftplus()
if atom_ref is not None: if atom_ref is not None:
self.e0 = AtomEmbedding(1, pre_train=atom_ref) self.e0 = AtomEmbedding(1, pre_train=atom_ref)
if pre_train is None: if pre_train is None:
self.embedding_layer = AtomEmbedding(dim) self.embedding_layer = AtomEmbedding(dim)
else: else:
self.embedding_layer = AtomEmbedding(pre_train=pre_train) self.embedding_layer = AtomEmbedding(pre_train=pre_train)
self.rbf_layer = RBFLayer(0, cutoff, width) self.rbf_layer = RBFLayer(0, cutoff, width)
self.conv_layers = nn.ModuleList( self.conv_layers = nn.ModuleList(
[Interaction(self.rbf_layer._fan_out, dim) for i in range(n_conv)]) [Interaction(self.rbf_layer._fan_out, dim) for _ in range(n_conv)])
self.atom_update = nn.Sequential(
self.atom_dense_layer1 = nn.Linear(dim, 64) nn.Linear(dim, 64),
self.atom_dense_layer2 = nn.Linear(64, output_dim) ShiftSoftplus(),
nn.Linear(64, output_dim)
)
self.readout = SumPooling()
def set_mean_std(self, mean, std, device="cpu"): def set_mean_std(self, mean, std, device="cpu"):
"""Set the mean and std of atom representations for normalization. """Set the mean and std of atom representations for normalization.
...@@ -78,37 +82,38 @@ class SchNetModel(nn.Module): ...@@ -78,37 +82,38 @@ class SchNetModel(nn.Module):
device : str or torch.device device : str or torch.device
Device for storing the mean and std Device for storing the mean and std
""" """
self.mean_per_atom = th.tensor(mean, device=device) self.mean_per_node = torch.tensor(mean, device=device)
self.std_per_atom = th.tensor(std, device=device) self.std_per_node = torch.tensor(std, device=device)
def forward(self, g): def forward(self, g, atom_types, edge_distances):
"""Predict molecule labels """Predict molecule labels
Parameters Parameters
---------- ----------
g : DGLGraph g : DGLGraph
Input DGLGraph for molecule(s) Input DGLGraph for molecule(s)
atom_types : int64 tensor of shape (B1)
Types for atoms in the graph(s), B1 for the number of atoms.
edge_distances : float32 tensor of shape (B2, 1)
Edge distances, B2 for the number of edges.
Returns Returns
------- -------
res : Predicted labels prediction : float32 tensor of shape (B, output_dim)
Model prediction for the batch of graphs, B for the number
of graphs, output_dim for the prediction size.
""" """
self.embedding_layer(g) h = self.embedding_layer(atom_types)
if self.atom_ref is not None: rbf_out = self.rbf_layer(edge_distances)
self.e0(g, "e0")
self.rbf_layer(g)
for idx in range(self.n_conv): for idx in range(self.n_conv):
self.conv_layers[idx](g) h = self.conv_layers[idx](g, h, rbf_out)
h = self.atom_update(h)
atom = self.atom_dense_layer1(g.ndata["node"])
atom = self.activation(atom)
res = self.atom_dense_layer2(atom)
g.ndata["res"] = res
if self.atom_ref is not None: if self.atom_ref is not None:
g.ndata["res"] = g.ndata["res"] + g.ndata["e0"] h_ref = self.e0(atom_types)
h = h + h_ref
if self.norm: if self.norm:
g.ndata["res"] = g.ndata["res"] * self.std_per_atom + self.mean_per_atom h = h * self.std_per_node + self.mean_per_node
res = sum_nodes(g, "res")
return res return self.readout(g, h)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment