Unverified Commit 828a5e5b authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[DGL-LifeSci] Migration and Refactor (#1226)

* First commit

* Update

* Update splitters

* Update

* Update

* Update

* Update

* Update

* Update

* Migrate ACNN

* Fix

* Fix

* Update

* Update

* Update

* Update

* Update

* Update

* Finish classification

* Update

* Fix

* Update

* Update

* Update

* Fix

* Fix

* Fix

* Update

* Update

* Update

* trigger CI

* Fix CI

* Update

* Update

* Update

* Add default values

* Rename

* Update deprecation message
parent e4948c5c
......@@ -5,6 +5,8 @@ import warnings
from functools import partial
from multiprocessing import Pool
from ....contrib.deprecation import deprecated
try:
import pdbfixer
import simtk
......@@ -20,6 +22,7 @@ __all__ = ['add_hydrogens_to_mol',
'load_molecule',
'multiprocess_load_molecules']
@deprecated('')
def add_hydrogens_to_mol(mol):
"""Add hydrogens to an RDKit molecule instance.
......@@ -57,6 +60,7 @@ def add_hydrogens_to_mol(mol):
warnings.warn('Failed to add hydrogens to the molecule.')
return mol
@deprecated('Import it from dgllife.utils.rdkit_utils instead.')
def get_mol_3D_coordinates(mol):
"""Get 3D coordinates of the molecule.
......@@ -83,6 +87,7 @@ def get_mol_3D_coordinates(mol):
warnings.warn('Unable to get conformation of the molecule.')
return None
@deprecated('Import it from dgllife.utils.rdkit_utils instead.')
def load_molecule(molecule_file, add_hydrogens=False, sanitize=False, calc_charges=False,
remove_hs=False, use_conformation=True):
"""Load a molecule from a file.
......@@ -161,6 +166,7 @@ def load_molecule(molecule_file, add_hydrogens=False, sanitize=False, calc_charg
return mol, coordinates
@deprecated('Import it from dgllife.utils.rdkit_utils instead.')
def multiprocess_load_molecules(files, add_hydrogens=False, sanitize=False, calc_charges=False,
remove_hs=False, use_conformation=True, num_processes=2):
"""Load molecules from files with multiprocessing.
......
......@@ -11,6 +11,7 @@ from itertools import accumulate, chain
from ...utils import split_dataset, Subset
from .... import backend as F
from ....contrib.deprecation import deprecated
try:
from rdkit import Chem
......@@ -187,7 +188,9 @@ class ConsecutiveSplitter(object):
The dataset is split without permutation, so the splitting is deterministic.
"""
@staticmethod
@deprecated('Import ConsecutiveSplitter from dgllife.utils.splitters instead.', 'class')
def train_val_test_split(dataset, frac_train=0.8, frac_val=0.1, frac_test=0.1):
"""Split the dataset into three consecutive chunks for training, validation and test.
......@@ -214,6 +217,7 @@ class ConsecutiveSplitter(object):
return split_dataset(dataset, frac_list=[frac_train, frac_val, frac_test], shuffle=False)
@staticmethod
@deprecated('Import ConsecutiveSplitter from dgllife.utils.splitters instead.', 'class')
def k_fold_split(dataset, k=5, log=True):
"""Split the dataset for k-fold cross validation by taking consecutive chunks.
......@@ -240,6 +244,7 @@ class RandomSplitter(object):
The dataset is split with permutation and the splitting is hence random.
"""
@staticmethod
@deprecated('Import RandomSplitter from dgllife.utils.splitters instead.', 'class')
def train_val_test_split(dataset, frac_train=0.8, frac_val=0.1,
frac_test=0.1, random_state=None):
"""Randomly permute the dataset and then split it into
......@@ -275,6 +280,7 @@ class RandomSplitter(object):
shuffle=True, random_state=random_state)
@staticmethod
@deprecated('Import RandomSplitter from dgllife.utils.splitters instead.', 'class')
def k_fold_split(dataset, k=5, random_state=None, log=True):
"""Randomly permute the dataset and then split it
for k-fold cross validation by taking consecutive chunks.
......@@ -309,6 +315,7 @@ class RandomSplitter(object):
class MolecularWeightSplitter(object):
"""Sort molecules based on their weights and then split them."""
@staticmethod
@deprecated('Import MolecularWeightSplitter from dgllife.utils.splitters instead.', 'class')
def molecular_weight_indices(molecules, log_every_n):
"""Reorder molecules based on molecular weights.
......@@ -341,6 +348,7 @@ class MolecularWeightSplitter(object):
return np.argsort(mws)
@staticmethod
@deprecated('Import MolecularWeightSplitter from dgllife.utils.splitters instead.', 'class')
def train_val_test_split(dataset, mols=None, sanitize=True, frac_train=0.8,
frac_val=0.1, frac_test=0.1, log_every_n=1000):
"""Sort molecules based on their weights and then split them into
......@@ -390,6 +398,7 @@ class MolecularWeightSplitter(object):
return indices_split(dataset, frac_train, frac_val, frac_test, sorted_indices)
@staticmethod
@deprecated('Import MolecularWeightSplitter from dgllife.utils.splitters instead.', 'class')
def k_fold_split(dataset, mols=None, sanitize=True, k=5, log_every_n=1000):
"""Sort molecules based on their weights and then split them
for k-fold cross validation by taking consecutive chunks.
......@@ -439,7 +448,9 @@ class ScaffoldSplitter(object):
Bemis, G. W.; Murcko, M. A. “The Properties of Known Drugs.
1. Molecular Frameworks.” J. Med. Chem. 39:2887-93 (1996).
"""
@staticmethod
@deprecated('Import ScaffoldSplitter from dgllife.utils.splitters instead.', 'class')
def get_ordered_scaffold_sets(molecules, include_chirality, log_every_n):
"""Group molecules based on their Bemis-Murcko scaffolds and
order these groups based on their sizes.
......@@ -494,6 +505,7 @@ class ScaffoldSplitter(object):
return scaffold_sets
@staticmethod
@deprecated('Import ScaffoldSplitter from dgllife.utils.splitters instead.', 'class')
def train_val_test_split(dataset, mols=None, sanitize=True, include_chirality=False,
frac_train=0.8, frac_val=0.1, frac_test=0.1, log_every_n=1000):
"""Split the dataset into training, validation and test set based on molecular scaffolds.
......@@ -564,6 +576,7 @@ class ScaffoldSplitter(object):
Subset(dataset, test_indices)]
@staticmethod
@deprecated('Import ScaffoldSplitter from dgllife.utils.splitters instead.', 'class')
def k_fold_split(dataset, mols=None, sanitize=True,
include_chirality=False, k=5, log_every_n=1000):
"""Group molecules based on their scaffolds and sort groups based on their sizes.
......@@ -636,6 +649,8 @@ class SingleTaskStratifiedSplitter(object):
take buckets of datapoints to augment the training, validation and test subsets.
"""
@staticmethod
@deprecated('Import SingleTaskStratifiedSplitter from '
'dgllife.utils.splitters instead.', 'class')
def train_val_test_split(dataset, labels, task_id, frac_train=0.8, frac_val=0.1,
frac_test=0.1, bucket_size=10, random_state=None):
"""Split the dataset into training, validation and test subsets as stated above.
......@@ -707,6 +722,8 @@ class SingleTaskStratifiedSplitter(object):
Subset(dataset, test_indices)]
@staticmethod
@deprecated('Import SingleTaskStratifiedSplitter from '
'dgllife.utils.splitters instead.', 'class')
def k_fold_split(dataset, labels, task_id, k=5, log=True):
"""Sort molecules based on their label values for a task and then split them
for k-fold cross validation by taking consecutive chunks.
......
......@@ -5,6 +5,7 @@ import torch
import torch.nn as nn
from ...nn.pytorch import AtomicConv
from ...contrib.deprecation import deprecated
def truncated_normal_(tensor, mean=0., std=1.):
"""Fills the given tensor in-place with elements sampled from the truncated normal
......@@ -146,6 +147,7 @@ class ACNN(nn.Module):
num_tasks : int
Number of output tasks.
"""
@deprecated('Import ACNN from dgllife.model instead.')
def __init__(self, hidden_sizes, weight_init_stddevs, dropouts,
features_to_use=None, radial=None, num_tasks=1):
super(ACNN, self).__init__()
......
......@@ -7,6 +7,7 @@ import torch.nn as nn
import torch.nn.functional as F
from ... import function as fn
from ...contrib.deprecation import deprecated
from ...nn.pytorch.softmax import edge_softmax
class AttentiveGRU1(nn.Module):
......@@ -296,6 +297,7 @@ class AttentiveFP(nn.Module):
dropout : float
The probability for performing dropout.
"""
@deprecated('Import AttentiveFPPredictor from dgllife.model instead.', 'class')
def __init__(self,
node_feat_size,
edge_feat_size,
......
......@@ -2,11 +2,11 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
from dgl import BatchedDGLGraph
from .gnn import GCNLayer, GATLayer
from ...batched_graph import BatchedDGLGraph, max_nodes
from ...nn.pytorch import WeightAndSum
from ...contrib.deprecation import deprecated
class MLPBinaryClassifier(nn.Module):
"""MLP for soft binary classification over multiple tasks from molecule representations.
......@@ -98,7 +98,7 @@ class BaseGNNClassifier(nn.Module):
with bg.local_scope():
bg.ndata['h'] = feats
h_g_max = dgl.max_nodes(bg, 'h')
h_g_max = max_nodes(bg, 'h')
if not isinstance(bg, BatchedDGLGraph):
h_g_sum = h_g_sum.unsqueeze(0)
......@@ -127,6 +127,7 @@ class GCNClassifier(BaseGNNClassifier):
The probability for dropout. Default to be 0., i.e. no
dropout is performed.
"""
@deprecated('Import GCNPredictor from dgllife.model instead.', 'class')
def __init__(self, in_feats, gcn_hidden_feats, n_tasks,
classifier_hidden_feats=128, dropout=0.):
super(GCNClassifier, self).__init__(gnn_out_feats=gcn_hidden_feats[-1],
......@@ -148,6 +149,7 @@ class GATClassifier(BaseGNNClassifier):
in_feats : int
Number of input atom features
"""
@deprecated('Import GATPredictor from dgllife.model instead.', 'class')
def __init__(self, in_feats, gat_hidden_feats, num_heads,
n_tasks, classifier_hidden_feats=128, dropout=0):
super(GATClassifier, self).__init__(gnn_out_feats=gat_hidden_feats[-1],
......
......@@ -13,6 +13,7 @@ from torch.distributions import Categorical
import dgl
from dgl import DGLGraph
from dgl.contrib.deprecation import deprecated
try:
from rdkit import Chem
......@@ -647,6 +648,7 @@ class DGMG(nn.Module):
dropout : float
Probability for dropout
"""
@deprecated('Import DGMG from dgllife.model instead.', 'class')
def __init__(self, atom_types, bond_types, node_hidden_size, num_prop_rounds, dropout):
super(DGMG, self).__init__()
......
"""JTNN Module
"""
"""JTNN Module"""
from .chemutils import decode_stereo
from .jtnn_vae import DGLJTNNVAE
from .mol_tree import Vocab
......
......@@ -7,8 +7,9 @@ import torch.nn.functional as F
import rdkit.Chem as Chem
from dgl import batch, unbatch
from dgl.data.utils import get_download_dir
from ....batched_graph import batch, unbatch
from ....contrib.deprecation import deprecated
from ....data.utils import get_download_dir
from .chemutils import (attach_mols_nx, copy_edit_mol, decode_stereo,
enum_assemble_nx, set_atommap)
......@@ -27,8 +28,10 @@ class DGLJTNNVAE(nn.Module):
`Junction Tree Variational Autoencoder for Molecular Graph Generation
<https://arxiv.org/abs/1802.04364>`__
"""
@deprecated('Import DGLJTNNVAE from dgllife.model instead.', 'class')
def __init__(self, hidden_size, latent_size, depth, vocab=None, vocab_file=None):
super(DGLJTNNVAE, self).__init__()
if vocab is None:
if vocab_file is None:
vocab_file = '{}/jtnn/{}.txt'.format(
......
......@@ -7,6 +7,7 @@ import torch.nn as nn
from .layers import AtomEmbedding, RBFLayer, EdgeEmbedding, \
MultiLevelInteraction
from ...nn.pytorch import SumPooling
from ...contrib.deprecation import deprecated
class MGCNModel(nn.Module):
......@@ -37,6 +38,7 @@ class MGCNModel(nn.Module):
If None, random representation initialization will be used. Otherwise,
they will be used to initialize atom representations. Default to be None.
"""
@deprecated('Import MGCNPredictor from dgllife.model instead.', 'class')
def __init__(self,
dim=128,
width=1,
......
......@@ -5,6 +5,7 @@
import torch.nn as nn
import torch.nn.functional as F
from ...contrib.deprecation import deprecated
from ...nn.pytorch import Set2Set, NNConv
class MPNNModel(nn.Module):
......@@ -31,6 +32,7 @@ class MPNNModel(nn.Module):
num_layer_set2set : int
Number of set2set layers
"""
@deprecated('Import MPNNPredictor from dgllife.model instead.', 'class')
def __init__(self,
node_input_dim=15,
edge_input_dim=5,
......
"""Utilities for using pretrained models."""
import os
import numpy as np
import torch
from rdkit import Chem
......@@ -11,8 +10,8 @@ from .mgcn import MGCNModel
from .mpnn import MPNNModel
from .schnet import SchNet
from .attentive_fp import AttentiveFP
from .acnn import ACNN
from ...data.utils import _get_dgl_url, download, get_download_dir, extract_archive
from ...contrib.deprecation import deprecated
URL = {
'GCN_Tox21': 'pre_trained/gcn_tox21.pth',
......@@ -63,6 +62,7 @@ def download_and_load_checkpoint(model_name, model, model_postfix,
return model
@deprecated('Import it from dgllife.model instead.')
def load_pretrained(model_name, log=True):
"""Load a pretrained model
......@@ -82,14 +82,6 @@ def load_pretrained(model_name, log=True):
* ``'DGMG_ZINC_canonical'``
* ``'DGMG_ZINC_random'``
* ``'JTNN_ZINC'``
* ``'ACNN_PDBBind_core_pocket_random'``
* ``'ACNN_PDBBind_core_pocket_scaffold'``
* ``'ACNN_PDBBind_core_pocket_stratified'``
* ``'ACNN_PDBBind_core_pocket_temporal'``
* ``'ACNN_PDBBind_refined_pocket_random'``
* ``'ACNN_PDBBind_refined_pocket_scaffold'``
* ``'ACNN_PDBBind_refined_pocket_stratified'``
* ``'ACNN_PDBBind_refined_pocket_temporal'``
log : bool
Whether to print progress for model loading
......@@ -152,21 +144,11 @@ def load_pretrained(model_name, log=True):
vocab_file = '{}/jtnn/{}.txt'.format(default_dir, 'vocab')
if not os.path.exists(vocab_file):
zip_file_path = '{}/jtnn.zip'.format(default_dir)
download('https://s3-ap-southeast-1.amazonaws.com/dgl-data-cn/dataset/jtnn.zip',
path=zip_file_path)
download(_get_dgl_url('dgllife/jtnn.zip'), path=zip_file_path)
extract_archive(zip_file_path, '{}/jtnn'.format(default_dir))
model = DGLJTNNVAE(vocab_file=vocab_file,
depth=3,
hidden_size=450,
latent_size=56)
elif model_name.startswith('ACNN_PDBBind_core_pocket'):
model = ACNN(hidden_sizes=[32, 32, 16],
weight_init_stddevs=[1. / float(np.sqrt(32)), 1. / float(np.sqrt(32)),
1. / float(np.sqrt(16)), 0.01],
dropouts=[0., 0., 0.],
features_to_use=torch.tensor([
1., 6., 7., 8., 9., 11., 12., 15., 16., 17., 20., 25., 30., 35., 53.]),
radial=[[12.0], [0.0, 4.0, 8.0], [4.0]])
return download_and_load_checkpoint(model_name, model, URL[model_name], log=log)
......@@ -5,6 +5,7 @@ import torch
import torch.nn as nn
from .layers import AtomEmbedding, Interaction, ShiftSoftplus, RBFLayer
from ...contrib.deprecation import deprecated
from ...nn.pytorch import SumPooling
......@@ -34,6 +35,7 @@ class SchNet(nn.Module):
If None, random representation initialization will be used. Otherwise,
they will be used to initialize atom representations. Default to be None.
"""
@deprecated('Import SchNetPredictor from dgllife.model instead.')
def __init__(self,
dim=64,
cutoff=5.0,
......
......@@ -19,8 +19,9 @@ from .densechebconv import DenseChebConv
from .densegraphconv import DenseGraphConv
from .densesageconv import DenseSAGEConv
from .atomicconv import AtomicConv
from .cfconv import CFConv
__all__ = ['GraphConv', 'GATConv', 'TAGConv', 'RelGraphConv', 'SAGEConv',
'SGConv', 'APPNPConv', 'GINConv', 'GatedGraphConv', 'GMMConv',
'ChebConv', 'AGNNConv', 'NNConv', 'DenseGraphConv', 'DenseSAGEConv',
'DenseChebConv', 'EdgeConv', 'AtomicConv']
'DenseChebConv', 'EdgeConv', 'AtomicConv', 'CFConv']
"""Torch Module for Atomic Convolution Layer"""
# pylint: disable= no-member, arguments-differ, invalid-name
import numpy as np
import torch as th
import torch.nn as nn
......
"""Torch modules for interaction blocks in SchNet"""
# pylint: disable= no-member, arguments-differ, invalid-name
import numpy as np
import torch.nn as nn
from .... import function as fn
class ShiftedSoftplus(nn.Module):
r"""Applies the element-wise function:
.. math::
\text{SSP}(x) = \frac{1}{\beta} * \log(1 + \exp(\beta * x)) - \log(\text{shift})
Parameters
----------
beta : int
:math:`\beta` value for the mathematical formulation. Default to 1.
shift : int
:math:`\text{shift}` value for the mathematical formulation. Default to 2.
"""
def __init__(self, beta=1, shift=2, threshold=20):
super(ShiftedSoftplus, self).__init__()
self.shift = shift
self.softplus = nn.Softplus(beta=beta, threshold=threshold)
def forward(self, inputs):
"""Applies the activation function.
Parameters
----------
inputs : float32 tensor of shape (N, *)
* denotes any number of additional dimensions.
Returns
-------
float32 tensor of shape (N, *)
Result of applying the activation function to the input.
"""
return self.softplus(inputs) - np.log(float(self.shift))
class CFConv(nn.Module):
r"""CFConv in SchNet.
SchNet is introduced in `SchNet: A continuous-filter convolutional neural network for
modeling quantum interactions <https://arxiv.org/abs/1706.08566>`__.
It combines node and edge features in message passing and updates node representations.
Parameters
----------
node_in_feats : int
Size for the input node features.
edge_in_feats : int
Size for the input edge features.
hidden_feats : int
Size for the hidden representations.
out_feats : int
Size for the output representations.
"""
def __init__(self, node_in_feats, edge_in_feats, hidden_feats, out_feats):
super(CFConv, self).__init__()
self.project_edge = nn.Sequential(
nn.Linear(edge_in_feats, hidden_feats),
ShiftedSoftplus(),
nn.Linear(hidden_feats, hidden_feats),
ShiftedSoftplus()
)
self.project_node = nn.Linear(node_in_feats, hidden_feats)
self.project_out = nn.Sequential(
nn.Linear(hidden_feats, out_feats),
ShiftedSoftplus()
)
def forward(self, g, node_feats, edge_feats):
"""Performs message passing and updates node representations.
Parameters
----------
g : DGLGraph
The graph.
node_feats : float32 tensor of shape (V, node_in_feats)
Input node features, V for the number of nodes.
edge_feats : float32 tensor of shape (E, edge_in_feats)
Input edge features, E for the number of edges.
Returns
-------
float32 tensor of shape (V, out_feats)
Updated node representations.
"""
g = g.local_var()
g.ndata['hv'] = self.project_node(node_feats)
g.edata['he'] = self.project_edge(edge_feats)
g.update_all(fn.u_mul_e('hv', 'he', 'm'), fn.sum('m', 'h'))
return self.project_out(g.ndata['h'])
......@@ -585,6 +585,41 @@ def test_sequential():
n_feat = net([g1, g2, g3], n_feat)
assert n_feat.shape == (4, 4)
def test_atomic_conv():
g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
aconv = nn.AtomicConv(interaction_cutoffs=F.tensor([12.0, 12.0]),
rbf_kernel_means=F.tensor([0.0, 2.0]),
rbf_kernel_scaling=F.tensor([4.0, 4.0]),
features_to_use=F.tensor([6.0, 8.0]))
ctx = F.ctx()
if F.gpu_ctx():
aconv = aconv.to(ctx)
feat = F.randn((100, 1))
dist = F.randn((g.number_of_edges(), 1))
h = aconv(g, feat, dist)
# current we only do shape check
assert h.shape[-1] == 4
def test_cf_conv():
g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
cfconv = nn.CFConv(node_in_feats=2,
edge_in_feats=3,
hidden_feats=2,
out_feats=3)
ctx = F.ctx()
if F.gpu_ctx():
cfconv = cfconv.to(ctx)
node_feats = F.randn((100, 2))
edge_feats = F.randn((g.number_of_edges(), 3))
h = cfconv(g, node_feats, edge_feats)
# current we only do shape check
assert h.shape[-1] == 3
if __name__ == '__main__':
test_graph_conv()
test_edge_softmax()
......@@ -608,4 +643,5 @@ if __name__ == '__main__':
test_dense_sage_conv()
test_dense_cheb_conv()
test_sequential()
test_atomic_conv()
test_cf_conv()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment