Unverified Commit b76ac11c authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[DGL-LifeSci] Release Preparation (CI, Docker, Conda build) (#1399)



* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* add docs

* Fix style

* Fix lint

* Bug fix

* Fix test

* Update

* Update

* Update

* Update
Co-authored-by: default avatarMinjie Wang <wmjlyjemaine@gmail.com>
parent e4cc8185
# pylint: disable=C0111, C0103, E1101, W0611, W0612
# pylint: disable=C0111, C0103, E1101, W0611, W0612, W0221
import numpy as np
import torch
import torch.nn as nn
......
# pylint: disable=C0111, C0103, E1101, W0611, W0612, C0200
# pylint: disable=C0111, C0103, E1101, W0611, W0612, C0200, W0221, E1102
import copy
import rdkit.Chem as Chem
import torch
......
# pylint: disable=C0111, C0103, E1101, W0611, W0612
# pylint: disable=C0111, C0103, E1101, W0611, W0612, I1101, W0221
# pylint: disable=redefined-outer-name
import rdkit.Chem as Chem
import torch
......
# pylint: disable=C0111, C0103, E1101, W0611, W0612
# pylint: disable=C0111, C0103, E1101, W0611, W0612, W0221
import os
import torch
import torch.nn as nn
......
"""MGCN"""
# pylint: disable= no-member, arguments-differ, invalid-name
import torch.nn as nn
from ..gnn import MGCNGNN
......@@ -6,6 +7,7 @@ from ..readout import MLPNodeReadout
__all__ = ['MGCNPredictor']
# pylint: disable=W0221
class MGCNPredictor(nn.Module):
"""MGCN for for regression and classification on graphs.
......
"""MLP for prediction on the output of readout."""
# pylint: disable= no-member, arguments-differ, invalid-name
import torch.nn as nn
# pylint: disable=W0221
class MLPPredictor(nn.Module):
"""Two-layer MLP for regression or soft classification
over multiple tasks from graph representations.
......
"""MPNN"""
# pylint: disable= no-member, arguments-differ, invalid-name
import torch.nn as nn
from dgl.nn.pytorch import Set2Set
......@@ -7,6 +8,7 @@ from ..gnn import MPNNGNN
__all__ = ['MPNNPredictor']
# pylint: disable=W0221
class MPNNPredictor(nn.Module):
"""MPNN for regression and classification on graphs.
......
"""SchNet"""
# pylint: disable= no-member, arguments-differ, invalid-name
import torch.nn as nn
from dgl.nn.pytorch.conv.cfconv import ShiftedSoftplus
......@@ -8,6 +9,7 @@ from ..readout import MLPNodeReadout
__all__ = ['SchNetPredictor']
# pylint: disable=W0221
class SchNetPredictor(nn.Module):
"""SchNet for regression and classification on graphs.
......
"""Weisfeiler-Lehman Network (WLN) for Reaction Center Prediction."""
# pylint: disable= no-member, arguments-differ, invalid-name
import dgl.function as fn
import torch
import torch.nn as nn
......@@ -7,6 +8,7 @@ from ..gnn.wln import WLNLinear, WLN
__all__ = ['WLNReactionCenter']
# pylint: disable=W0221, E1101
class WLNContext(nn.Module):
"""Attention-based context computation for each node.
......
"""Utilities for using pretrained models."""
# pylint: disable= no-member, arguments-differ, invalid-name
import os
import torch
import torch.nn.functional as F
......@@ -58,6 +59,7 @@ def download_and_load_checkpoint(model_name, model, model_postfix,
return model
# pylint: disable=I1101
def load_pretrained(model_name, log=True):
"""Load a pretrained model
......
"""Readout for AttentiveFP"""
# pylint: disable= no-member, arguments-differ, invalid-name
import dgl
import torch
import torch.nn as nn
......@@ -6,6 +7,7 @@ import torch.nn.functional as F
__all__ = ['AttentiveFPReadout']
# pylint: disable=W0221
class GlobalPool(nn.Module):
"""One-step readout in AttentiveFP
......@@ -88,7 +90,7 @@ class AttentiveFPReadout(nn.Module):
super(AttentiveFPReadout, self).__init__()
self.readouts = nn.ModuleList()
for t in range(num_timesteps):
for _ in range(num_timesteps):
self.readouts.append(GlobalPool(feat_size, dropout))
def forward(self, g, node_feats, get_node_weight=False):
......
"""Readout for SchNet"""
# pylint: disable= no-member, arguments-differ, invalid-name
import dgl
import torch.nn as nn
__all__ = ['MLPNodeReadout']
# pylint: disable=W0221
class MLPNodeReadout(nn.Module):
"""MLP-based Readout.
......
"""Apply weighted sum and max pooling to the node representations and concatenate the results."""
# pylint: disable= no-member, arguments-differ, invalid-name
import dgl
import torch
import torch.nn as nn
......@@ -7,6 +8,7 @@ from dgl.nn.pytorch import WeightAndSum
__all__ = ['WeightedSumAndMax']
# pylint: disable=W0221
class WeightedSumAndMax(nn.Module):
r"""Apply weighted sum and max pooling to the node
representations and concatenate the results.
......
"""Convert complexes into DGLHeteroGraphs"""
# pylint: disable= no-member, arguments-differ, invalid-name
import dgl.backend as F
import numpy as np
......@@ -50,6 +51,7 @@ def get_atomic_numbers(mol, indices):
atomic_numbers.append(atom.GetAtomicNum())
return atomic_numbers
# pylint: disable=C0326
def ACNN_graph_construction_and_featurization(ligand_mol,
protein_mol,
ligand_coordinates,
......@@ -216,7 +218,7 @@ def ACNN_graph_construction_and_featurization(ligand_mol,
protein_atomic_numbers = np.concatenate([
protein_atomic_numbers, np.zeros(num_protein_atoms - len(protein_atom_indices_left))])
g.nodes['ligand_atom'].data['atomic_number'] = F.reshape(F.zerocopy_from_numpy(
g.nodes['ligand_atom'].data['atomic_number'] = F.reshape(F.zerocopy_from_numpy(
ligand_atomic_numbers.astype(np.float32)), (-1, 1))
g.nodes['protein_atom'].data['atomic_number'] = F.reshape(F.zerocopy_from_numpy(
protein_atomic_numbers.astype(np.float32)), (-1, 1))
......
"""Early stopping"""
# pylint: disable= no-member, arguments-differ, invalid-name
import datetime
import torch
__all__ = ['EarlyStopping']
# pylint: disable=C0103
class EarlyStopping(object):
"""Early stop tracker
......@@ -56,7 +58,7 @@ class EarlyStopping(object):
bool
Whether the new score is higher than the previous best score.
"""
return (score > prev_best_score)
return score > prev_best_score
def _check_lower(self, score, prev_best_score):
"""Check if the new score is lower than the previous best score.
......@@ -73,7 +75,7 @@ class EarlyStopping(object):
bool
Whether the new score is lower than the previous best score.
"""
return (score < prev_best_score)
return score < prev_best_score
def step(self, score, model):
"""Update based on a new score.
......
"""Evaluation of model performance."""
# pylint: disable= no-member, arguments-differ, invalid-name
import numpy as np
import torch
import torch.nn.functional as F
......@@ -8,6 +9,7 @@ from sklearn.metrics import roc_auc_score
__all__ = ['Meter']
# pylint: disable=E1101
class Meter(object):
"""Track and summarize model performance on a dataset for (multi-label) prediction.
......@@ -252,12 +254,12 @@ class Meter(object):
"""
if metric_name == 'r2':
return self.pearson_r2(reduction)
if metric_name == 'mae':
elif metric_name == 'mae':
return self.mae(reduction)
if metric_name == 'rmse':
elif metric_name == 'rmse':
return self.rmse(reduction)
if metric_name == 'roc_auc_score':
elif metric_name == 'roc_auc_score':
return self.roc_auc_score(reduction)
else:
raise ValueError('Expect metric_name to be "r2" or "mae" or "rmse" '
'or "roc_auc_score", got {}'.format(metric_name))
"""Node and edge featurization for molecular graphs."""
import dgl.backend as F
# pylint: disable= no-member, arguments-differ, invalid-name
import itertools
from collections import defaultdict
import dgl.backend as F
import numpy as np
from collections import defaultdict
from rdkit import Chem
__all__ = ['one_hot_encoding',
......@@ -302,6 +303,7 @@ def atom_implicit_valence(atom):
"""
return [atom.GetImplicitValence()]
# pylint: disable=I1101
def atom_hybridization_one_hot(atom, allowable_set=None, encode_unknown=False):
"""One hot encoding for the hybridization of an atom.
......@@ -601,7 +603,7 @@ class BaseAtomFeaturizer(object):
Examples
--------
>>> from dgl.data.life_sci import BaseAtomFeaturizer, atom_mass, atom_degree_one_hot
>>> from dgl.data.dgllife import BaseAtomFeaturizer, atom_mass, atom_degree_one_hot
>>> from rdkit import Chem
>>> mol = Chem.MolFromSmiles('CCO')
......@@ -856,17 +858,17 @@ class BaseBondFeaturizer(object):
Examples
--------
>>> from dgl.data.life_sci import BaseBondFeaturizer, bond_type_one_hot, bond_is_in_ring
>>> from dgl.data.dgllife import BaseBondFeaturizer, bond_type_one_hot, bond_is_in_ring
>>> from rdkit import Chem
>>> mol = Chem.MolFromSmiles('CCO')
>>> bond_featurizer = BaseBondFeaturizer({'bond_type': bond_type_one_hot, 'in_ring': bond_is_in_ring})
>>> bond_featurizer = BaseBondFeaturizer({'type': bond_type_one_hot, 'ring': bond_is_in_ring})
>>> bond_featurizer(mol)
{'bond_type': tensor([[1., 0., 0., 0.],
[1., 0., 0., 0.],
[1., 0., 0., 0.],
[1., 0., 0., 0.]]),
'in_ring': tensor([[0.], [0.], [0.], [0.]])}
{'type': tensor([[1., 0., 0., 0.],
[1., 0., 0., 0.],
[1., 0., 0., 0.],
[1., 0., 0., 0.]]),
'ring': tensor([[0.], [0.], [0.], [0.]])}
"""
def __init__(self, featurizer_funcs, feat_sizes=None):
self.featurizer_funcs = featurizer_funcs
......
"""Convert molecules into DGLGraphs."""
# pylint: disable= no-member, arguments-differ, invalid-name
from functools import partial
import torch
from dgl import DGLGraph
from functools import partial
from rdkit import Chem
from rdkit.Chem import rdmolfiles, rdmolops
from sklearn.neighbors import NearestNeighbors
......@@ -16,6 +17,7 @@ __all__ = ['mol_to_graph',
'mol_to_nearest_neighbor_graph',
'smiles_to_nearest_neighbor_graph']
# pylint: disable=I1101
def mol_to_graph(mol, graph_constructor, node_featurizer, edge_featurizer, canonical_atom_order):
"""Convert an RDKit molecule object into a DGLGraph and featurize for it.
......@@ -225,7 +227,8 @@ def mol_to_complete_graph(mol, add_self_loop=False,
g : DGLGraph
Complete DGLGraph for the molecule
"""
return mol_to_graph(mol, partial(construct_complete_graph_from_mol, add_self_loop=add_self_loop),
return mol_to_graph(mol,
partial(construct_complete_graph_from_mol, add_self_loop=add_self_loop),
node_featurizer, edge_featurizer, canonical_atom_order)
def smiles_to_complete_graph(smiles, add_self_loop=False,
......@@ -321,6 +324,7 @@ def k_nearest_neighbors(coordinates, neighbor_cutoff, max_num_neighbors=None,
return srcs, dsts, dists
# pylint: disable=E1102
def mol_to_nearest_neighbor_graph(mol,
coordinates,
neighbor_cutoff,
......
"""Utils for RDKit, mostly adapted from DeepChem
(https://github.com/deepchem/deepchem/blob/master/deepchem)."""
# pylint: disable= no-member, arguments-differ, invalid-name
import warnings
from functools import partial
......@@ -7,16 +8,12 @@ from multiprocessing import Pool
from rdkit import Chem
from rdkit.Chem import AllChem
try:
from StringIO import StringIO
except ImportError:
from io import StringIO
__all__ = ['get_mol_3D_coordinates',
__all__ = ['get_mol_3d_coordinates',
'load_molecule',
'multiprocess_load_molecules']
def get_mol_3D_coordinates(mol):
# pylint: disable=W0702
def get_mol_3d_coordinates(mol):
"""Get 3D coordinates of the molecule.
Parameters
......@@ -42,6 +39,7 @@ def get_mol_3D_coordinates(mol):
warnings.warn('Unable to get conformation of the molecule.')
return None
# pylint: disable=E1101
def load_molecule(molecule_file, sanitize=False, calc_charges=False,
remove_hs=False, use_conformation=True):
"""Load a molecule from a file.
......@@ -80,8 +78,8 @@ def load_molecule(molecule_file, sanitize=False, calc_charges=False,
supplier = Chem.SDMolSupplier(molecule_file, sanitize=False, removeHs=False)
mol = supplier[0]
elif molecule_file.endswith('.pdbqt'):
with open(molecule_file) as f:
pdbqt_data = f.readlines()
with open(molecule_file) as file:
pdbqt_data = file.readlines()
pdb_block = ''
for line in pdbqt_data:
pdb_block += '{}\n'.format(line[:66])
......@@ -109,7 +107,7 @@ def load_molecule(molecule_file, sanitize=False, calc_charges=False,
return None, None
if use_conformation:
coordinates = get_mol_3D_coordinates(mol)
coordinates = get_mol_3d_coordinates(mol)
else:
coordinates = None
......@@ -151,7 +149,7 @@ def multiprocess_load_molecules(files, sanitize=False, calc_charges=False,
"""
if num_processes == 1:
mols_loaded = []
for i, f in enumerate(files):
for f in files:
mols_loaded.append(load_molecule(
f, sanitize=sanitize, calc_charges=calc_charges,
remove_hs=remove_hs, use_conformation=use_conformation))
......
......@@ -3,10 +3,8 @@
We mostly adapt them from deepchem
(https://github.com/deepchem/deepchem/blob/master/deepchem/splits/splitters.py).
"""
import dgl.backend as F
import numpy as np
from dgl.data.utils import split_dataset, Subset
# pylint: disable= no-member, arguments-differ, invalid-name
# pylint: disable=E0611
from collections import defaultdict
from functools import partial
from itertools import accumulate, chain
......@@ -15,6 +13,10 @@ from rdkit.Chem import rdMolDescriptors
from rdkit.Chem.rdmolops import FastFindRings
from rdkit.Chem.Scaffolds import MurckoScaffold
import dgl.backend as F
import numpy as np
from dgl.data.utils import split_dataset, Subset
__all__ = ['ConsecutiveSplitter',
'RandomSplitter',
'MolecularWeightSplitter',
......@@ -304,6 +306,7 @@ class RandomSplitter(object):
return base_k_fold_split(partial(indices_split, indices=indices), dataset, k, log)
# pylint: disable=I1101
class MolecularWeightSplitter(object):
"""Sort molecules based on their weights and then split them."""
......@@ -335,7 +338,7 @@ class MolecularWeightSplitter(object):
for i, mol in enumerate(molecules):
count_and_log('Computing molecular weight for compound',
i, len(molecules), log_every_n)
mws.append(Chem.rdMolDescriptors.CalcExactMolWt(mol))
mws.append(rdMolDescriptors.CalcExactMolWt(mol))
return np.argsort(mws)
......@@ -427,6 +430,7 @@ class MolecularWeightSplitter(object):
return base_k_fold_split(partial(indices_split, indices=sorted_indices), dataset, k,
log=(log_every_n is not None))
# pylint: disable=W0702
class ScaffoldSplitter(object):
"""Group molecules based on their Bemis-Murcko scaffolds and then split the groups.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment