Unverified Commit 36c7b771 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[LifeSci] Move to Independent Repo (#1592)

* Move LifeSci

* Remove doc
parent 94c67203
#!/usr/bin/env python
# coding: utf-8
# pylint: disable=C0103, C0111, E1101, W0612
"""Implementation of MPNN model."""
import torch.nn as nn
import torch.nn.functional as F
from ...contrib.deprecation import deprecated
from ...nn.pytorch import Set2Set, NNConv
class MPNNModel(nn.Module):
"""
MPNN from
`Neural Message Passing for Quantum Chemistry <https://arxiv.org/abs/1704.01212>`__
Parameters
----------
node_input_dim : int
Dimension of input node feature, default to be 15.
edge_input_dim : int
Dimension of input edge feature, default to be 15.
output_dim : int
Dimension of prediction, default to be 12.
node_hidden_dim : int
Dimension of node feature in hidden layers, default to be 64.
edge_hidden_dim : int
Dimension of edge feature in hidden layers, default to be 128.
num_step_message_passing : int
Number of message passing steps, default to be 6.
num_step_set2set : int
Number of set2set steps
num_layer_set2set : int
Number of set2set layers
"""
@deprecated('Import MPNNPredictor from dgllife.model instead.', 'class')
def __init__(self,
node_input_dim=15,
edge_input_dim=5,
output_dim=12,
node_hidden_dim=64,
edge_hidden_dim=128,
num_step_message_passing=6,
num_step_set2set=6,
num_layer_set2set=3):
super(MPNNModel, self).__init__()
self.num_step_message_passing = num_step_message_passing
self.lin0 = nn.Linear(node_input_dim, node_hidden_dim)
edge_network = nn.Sequential(
nn.Linear(edge_input_dim, edge_hidden_dim), nn.ReLU(),
nn.Linear(edge_hidden_dim, node_hidden_dim * node_hidden_dim))
self.conv = NNConv(in_feats=node_hidden_dim,
out_feats=node_hidden_dim,
edge_func=edge_network,
aggregator_type='sum')
self.gru = nn.GRU(node_hidden_dim, node_hidden_dim)
self.set2set = Set2Set(node_hidden_dim, num_step_set2set, num_layer_set2set)
self.lin1 = nn.Linear(2 * node_hidden_dim, node_hidden_dim)
self.lin2 = nn.Linear(node_hidden_dim, output_dim)
def forward(self, g, n_feat, e_feat):
"""Predict molecule labels
Parameters
----------
g : DGLGraph
Input DGLGraph for molecule(s)
n_feat : tensor of dtype float32 and shape (B1, D1)
Node features. B1 for number of nodes and D1 for
the node feature size.
e_feat : tensor of dtype float32 and shape (B2, D2)
Edge features. B2 for number of edges and D2 for
the edge feature size.
Returns
-------
res : Predicted labels
"""
out = F.relu(self.lin0(n_feat)) # (B1, H1)
h = out.unsqueeze(0) # (1, B1, H1)
for i in range(self.num_step_message_passing):
m = F.relu(self.conv(g, out, e_feat)) # (B1, H1)
out, h = self.gru(m.unsqueeze(0), h)
out = out.squeeze(0)
out = self.set2set(g, out)
out = F.relu(self.lin1(out))
out = self.lin2(out)
return out
"""Utilities for using pretrained models."""
import os
import torch
from rdkit import Chem
from . import DGLJTNNVAE
from .classifiers import GCNClassifier, GATClassifier
from .dgmg import DGMG
from .mgcn import MGCNModel
from .mpnn import MPNNModel
from .schnet import SchNet
from .attentive_fp import AttentiveFP
from ...data.utils import _get_dgl_url, download, get_download_dir, extract_archive
from ...contrib.deprecation import deprecated
URL = {
'GCN_Tox21': 'pre_trained/gcn_tox21.pth',
'GAT_Tox21': 'pre_trained/gat_tox21.pth',
'MGCN_Alchemy': 'pre_trained/mgcn_alchemy.pth',
'SCHNET_Alchemy': 'pre_trained/schnet_alchemy.pth',
'MPNN_Alchemy': 'pre_trained/mpnn_alchemy.pth',
'AttentiveFP_Aromaticity': 'pre_trained/attentivefp_aromaticity.pth',
'DGMG_ChEMBL_canonical': 'pre_trained/dgmg_ChEMBL_canonical.pth',
'DGMG_ChEMBL_random': 'pre_trained/dgmg_ChEMBL_random.pth',
'DGMG_ZINC_canonical': 'pre_trained/dgmg_ZINC_canonical.pth',
'DGMG_ZINC_random': 'pre_trained/dgmg_ZINC_random.pth',
'JTNN_ZINC': 'pre_trained/JTNN_ZINC.pth'
}
def download_and_load_checkpoint(model_name, model, model_postfix,
local_pretrained_path='pre_trained.pth', log=True):
"""Download pretrained model checkpoint
The model will be loaded to CPU.
Parameters
----------
model_name : str
Name of the model
model : nn.Module
Instantiated model instance
model_postfix : str
Postfix for pretrained model checkpoint
local_pretrained_path : str
Local name for the downloaded model checkpoint
log : bool
Whether to print progress for model loading
Returns
-------
model : nn.Module
Pretrained model
"""
url_to_pretrained = _get_dgl_url(model_postfix)
local_pretrained_path = '_'.join([model_name, local_pretrained_path])
download(url_to_pretrained, path=local_pretrained_path, log=log)
checkpoint = torch.load(local_pretrained_path, map_location='cpu')
model.load_state_dict(checkpoint['model_state_dict'])
if log:
print('Pretrained model loaded')
return model
@deprecated('Import it from dgllife.model instead.')
def load_pretrained(model_name, log=True):
"""Load a pretrained model
Parameters
----------
model_name : str
Currently supported options include
* ``'GCN_Tox21'``
* ``'GAT_Tox21'``
* ``'MGCN_Alchemy'``
* ``'SCHNET_Alchemy'``
* ``'MPNN_Alchemy'``
* ``'AttentiveFP_Aromaticity'``
* ``'DGMG_ChEMBL_canonical'``
* ``'DGMG_ChEMBL_random'``
* ``'DGMG_ZINC_canonical'``
* ``'DGMG_ZINC_random'``
* ``'JTNN_ZINC'``
log : bool
Whether to print progress for model loading
Returns
-------
model
"""
if model_name not in URL:
raise RuntimeError("Cannot find a pretrained model with name {}".format(model_name))
if model_name == 'GCN_Tox21':
model = GCNClassifier(in_feats=74,
gcn_hidden_feats=[64, 64],
classifier_hidden_feats=64,
n_tasks=12)
elif model_name == 'GAT_Tox21':
model = GATClassifier(in_feats=74,
gat_hidden_feats=[32, 32],
num_heads=[4, 4],
classifier_hidden_feats=64,
n_tasks=12)
elif model_name.startswith('DGMG'):
if model_name.startswith('DGMG_ChEMBL'):
atom_types = ['O', 'Cl', 'C', 'S', 'F', 'Br', 'N']
elif model_name.startswith('DGMG_ZINC'):
atom_types = ['Br', 'S', 'C', 'P', 'N', 'O', 'F', 'Cl', 'I']
bond_types = [Chem.rdchem.BondType.SINGLE,
Chem.rdchem.BondType.DOUBLE,
Chem.rdchem.BondType.TRIPLE]
model = DGMG(atom_types=atom_types,
bond_types=bond_types,
node_hidden_size=128,
num_prop_rounds=2,
dropout=0.2)
elif model_name == 'MGCN_Alchemy':
model = MGCNModel(norm=True, output_dim=12)
elif model_name == 'SCHNET_Alchemy':
model = SchNet(norm=True, output_dim=12)
elif model_name == 'MPNN_Alchemy':
model = MPNNModel(output_dim=12)
elif model_name == 'AttentiveFP_Aromaticity':
model = AttentiveFP(node_feat_size=39,
edge_feat_size=10,
num_layers=2,
num_timesteps=2,
graph_feat_size=200,
output_size=1,
dropout=0.2)
elif model_name == "JTNN_ZINC":
default_dir = get_download_dir()
vocab_file = '{}/jtnn/{}.txt'.format(default_dir, 'vocab')
if not os.path.exists(vocab_file):
zip_file_path = '{}/jtnn.zip'.format(default_dir)
download(_get_dgl_url('dgllife/jtnn.zip'), path=zip_file_path)
extract_archive(zip_file_path, '{}/jtnn'.format(default_dir))
model = DGLJTNNVAE(vocab_file=vocab_file,
depth=3,
hidden_size=450,
latent_size=56)
return download_and_load_checkpoint(model_name, model, URL[model_name], log=log)
# -*- coding:utf-8 -*-
# pylint: disable=C0103, C0111, W0621
"""Implementation of SchNet model."""
import torch
import torch.nn as nn
from .layers import AtomEmbedding, Interaction, ShiftSoftplus, RBFLayer
from ...contrib.deprecation import deprecated
from ...nn.pytorch import SumPooling
class SchNet(nn.Module):
"""
`SchNet: A continuous-filter convolutional neural network for modeling
quantum interactions. (NIPS'2017) <https://arxiv.org/abs/1706.08566>`__
Parameters
----------
dim : int
Size for atom embeddings, default to be 64.
cutoff : float
Radius cutoff for RBF, default to be 5.0.
output_dim : int
Number of target properties to predict, default to be 1.
width : int
Width in RBF, default to 1.
n_conv : int
Number of conv (interaction) layers, default to be 1.
norm : bool
Whether to normalize the output atom representations, default to be False.
atom_ref : Atom embeddings or None
If None, random representation initialization will be used. Otherwise,
they will be used to initialize atom representations. Default to be None.
pre_train : Atom embeddings or None
If None, random representation initialization will be used. Otherwise,
they will be used to initialize atom representations. Default to be None.
"""
@deprecated('Import SchNetPredictor from dgllife.model instead.')
def __init__(self,
dim=64,
cutoff=5.0,
output_dim=1,
width=1,
n_conv=3,
norm=False,
atom_ref=None,
pre_train=None):
super(SchNet, self).__init__()
self._dim = dim
self.cutoff = cutoff
self.width = width
self.n_conv = n_conv
self.atom_ref = atom_ref
self.norm = norm
if atom_ref is not None:
self.e0 = AtomEmbedding(1, pre_train=atom_ref)
if pre_train is None:
self.embedding_layer = AtomEmbedding(dim)
else:
self.embedding_layer = AtomEmbedding(pre_train=pre_train)
self.rbf_layer = RBFLayer(0, cutoff, width)
self.conv_layers = nn.ModuleList(
[Interaction(self.rbf_layer._fan_out, dim) for _ in range(n_conv)])
self.atom_update = nn.Sequential(
nn.Linear(dim, 64),
ShiftSoftplus(),
nn.Linear(64, output_dim)
)
self.readout = SumPooling()
def set_mean_std(self, mean, std, device="cpu"):
"""Set the mean and std of atom representations for normalization.
Parameters
----------
mean : list or numpy array
The mean of labels
std : list or numpy array
The std of labels
device : str or torch.device
Device for storing the mean and std
"""
self.mean_per_node = torch.tensor(mean, device=device)
self.std_per_node = torch.tensor(std, device=device)
def forward(self, g, atom_types, edge_distances):
"""Predict molecule labels
Parameters
----------
g : DGLGraph
Input DGLGraph for molecule(s)
atom_types : int64 tensor of shape (B1)
Types for atoms in the graph(s), B1 for the number of atoms.
edge_distances : float32 tensor of shape (B2, 1)
Edge distances, B2 for the number of edges.
Returns
-------
prediction : float32 tensor of shape (B, output_dim)
Model prediction for the batch of graphs, B for the number
of graphs, output_dim for the prediction size.
"""
h = self.embedding_layer(atom_types)
rbf_out = self.rbf_layer(edge_distances)
for idx in range(self.n_conv):
h = self.conv_layers[idx](g, h, rbf_out)
h = self.atom_update(h)
if self.atom_ref is not None:
h_ref = self.e0(atom_types)
h = h + h_ref
if self.norm:
h = h * self.std_per_node + self.mean_per_node
return self.readout(g, h)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment