"vscode:/vscode.git/clone" did not exist on "6dbb0d81d9fade9ba668eb93f12153f0fe2f6dc7"
Unverified Commit 828a5e5b authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[DGL-LifeSci] Migration and Refactor (#1226)

* First commit

* Update

* Update splitters

* Update

* Update

* Update

* Update

* Update

* Update

* Migrate ACNN

* Fix

* Fix

* Update

* Update

* Update

* Update

* Update

* Update

* Finish classification

* Update

* Fix

* Update

* Update

* Update

* Fix

* Fix

* Fix

* Update

* Update

* Update

* trigger CI

* Fix CI

* Update

* Update

* Update

* Add default values

* Rename

* Update deprecation message
parent e4948c5c
...@@ -5,6 +5,8 @@ import warnings ...@@ -5,6 +5,8 @@ import warnings
from functools import partial from functools import partial
from multiprocessing import Pool from multiprocessing import Pool
from ....contrib.deprecation import deprecated
try: try:
import pdbfixer import pdbfixer
import simtk import simtk
...@@ -20,6 +22,7 @@ __all__ = ['add_hydrogens_to_mol', ...@@ -20,6 +22,7 @@ __all__ = ['add_hydrogens_to_mol',
'load_molecule', 'load_molecule',
'multiprocess_load_molecules'] 'multiprocess_load_molecules']
@deprecated('')
def add_hydrogens_to_mol(mol): def add_hydrogens_to_mol(mol):
"""Add hydrogens to an RDKit molecule instance. """Add hydrogens to an RDKit molecule instance.
...@@ -57,6 +60,7 @@ def add_hydrogens_to_mol(mol): ...@@ -57,6 +60,7 @@ def add_hydrogens_to_mol(mol):
warnings.warn('Failed to add hydrogens to the molecule.') warnings.warn('Failed to add hydrogens to the molecule.')
return mol return mol
@deprecated('Import it from dgllife.utils.rdkit_utils instead.')
def get_mol_3D_coordinates(mol): def get_mol_3D_coordinates(mol):
"""Get 3D coordinates of the molecule. """Get 3D coordinates of the molecule.
...@@ -83,6 +87,7 @@ def get_mol_3D_coordinates(mol): ...@@ -83,6 +87,7 @@ def get_mol_3D_coordinates(mol):
warnings.warn('Unable to get conformation of the molecule.') warnings.warn('Unable to get conformation of the molecule.')
return None return None
@deprecated('Import it from dgllife.utils.rdkit_utils instead.')
def load_molecule(molecule_file, add_hydrogens=False, sanitize=False, calc_charges=False, def load_molecule(molecule_file, add_hydrogens=False, sanitize=False, calc_charges=False,
remove_hs=False, use_conformation=True): remove_hs=False, use_conformation=True):
"""Load a molecule from a file. """Load a molecule from a file.
...@@ -161,6 +166,7 @@ def load_molecule(molecule_file, add_hydrogens=False, sanitize=False, calc_charg ...@@ -161,6 +166,7 @@ def load_molecule(molecule_file, add_hydrogens=False, sanitize=False, calc_charg
return mol, coordinates return mol, coordinates
@deprecated('Import it from dgllife.utils.rdkit_utils instead.')
def multiprocess_load_molecules(files, add_hydrogens=False, sanitize=False, calc_charges=False, def multiprocess_load_molecules(files, add_hydrogens=False, sanitize=False, calc_charges=False,
remove_hs=False, use_conformation=True, num_processes=2): remove_hs=False, use_conformation=True, num_processes=2):
"""Load molecules from files with multiprocessing. """Load molecules from files with multiprocessing.
......
...@@ -11,6 +11,7 @@ from itertools import accumulate, chain ...@@ -11,6 +11,7 @@ from itertools import accumulate, chain
from ...utils import split_dataset, Subset from ...utils import split_dataset, Subset
from .... import backend as F from .... import backend as F
from ....contrib.deprecation import deprecated
try: try:
from rdkit import Chem from rdkit import Chem
...@@ -187,7 +188,9 @@ class ConsecutiveSplitter(object): ...@@ -187,7 +188,9 @@ class ConsecutiveSplitter(object):
The dataset is split without permutation, so the splitting is deterministic. The dataset is split without permutation, so the splitting is deterministic.
""" """
@staticmethod @staticmethod
@deprecated('Import ConsecutiveSplitter from dgllife.utils.splitters instead.', 'class')
def train_val_test_split(dataset, frac_train=0.8, frac_val=0.1, frac_test=0.1): def train_val_test_split(dataset, frac_train=0.8, frac_val=0.1, frac_test=0.1):
"""Split the dataset into three consecutive chunks for training, validation and test. """Split the dataset into three consecutive chunks for training, validation and test.
...@@ -214,6 +217,7 @@ class ConsecutiveSplitter(object): ...@@ -214,6 +217,7 @@ class ConsecutiveSplitter(object):
return split_dataset(dataset, frac_list=[frac_train, frac_val, frac_test], shuffle=False) return split_dataset(dataset, frac_list=[frac_train, frac_val, frac_test], shuffle=False)
@staticmethod @staticmethod
@deprecated('Import ConsecutiveSplitter from dgllife.utils.splitters instead.', 'class')
def k_fold_split(dataset, k=5, log=True): def k_fold_split(dataset, k=5, log=True):
"""Split the dataset for k-fold cross validation by taking consecutive chunks. """Split the dataset for k-fold cross validation by taking consecutive chunks.
...@@ -240,6 +244,7 @@ class RandomSplitter(object): ...@@ -240,6 +244,7 @@ class RandomSplitter(object):
The dataset is split with permutation and the splitting is hence random. The dataset is split with permutation and the splitting is hence random.
""" """
@staticmethod @staticmethod
@deprecated('Import RandomSplitter from dgllife.utils.splitters instead.', 'class')
def train_val_test_split(dataset, frac_train=0.8, frac_val=0.1, def train_val_test_split(dataset, frac_train=0.8, frac_val=0.1,
frac_test=0.1, random_state=None): frac_test=0.1, random_state=None):
"""Randomly permute the dataset and then split it into """Randomly permute the dataset and then split it into
...@@ -275,6 +280,7 @@ class RandomSplitter(object): ...@@ -275,6 +280,7 @@ class RandomSplitter(object):
shuffle=True, random_state=random_state) shuffle=True, random_state=random_state)
@staticmethod @staticmethod
@deprecated('Import RandomSplitter from dgllife.utils.splitters instead.', 'class')
def k_fold_split(dataset, k=5, random_state=None, log=True): def k_fold_split(dataset, k=5, random_state=None, log=True):
"""Randomly permute the dataset and then split it """Randomly permute the dataset and then split it
for k-fold cross validation by taking consecutive chunks. for k-fold cross validation by taking consecutive chunks.
...@@ -309,6 +315,7 @@ class RandomSplitter(object): ...@@ -309,6 +315,7 @@ class RandomSplitter(object):
class MolecularWeightSplitter(object): class MolecularWeightSplitter(object):
"""Sort molecules based on their weights and then split them.""" """Sort molecules based on their weights and then split them."""
@staticmethod @staticmethod
@deprecated('Import MolecularWeightSplitter from dgllife.utils.splitters instead.', 'class')
def molecular_weight_indices(molecules, log_every_n): def molecular_weight_indices(molecules, log_every_n):
"""Reorder molecules based on molecular weights. """Reorder molecules based on molecular weights.
...@@ -341,6 +348,7 @@ class MolecularWeightSplitter(object): ...@@ -341,6 +348,7 @@ class MolecularWeightSplitter(object):
return np.argsort(mws) return np.argsort(mws)
@staticmethod @staticmethod
@deprecated('Import MolecularWeightSplitter from dgllife.utils.splitters instead.', 'class')
def train_val_test_split(dataset, mols=None, sanitize=True, frac_train=0.8, def train_val_test_split(dataset, mols=None, sanitize=True, frac_train=0.8,
frac_val=0.1, frac_test=0.1, log_every_n=1000): frac_val=0.1, frac_test=0.1, log_every_n=1000):
"""Sort molecules based on their weights and then split them into """Sort molecules based on their weights and then split them into
...@@ -390,6 +398,7 @@ class MolecularWeightSplitter(object): ...@@ -390,6 +398,7 @@ class MolecularWeightSplitter(object):
return indices_split(dataset, frac_train, frac_val, frac_test, sorted_indices) return indices_split(dataset, frac_train, frac_val, frac_test, sorted_indices)
@staticmethod @staticmethod
@deprecated('Import MolecularWeightSplitter from dgllife.utils.splitters instead.', 'class')
def k_fold_split(dataset, mols=None, sanitize=True, k=5, log_every_n=1000): def k_fold_split(dataset, mols=None, sanitize=True, k=5, log_every_n=1000):
"""Sort molecules based on their weights and then split them """Sort molecules based on their weights and then split them
for k-fold cross validation by taking consecutive chunks. for k-fold cross validation by taking consecutive chunks.
...@@ -439,7 +448,9 @@ class ScaffoldSplitter(object): ...@@ -439,7 +448,9 @@ class ScaffoldSplitter(object):
Bemis, G. W.; Murcko, M. A. “The Properties of Known Drugs. Bemis, G. W.; Murcko, M. A. “The Properties of Known Drugs.
1. Molecular Frameworks.” J. Med. Chem. 39:2887-93 (1996). 1. Molecular Frameworks.” J. Med. Chem. 39:2887-93 (1996).
""" """
@staticmethod @staticmethod
@deprecated('Import ScaffoldSplitter from dgllife.utils.splitters instead.', 'class')
def get_ordered_scaffold_sets(molecules, include_chirality, log_every_n): def get_ordered_scaffold_sets(molecules, include_chirality, log_every_n):
"""Group molecules based on their Bemis-Murcko scaffolds and """Group molecules based on their Bemis-Murcko scaffolds and
order these groups based on their sizes. order these groups based on their sizes.
...@@ -494,6 +505,7 @@ class ScaffoldSplitter(object): ...@@ -494,6 +505,7 @@ class ScaffoldSplitter(object):
return scaffold_sets return scaffold_sets
@staticmethod @staticmethod
@deprecated('Import ScaffoldSplitter from dgllife.utils.splitters instead.', 'class')
def train_val_test_split(dataset, mols=None, sanitize=True, include_chirality=False, def train_val_test_split(dataset, mols=None, sanitize=True, include_chirality=False,
frac_train=0.8, frac_val=0.1, frac_test=0.1, log_every_n=1000): frac_train=0.8, frac_val=0.1, frac_test=0.1, log_every_n=1000):
"""Split the dataset into training, validation and test set based on molecular scaffolds. """Split the dataset into training, validation and test set based on molecular scaffolds.
...@@ -564,6 +576,7 @@ class ScaffoldSplitter(object): ...@@ -564,6 +576,7 @@ class ScaffoldSplitter(object):
Subset(dataset, test_indices)] Subset(dataset, test_indices)]
@staticmethod @staticmethod
@deprecated('Import ScaffoldSplitter from dgllife.utils.splitters instead.', 'class')
def k_fold_split(dataset, mols=None, sanitize=True, def k_fold_split(dataset, mols=None, sanitize=True,
include_chirality=False, k=5, log_every_n=1000): include_chirality=False, k=5, log_every_n=1000):
"""Group molecules based on their scaffolds and sort groups based on their sizes. """Group molecules based on their scaffolds and sort groups based on their sizes.
...@@ -636,6 +649,8 @@ class SingleTaskStratifiedSplitter(object): ...@@ -636,6 +649,8 @@ class SingleTaskStratifiedSplitter(object):
take buckets of datapoints to augment the training, validation and test subsets. take buckets of datapoints to augment the training, validation and test subsets.
""" """
@staticmethod @staticmethod
@deprecated('Import SingleTaskStratifiedSplitter from '
'dgllife.utils.splitters instead.', 'class')
def train_val_test_split(dataset, labels, task_id, frac_train=0.8, frac_val=0.1, def train_val_test_split(dataset, labels, task_id, frac_train=0.8, frac_val=0.1,
frac_test=0.1, bucket_size=10, random_state=None): frac_test=0.1, bucket_size=10, random_state=None):
"""Split the dataset into training, validation and test subsets as stated above. """Split the dataset into training, validation and test subsets as stated above.
...@@ -707,6 +722,8 @@ class SingleTaskStratifiedSplitter(object): ...@@ -707,6 +722,8 @@ class SingleTaskStratifiedSplitter(object):
Subset(dataset, test_indices)] Subset(dataset, test_indices)]
@staticmethod @staticmethod
@deprecated('Import SingleTaskStratifiedSplitter from '
'dgllife.utils.splitters instead.', 'class')
def k_fold_split(dataset, labels, task_id, k=5, log=True): def k_fold_split(dataset, labels, task_id, k=5, log=True):
"""Sort molecules based on their label values for a task and then split them """Sort molecules based on their label values for a task and then split them
for k-fold cross validation by taking consecutive chunks. for k-fold cross validation by taking consecutive chunks.
......
...@@ -5,6 +5,7 @@ import torch ...@@ -5,6 +5,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from ...nn.pytorch import AtomicConv from ...nn.pytorch import AtomicConv
from ...contrib.deprecation import deprecated
def truncated_normal_(tensor, mean=0., std=1.): def truncated_normal_(tensor, mean=0., std=1.):
"""Fills the given tensor in-place with elements sampled from the truncated normal """Fills the given tensor in-place with elements sampled from the truncated normal
...@@ -146,6 +147,7 @@ class ACNN(nn.Module): ...@@ -146,6 +147,7 @@ class ACNN(nn.Module):
num_tasks : int num_tasks : int
Number of output tasks. Number of output tasks.
""" """
@deprecated('Import ACNN from dgllife.model instead.')
def __init__(self, hidden_sizes, weight_init_stddevs, dropouts, def __init__(self, hidden_sizes, weight_init_stddevs, dropouts,
features_to_use=None, radial=None, num_tasks=1): features_to_use=None, radial=None, num_tasks=1):
super(ACNN, self).__init__() super(ACNN, self).__init__()
......
...@@ -7,6 +7,7 @@ import torch.nn as nn ...@@ -7,6 +7,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from ... import function as fn from ... import function as fn
from ...contrib.deprecation import deprecated
from ...nn.pytorch.softmax import edge_softmax from ...nn.pytorch.softmax import edge_softmax
class AttentiveGRU1(nn.Module): class AttentiveGRU1(nn.Module):
...@@ -296,6 +297,7 @@ class AttentiveFP(nn.Module): ...@@ -296,6 +297,7 @@ class AttentiveFP(nn.Module):
dropout : float dropout : float
The probability for performing dropout. The probability for performing dropout.
""" """
@deprecated('Import AttentiveFPPredictor from dgllife.model instead.', 'class')
def __init__(self, def __init__(self,
node_feat_size, node_feat_size,
edge_feat_size, edge_feat_size,
......
...@@ -2,11 +2,11 @@ ...@@ -2,11 +2,11 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import dgl
from dgl import BatchedDGLGraph
from .gnn import GCNLayer, GATLayer from .gnn import GCNLayer, GATLayer
from ...batched_graph import BatchedDGLGraph, max_nodes
from ...nn.pytorch import WeightAndSum from ...nn.pytorch import WeightAndSum
from ...contrib.deprecation import deprecated
class MLPBinaryClassifier(nn.Module): class MLPBinaryClassifier(nn.Module):
"""MLP for soft binary classification over multiple tasks from molecule representations. """MLP for soft binary classification over multiple tasks from molecule representations.
...@@ -98,7 +98,7 @@ class BaseGNNClassifier(nn.Module): ...@@ -98,7 +98,7 @@ class BaseGNNClassifier(nn.Module):
with bg.local_scope(): with bg.local_scope():
bg.ndata['h'] = feats bg.ndata['h'] = feats
h_g_max = dgl.max_nodes(bg, 'h') h_g_max = max_nodes(bg, 'h')
if not isinstance(bg, BatchedDGLGraph): if not isinstance(bg, BatchedDGLGraph):
h_g_sum = h_g_sum.unsqueeze(0) h_g_sum = h_g_sum.unsqueeze(0)
...@@ -127,6 +127,7 @@ class GCNClassifier(BaseGNNClassifier): ...@@ -127,6 +127,7 @@ class GCNClassifier(BaseGNNClassifier):
The probability for dropout. Default to be 0., i.e. no The probability for dropout. Default to be 0., i.e. no
dropout is performed. dropout is performed.
""" """
@deprecated('Import GCNPredictor from dgllife.model instead.', 'class')
def __init__(self, in_feats, gcn_hidden_feats, n_tasks, def __init__(self, in_feats, gcn_hidden_feats, n_tasks,
classifier_hidden_feats=128, dropout=0.): classifier_hidden_feats=128, dropout=0.):
super(GCNClassifier, self).__init__(gnn_out_feats=gcn_hidden_feats[-1], super(GCNClassifier, self).__init__(gnn_out_feats=gcn_hidden_feats[-1],
...@@ -148,6 +149,7 @@ class GATClassifier(BaseGNNClassifier): ...@@ -148,6 +149,7 @@ class GATClassifier(BaseGNNClassifier):
in_feats : int in_feats : int
Number of input atom features Number of input atom features
""" """
@deprecated('Import GATPredictor from dgllife.model instead.', 'class')
def __init__(self, in_feats, gat_hidden_feats, num_heads, def __init__(self, in_feats, gat_hidden_feats, num_heads,
n_tasks, classifier_hidden_feats=128, dropout=0): n_tasks, classifier_hidden_feats=128, dropout=0):
super(GATClassifier, self).__init__(gnn_out_feats=gat_hidden_feats[-1], super(GATClassifier, self).__init__(gnn_out_feats=gat_hidden_feats[-1],
......
...@@ -13,6 +13,7 @@ from torch.distributions import Categorical ...@@ -13,6 +13,7 @@ from torch.distributions import Categorical
import dgl import dgl
from dgl import DGLGraph from dgl import DGLGraph
from dgl.contrib.deprecation import deprecated
try: try:
from rdkit import Chem from rdkit import Chem
...@@ -647,6 +648,7 @@ class DGMG(nn.Module): ...@@ -647,6 +648,7 @@ class DGMG(nn.Module):
dropout : float dropout : float
Probability for dropout Probability for dropout
""" """
@deprecated('Import DGMG from dgllife.model instead.', 'class')
def __init__(self, atom_types, bond_types, node_hidden_size, num_prop_rounds, dropout): def __init__(self, atom_types, bond_types, node_hidden_size, num_prop_rounds, dropout):
super(DGMG, self).__init__() super(DGMG, self).__init__()
......
"""JTNN Module """JTNN Module"""
"""
from .chemutils import decode_stereo from .chemutils import decode_stereo
from .jtnn_vae import DGLJTNNVAE from .jtnn_vae import DGLJTNNVAE
from .mol_tree import Vocab from .mol_tree import Vocab
......
...@@ -7,8 +7,9 @@ import torch.nn.functional as F ...@@ -7,8 +7,9 @@ import torch.nn.functional as F
import rdkit.Chem as Chem import rdkit.Chem as Chem
from dgl import batch, unbatch from ....batched_graph import batch, unbatch
from dgl.data.utils import get_download_dir from ....contrib.deprecation import deprecated
from ....data.utils import get_download_dir
from .chemutils import (attach_mols_nx, copy_edit_mol, decode_stereo, from .chemutils import (attach_mols_nx, copy_edit_mol, decode_stereo,
enum_assemble_nx, set_atommap) enum_assemble_nx, set_atommap)
...@@ -27,8 +28,10 @@ class DGLJTNNVAE(nn.Module): ...@@ -27,8 +28,10 @@ class DGLJTNNVAE(nn.Module):
`Junction Tree Variational Autoencoder for Molecular Graph Generation `Junction Tree Variational Autoencoder for Molecular Graph Generation
<https://arxiv.org/abs/1802.04364>`__ <https://arxiv.org/abs/1802.04364>`__
""" """
@deprecated('Import DGLJTNNVAE from dgllife.model instead.', 'class')
def __init__(self, hidden_size, latent_size, depth, vocab=None, vocab_file=None): def __init__(self, hidden_size, latent_size, depth, vocab=None, vocab_file=None):
super(DGLJTNNVAE, self).__init__() super(DGLJTNNVAE, self).__init__()
if vocab is None: if vocab is None:
if vocab_file is None: if vocab_file is None:
vocab_file = '{}/jtnn/{}.txt'.format( vocab_file = '{}/jtnn/{}.txt'.format(
......
...@@ -7,6 +7,7 @@ import torch.nn as nn ...@@ -7,6 +7,7 @@ import torch.nn as nn
from .layers import AtomEmbedding, RBFLayer, EdgeEmbedding, \ from .layers import AtomEmbedding, RBFLayer, EdgeEmbedding, \
MultiLevelInteraction MultiLevelInteraction
from ...nn.pytorch import SumPooling from ...nn.pytorch import SumPooling
from ...contrib.deprecation import deprecated
class MGCNModel(nn.Module): class MGCNModel(nn.Module):
...@@ -37,6 +38,7 @@ class MGCNModel(nn.Module): ...@@ -37,6 +38,7 @@ class MGCNModel(nn.Module):
If None, random representation initialization will be used. Otherwise, If None, random representation initialization will be used. Otherwise,
they will be used to initialize atom representations. Default to be None. they will be used to initialize atom representations. Default to be None.
""" """
@deprecated('Import MGCNPredictor from dgllife.model instead.', 'class')
def __init__(self, def __init__(self,
dim=128, dim=128,
width=1, width=1,
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from ...contrib.deprecation import deprecated
from ...nn.pytorch import Set2Set, NNConv from ...nn.pytorch import Set2Set, NNConv
class MPNNModel(nn.Module): class MPNNModel(nn.Module):
...@@ -31,6 +32,7 @@ class MPNNModel(nn.Module): ...@@ -31,6 +32,7 @@ class MPNNModel(nn.Module):
num_layer_set2set : int num_layer_set2set : int
Number of set2set layers Number of set2set layers
""" """
@deprecated('Import MPNNPredictor from dgllife.model instead.', 'class')
def __init__(self, def __init__(self,
node_input_dim=15, node_input_dim=15,
edge_input_dim=5, edge_input_dim=5,
......
"""Utilities for using pretrained models.""" """Utilities for using pretrained models."""
import os import os
import numpy as np
import torch import torch
from rdkit import Chem from rdkit import Chem
...@@ -11,8 +10,8 @@ from .mgcn import MGCNModel ...@@ -11,8 +10,8 @@ from .mgcn import MGCNModel
from .mpnn import MPNNModel from .mpnn import MPNNModel
from .schnet import SchNet from .schnet import SchNet
from .attentive_fp import AttentiveFP from .attentive_fp import AttentiveFP
from .acnn import ACNN
from ...data.utils import _get_dgl_url, download, get_download_dir, extract_archive from ...data.utils import _get_dgl_url, download, get_download_dir, extract_archive
from ...contrib.deprecation import deprecated
URL = { URL = {
'GCN_Tox21': 'pre_trained/gcn_tox21.pth', 'GCN_Tox21': 'pre_trained/gcn_tox21.pth',
...@@ -63,6 +62,7 @@ def download_and_load_checkpoint(model_name, model, model_postfix, ...@@ -63,6 +62,7 @@ def download_and_load_checkpoint(model_name, model, model_postfix,
return model return model
@deprecated('Import it from dgllife.model instead.')
def load_pretrained(model_name, log=True): def load_pretrained(model_name, log=True):
"""Load a pretrained model """Load a pretrained model
...@@ -82,14 +82,6 @@ def load_pretrained(model_name, log=True): ...@@ -82,14 +82,6 @@ def load_pretrained(model_name, log=True):
* ``'DGMG_ZINC_canonical'`` * ``'DGMG_ZINC_canonical'``
* ``'DGMG_ZINC_random'`` * ``'DGMG_ZINC_random'``
* ``'JTNN_ZINC'`` * ``'JTNN_ZINC'``
* ``'ACNN_PDBBind_core_pocket_random'``
* ``'ACNN_PDBBind_core_pocket_scaffold'``
* ``'ACNN_PDBBind_core_pocket_stratified'``
* ``'ACNN_PDBBind_core_pocket_temporal'``
* ``'ACNN_PDBBind_refined_pocket_random'``
* ``'ACNN_PDBBind_refined_pocket_scaffold'``
* ``'ACNN_PDBBind_refined_pocket_stratified'``
* ``'ACNN_PDBBind_refined_pocket_temporal'``
log : bool log : bool
Whether to print progress for model loading Whether to print progress for model loading
...@@ -152,21 +144,11 @@ def load_pretrained(model_name, log=True): ...@@ -152,21 +144,11 @@ def load_pretrained(model_name, log=True):
vocab_file = '{}/jtnn/{}.txt'.format(default_dir, 'vocab') vocab_file = '{}/jtnn/{}.txt'.format(default_dir, 'vocab')
if not os.path.exists(vocab_file): if not os.path.exists(vocab_file):
zip_file_path = '{}/jtnn.zip'.format(default_dir) zip_file_path = '{}/jtnn.zip'.format(default_dir)
download('https://s3-ap-southeast-1.amazonaws.com/dgl-data-cn/dataset/jtnn.zip', download(_get_dgl_url('dgllife/jtnn.zip'), path=zip_file_path)
path=zip_file_path)
extract_archive(zip_file_path, '{}/jtnn'.format(default_dir)) extract_archive(zip_file_path, '{}/jtnn'.format(default_dir))
model = DGLJTNNVAE(vocab_file=vocab_file, model = DGLJTNNVAE(vocab_file=vocab_file,
depth=3, depth=3,
hidden_size=450, hidden_size=450,
latent_size=56) latent_size=56)
elif model_name.startswith('ACNN_PDBBind_core_pocket'):
model = ACNN(hidden_sizes=[32, 32, 16],
weight_init_stddevs=[1. / float(np.sqrt(32)), 1. / float(np.sqrt(32)),
1. / float(np.sqrt(16)), 0.01],
dropouts=[0., 0., 0.],
features_to_use=torch.tensor([
1., 6., 7., 8., 9., 11., 12., 15., 16., 17., 20., 25., 30., 35., 53.]),
radial=[[12.0], [0.0, 4.0, 8.0], [4.0]])
return download_and_load_checkpoint(model_name, model, URL[model_name], log=log) return download_and_load_checkpoint(model_name, model, URL[model_name], log=log)
...@@ -5,6 +5,7 @@ import torch ...@@ -5,6 +5,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from .layers import AtomEmbedding, Interaction, ShiftSoftplus, RBFLayer from .layers import AtomEmbedding, Interaction, ShiftSoftplus, RBFLayer
from ...contrib.deprecation import deprecated
from ...nn.pytorch import SumPooling from ...nn.pytorch import SumPooling
...@@ -34,6 +35,7 @@ class SchNet(nn.Module): ...@@ -34,6 +35,7 @@ class SchNet(nn.Module):
If None, random representation initialization will be used. Otherwise, If None, random representation initialization will be used. Otherwise,
they will be used to initialize atom representations. Default to be None. they will be used to initialize atom representations. Default to be None.
""" """
@deprecated('Import SchNetPredictor from dgllife.model instead.')
def __init__(self, def __init__(self,
dim=64, dim=64,
cutoff=5.0, cutoff=5.0,
......
...@@ -19,8 +19,9 @@ from .densechebconv import DenseChebConv ...@@ -19,8 +19,9 @@ from .densechebconv import DenseChebConv
from .densegraphconv import DenseGraphConv from .densegraphconv import DenseGraphConv
from .densesageconv import DenseSAGEConv from .densesageconv import DenseSAGEConv
from .atomicconv import AtomicConv from .atomicconv import AtomicConv
from .cfconv import CFConv
__all__ = ['GraphConv', 'GATConv', 'TAGConv', 'RelGraphConv', 'SAGEConv', __all__ = ['GraphConv', 'GATConv', 'TAGConv', 'RelGraphConv', 'SAGEConv',
'SGConv', 'APPNPConv', 'GINConv', 'GatedGraphConv', 'GMMConv', 'SGConv', 'APPNPConv', 'GINConv', 'GatedGraphConv', 'GMMConv',
'ChebConv', 'AGNNConv', 'NNConv', 'DenseGraphConv', 'DenseSAGEConv', 'ChebConv', 'AGNNConv', 'NNConv', 'DenseGraphConv', 'DenseSAGEConv',
'DenseChebConv', 'EdgeConv', 'AtomicConv'] 'DenseChebConv', 'EdgeConv', 'AtomicConv', 'CFConv']
"""Torch Module for Atomic Convolution Layer""" """Torch Module for Atomic Convolution Layer"""
# pylint: disable= no-member, arguments-differ, invalid-name
import numpy as np import numpy as np
import torch as th import torch as th
import torch.nn as nn import torch.nn as nn
......
"""Torch modules for interaction blocks in SchNet"""
# pylint: disable= no-member, arguments-differ, invalid-name
import numpy as np
import torch.nn as nn
from .... import function as fn
class ShiftedSoftplus(nn.Module):
r"""Applies the element-wise function:
.. math::
\text{SSP}(x) = \frac{1}{\beta} * \log(1 + \exp(\beta * x)) - \log(\text{shift})
Parameters
----------
beta : int
:math:`\beta` value for the mathematical formulation. Default to 1.
shift : int
:math:`\text{shift}` value for the mathematical formulation. Default to 2.
"""
def __init__(self, beta=1, shift=2, threshold=20):
super(ShiftedSoftplus, self).__init__()
self.shift = shift
self.softplus = nn.Softplus(beta=beta, threshold=threshold)
def forward(self, inputs):
"""Applies the activation function.
Parameters
----------
inputs : float32 tensor of shape (N, *)
* denotes any number of additional dimensions.
Returns
-------
float32 tensor of shape (N, *)
Result of applying the activation function to the input.
"""
return self.softplus(inputs) - np.log(float(self.shift))
class CFConv(nn.Module):
r"""CFConv in SchNet.
SchNet is introduced in `SchNet: A continuous-filter convolutional neural network for
modeling quantum interactions <https://arxiv.org/abs/1706.08566>`__.
It combines node and edge features in message passing and updates node representations.
Parameters
----------
node_in_feats : int
Size for the input node features.
edge_in_feats : int
Size for the input edge features.
hidden_feats : int
Size for the hidden representations.
out_feats : int
Size for the output representations.
"""
def __init__(self, node_in_feats, edge_in_feats, hidden_feats, out_feats):
super(CFConv, self).__init__()
self.project_edge = nn.Sequential(
nn.Linear(edge_in_feats, hidden_feats),
ShiftedSoftplus(),
nn.Linear(hidden_feats, hidden_feats),
ShiftedSoftplus()
)
self.project_node = nn.Linear(node_in_feats, hidden_feats)
self.project_out = nn.Sequential(
nn.Linear(hidden_feats, out_feats),
ShiftedSoftplus()
)
def forward(self, g, node_feats, edge_feats):
"""Performs message passing and updates node representations.
Parameters
----------
g : DGLGraph
The graph.
node_feats : float32 tensor of shape (V, node_in_feats)
Input node features, V for the number of nodes.
edge_feats : float32 tensor of shape (E, edge_in_feats)
Input edge features, E for the number of edges.
Returns
-------
float32 tensor of shape (V, out_feats)
Updated node representations.
"""
g = g.local_var()
g.ndata['hv'] = self.project_node(node_feats)
g.edata['he'] = self.project_edge(edge_feats)
g.update_all(fn.u_mul_e('hv', 'he', 'm'), fn.sum('m', 'h'))
return self.project_out(g.ndata['h'])
...@@ -585,6 +585,41 @@ def test_sequential(): ...@@ -585,6 +585,41 @@ def test_sequential():
n_feat = net([g1, g2, g3], n_feat) n_feat = net([g1, g2, g3], n_feat)
assert n_feat.shape == (4, 4) assert n_feat.shape == (4, 4)
def test_atomic_conv():
g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
aconv = nn.AtomicConv(interaction_cutoffs=F.tensor([12.0, 12.0]),
rbf_kernel_means=F.tensor([0.0, 2.0]),
rbf_kernel_scaling=F.tensor([4.0, 4.0]),
features_to_use=F.tensor([6.0, 8.0]))
ctx = F.ctx()
if F.gpu_ctx():
aconv = aconv.to(ctx)
feat = F.randn((100, 1))
dist = F.randn((g.number_of_edges(), 1))
h = aconv(g, feat, dist)
# current we only do shape check
assert h.shape[-1] == 4
def test_cf_conv():
g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
cfconv = nn.CFConv(node_in_feats=2,
edge_in_feats=3,
hidden_feats=2,
out_feats=3)
ctx = F.ctx()
if F.gpu_ctx():
cfconv = cfconv.to(ctx)
node_feats = F.randn((100, 2))
edge_feats = F.randn((g.number_of_edges(), 3))
h = cfconv(g, node_feats, edge_feats)
# current we only do shape check
assert h.shape[-1] == 3
if __name__ == '__main__': if __name__ == '__main__':
test_graph_conv() test_graph_conv()
test_edge_softmax() test_edge_softmax()
...@@ -608,4 +643,5 @@ if __name__ == '__main__': ...@@ -608,4 +643,5 @@ if __name__ == '__main__':
test_dense_sage_conv() test_dense_sage_conv()
test_dense_cheb_conv() test_dense_cheb_conv()
test_sequential() test_sequential()
test_atomic_conv()
test_cf_conv()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment