Unverified Commit b76ac11c authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[DGL-LifeSci] Release Preparation (CI, Docker, Conda build) (#1399)



* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* add docs

* Fix style

* Fix lint

* Bug fix

* Fix test

* Update

* Update

* Update

* Update
Co-authored-by: default avatarMinjie Wang <wmjlyjemaine@gmail.com>
parent e4cc8185
# pylint: disable=C0111, C0103, E1101, W0611, W0612 # pylint: disable=C0111, C0103, E1101, W0611, W0612, W0221
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
......
# pylint: disable=C0111, C0103, E1101, W0611, W0612, C0200 # pylint: disable=C0111, C0103, E1101, W0611, W0612, C0200, W0221, E1102
import copy import copy
import rdkit.Chem as Chem import rdkit.Chem as Chem
import torch import torch
......
# pylint: disable=C0111, C0103, E1101, W0611, W0612 # pylint: disable=C0111, C0103, E1101, W0611, W0612, I1101, W0221
# pylint: disable=redefined-outer-name # pylint: disable=redefined-outer-name
import rdkit.Chem as Chem import rdkit.Chem as Chem
import torch import torch
......
# pylint: disable=C0111, C0103, E1101, W0611, W0612 # pylint: disable=C0111, C0103, E1101, W0611, W0612, W0221
import os import os
import torch import torch
import torch.nn as nn import torch.nn as nn
......
"""MGCN""" """MGCN"""
# pylint: disable= no-member, arguments-differ, invalid-name
import torch.nn as nn import torch.nn as nn
from ..gnn import MGCNGNN from ..gnn import MGCNGNN
...@@ -6,6 +7,7 @@ from ..readout import MLPNodeReadout ...@@ -6,6 +7,7 @@ from ..readout import MLPNodeReadout
__all__ = ['MGCNPredictor'] __all__ = ['MGCNPredictor']
# pylint: disable=W0221
class MGCNPredictor(nn.Module): class MGCNPredictor(nn.Module):
"""MGCN for for regression and classification on graphs. """MGCN for for regression and classification on graphs.
......
"""MLP for prediction on the output of readout.""" """MLP for prediction on the output of readout."""
# pylint: disable= no-member, arguments-differ, invalid-name
import torch.nn as nn import torch.nn as nn
# pylint: disable=W0221
class MLPPredictor(nn.Module): class MLPPredictor(nn.Module):
"""Two-layer MLP for regression or soft classification """Two-layer MLP for regression or soft classification
over multiple tasks from graph representations. over multiple tasks from graph representations.
......
"""MPNN""" """MPNN"""
# pylint: disable= no-member, arguments-differ, invalid-name
import torch.nn as nn import torch.nn as nn
from dgl.nn.pytorch import Set2Set from dgl.nn.pytorch import Set2Set
...@@ -7,6 +8,7 @@ from ..gnn import MPNNGNN ...@@ -7,6 +8,7 @@ from ..gnn import MPNNGNN
__all__ = ['MPNNPredictor'] __all__ = ['MPNNPredictor']
# pylint: disable=W0221
class MPNNPredictor(nn.Module): class MPNNPredictor(nn.Module):
"""MPNN for regression and classification on graphs. """MPNN for regression and classification on graphs.
......
"""SchNet""" """SchNet"""
# pylint: disable= no-member, arguments-differ, invalid-name
import torch.nn as nn import torch.nn as nn
from dgl.nn.pytorch.conv.cfconv import ShiftedSoftplus from dgl.nn.pytorch.conv.cfconv import ShiftedSoftplus
...@@ -8,6 +9,7 @@ from ..readout import MLPNodeReadout ...@@ -8,6 +9,7 @@ from ..readout import MLPNodeReadout
__all__ = ['SchNetPredictor'] __all__ = ['SchNetPredictor']
# pylint: disable=W0221
class SchNetPredictor(nn.Module): class SchNetPredictor(nn.Module):
"""SchNet for regression and classification on graphs. """SchNet for regression and classification on graphs.
......
"""Weisfeiler-Lehman Network (WLN) for Reaction Center Prediction.""" """Weisfeiler-Lehman Network (WLN) for Reaction Center Prediction."""
# pylint: disable= no-member, arguments-differ, invalid-name
import dgl.function as fn import dgl.function as fn
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -7,6 +8,7 @@ from ..gnn.wln import WLNLinear, WLN ...@@ -7,6 +8,7 @@ from ..gnn.wln import WLNLinear, WLN
__all__ = ['WLNReactionCenter'] __all__ = ['WLNReactionCenter']
# pylint: disable=W0221, E1101
class WLNContext(nn.Module): class WLNContext(nn.Module):
"""Attention-based context computation for each node. """Attention-based context computation for each node.
......
"""Utilities for using pretrained models.""" """Utilities for using pretrained models."""
# pylint: disable= no-member, arguments-differ, invalid-name
import os import os
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -58,6 +59,7 @@ def download_and_load_checkpoint(model_name, model, model_postfix, ...@@ -58,6 +59,7 @@ def download_and_load_checkpoint(model_name, model, model_postfix,
return model return model
# pylint: disable=I1101
def load_pretrained(model_name, log=True): def load_pretrained(model_name, log=True):
"""Load a pretrained model """Load a pretrained model
......
"""Readout for AttentiveFP""" """Readout for AttentiveFP"""
# pylint: disable= no-member, arguments-differ, invalid-name
import dgl import dgl
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -6,6 +7,7 @@ import torch.nn.functional as F ...@@ -6,6 +7,7 @@ import torch.nn.functional as F
__all__ = ['AttentiveFPReadout'] __all__ = ['AttentiveFPReadout']
# pylint: disable=W0221
class GlobalPool(nn.Module): class GlobalPool(nn.Module):
"""One-step readout in AttentiveFP """One-step readout in AttentiveFP
...@@ -88,7 +90,7 @@ class AttentiveFPReadout(nn.Module): ...@@ -88,7 +90,7 @@ class AttentiveFPReadout(nn.Module):
super(AttentiveFPReadout, self).__init__() super(AttentiveFPReadout, self).__init__()
self.readouts = nn.ModuleList() self.readouts = nn.ModuleList()
for t in range(num_timesteps): for _ in range(num_timesteps):
self.readouts.append(GlobalPool(feat_size, dropout)) self.readouts.append(GlobalPool(feat_size, dropout))
def forward(self, g, node_feats, get_node_weight=False): def forward(self, g, node_feats, get_node_weight=False):
......
"""Readout for SchNet""" """Readout for SchNet"""
# pylint: disable= no-member, arguments-differ, invalid-name
import dgl import dgl
import torch.nn as nn import torch.nn as nn
__all__ = ['MLPNodeReadout'] __all__ = ['MLPNodeReadout']
# pylint: disable=W0221
class MLPNodeReadout(nn.Module): class MLPNodeReadout(nn.Module):
"""MLP-based Readout. """MLP-based Readout.
......
"""Apply weighted sum and max pooling to the node representations and concatenate the results.""" """Apply weighted sum and max pooling to the node representations and concatenate the results."""
# pylint: disable= no-member, arguments-differ, invalid-name
import dgl import dgl
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -7,6 +8,7 @@ from dgl.nn.pytorch import WeightAndSum ...@@ -7,6 +8,7 @@ from dgl.nn.pytorch import WeightAndSum
__all__ = ['WeightedSumAndMax'] __all__ = ['WeightedSumAndMax']
# pylint: disable=W0221
class WeightedSumAndMax(nn.Module): class WeightedSumAndMax(nn.Module):
r"""Apply weighted sum and max pooling to the node r"""Apply weighted sum and max pooling to the node
representations and concatenate the results. representations and concatenate the results.
......
"""Convert complexes into DGLHeteroGraphs""" """Convert complexes into DGLHeteroGraphs"""
# pylint: disable= no-member, arguments-differ, invalid-name
import dgl.backend as F import dgl.backend as F
import numpy as np import numpy as np
...@@ -50,6 +51,7 @@ def get_atomic_numbers(mol, indices): ...@@ -50,6 +51,7 @@ def get_atomic_numbers(mol, indices):
atomic_numbers.append(atom.GetAtomicNum()) atomic_numbers.append(atom.GetAtomicNum())
return atomic_numbers return atomic_numbers
# pylint: disable=C0326
def ACNN_graph_construction_and_featurization(ligand_mol, def ACNN_graph_construction_and_featurization(ligand_mol,
protein_mol, protein_mol,
ligand_coordinates, ligand_coordinates,
...@@ -216,7 +218,7 @@ def ACNN_graph_construction_and_featurization(ligand_mol, ...@@ -216,7 +218,7 @@ def ACNN_graph_construction_and_featurization(ligand_mol,
protein_atomic_numbers = np.concatenate([ protein_atomic_numbers = np.concatenate([
protein_atomic_numbers, np.zeros(num_protein_atoms - len(protein_atom_indices_left))]) protein_atomic_numbers, np.zeros(num_protein_atoms - len(protein_atom_indices_left))])
g.nodes['ligand_atom'].data['atomic_number'] = F.reshape(F.zerocopy_from_numpy( g.nodes['ligand_atom'].data['atomic_number'] = F.reshape(F.zerocopy_from_numpy(
ligand_atomic_numbers.astype(np.float32)), (-1, 1)) ligand_atomic_numbers.astype(np.float32)), (-1, 1))
g.nodes['protein_atom'].data['atomic_number'] = F.reshape(F.zerocopy_from_numpy( g.nodes['protein_atom'].data['atomic_number'] = F.reshape(F.zerocopy_from_numpy(
protein_atomic_numbers.astype(np.float32)), (-1, 1)) protein_atomic_numbers.astype(np.float32)), (-1, 1))
......
"""Early stopping""" """Early stopping"""
# pylint: disable= no-member, arguments-differ, invalid-name
import datetime import datetime
import torch import torch
__all__ = ['EarlyStopping'] __all__ = ['EarlyStopping']
# pylint: disable=C0103
class EarlyStopping(object): class EarlyStopping(object):
"""Early stop tracker """Early stop tracker
...@@ -56,7 +58,7 @@ class EarlyStopping(object): ...@@ -56,7 +58,7 @@ class EarlyStopping(object):
bool bool
Whether the new score is higher than the previous best score. Whether the new score is higher than the previous best score.
""" """
return (score > prev_best_score) return score > prev_best_score
def _check_lower(self, score, prev_best_score): def _check_lower(self, score, prev_best_score):
"""Check if the new score is lower than the previous best score. """Check if the new score is lower than the previous best score.
...@@ -73,7 +75,7 @@ class EarlyStopping(object): ...@@ -73,7 +75,7 @@ class EarlyStopping(object):
bool bool
Whether the new score is lower than the previous best score. Whether the new score is lower than the previous best score.
""" """
return (score < prev_best_score) return score < prev_best_score
def step(self, score, model): def step(self, score, model):
"""Update based on a new score. """Update based on a new score.
......
"""Evaluation of model performance.""" """Evaluation of model performance."""
# pylint: disable= no-member, arguments-differ, invalid-name
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -8,6 +9,7 @@ from sklearn.metrics import roc_auc_score ...@@ -8,6 +9,7 @@ from sklearn.metrics import roc_auc_score
__all__ = ['Meter'] __all__ = ['Meter']
# pylint: disable=E1101
class Meter(object): class Meter(object):
"""Track and summarize model performance on a dataset for (multi-label) prediction. """Track and summarize model performance on a dataset for (multi-label) prediction.
...@@ -252,12 +254,12 @@ class Meter(object): ...@@ -252,12 +254,12 @@ class Meter(object):
""" """
if metric_name == 'r2': if metric_name == 'r2':
return self.pearson_r2(reduction) return self.pearson_r2(reduction)
elif metric_name == 'mae':
if metric_name == 'mae':
return self.mae(reduction) return self.mae(reduction)
elif metric_name == 'rmse':
if metric_name == 'rmse':
return self.rmse(reduction) return self.rmse(reduction)
elif metric_name == 'roc_auc_score':
if metric_name == 'roc_auc_score':
return self.roc_auc_score(reduction) return self.roc_auc_score(reduction)
else:
raise ValueError('Expect metric_name to be "r2" or "mae" or "rmse" '
'or "roc_auc_score", got {}'.format(metric_name))
"""Node and edge featurization for molecular graphs.""" """Node and edge featurization for molecular graphs."""
import dgl.backend as F # pylint: disable= no-member, arguments-differ, invalid-name
import itertools import itertools
from collections import defaultdict
import dgl.backend as F
import numpy as np import numpy as np
from collections import defaultdict
from rdkit import Chem from rdkit import Chem
__all__ = ['one_hot_encoding', __all__ = ['one_hot_encoding',
...@@ -302,6 +303,7 @@ def atom_implicit_valence(atom): ...@@ -302,6 +303,7 @@ def atom_implicit_valence(atom):
""" """
return [atom.GetImplicitValence()] return [atom.GetImplicitValence()]
# pylint: disable=I1101
def atom_hybridization_one_hot(atom, allowable_set=None, encode_unknown=False): def atom_hybridization_one_hot(atom, allowable_set=None, encode_unknown=False):
"""One hot encoding for the hybridization of an atom. """One hot encoding for the hybridization of an atom.
...@@ -601,7 +603,7 @@ class BaseAtomFeaturizer(object): ...@@ -601,7 +603,7 @@ class BaseAtomFeaturizer(object):
Examples Examples
-------- --------
>>> from dgl.data.life_sci import BaseAtomFeaturizer, atom_mass, atom_degree_one_hot >>> from dgl.data.dgllife import BaseAtomFeaturizer, atom_mass, atom_degree_one_hot
>>> from rdkit import Chem >>> from rdkit import Chem
>>> mol = Chem.MolFromSmiles('CCO') >>> mol = Chem.MolFromSmiles('CCO')
...@@ -856,17 +858,17 @@ class BaseBondFeaturizer(object): ...@@ -856,17 +858,17 @@ class BaseBondFeaturizer(object):
Examples Examples
-------- --------
>>> from dgl.data.life_sci import BaseBondFeaturizer, bond_type_one_hot, bond_is_in_ring >>> from dgl.data.dgllife import BaseBondFeaturizer, bond_type_one_hot, bond_is_in_ring
>>> from rdkit import Chem >>> from rdkit import Chem
>>> mol = Chem.MolFromSmiles('CCO') >>> mol = Chem.MolFromSmiles('CCO')
>>> bond_featurizer = BaseBondFeaturizer({'bond_type': bond_type_one_hot, 'in_ring': bond_is_in_ring}) >>> bond_featurizer = BaseBondFeaturizer({'type': bond_type_one_hot, 'ring': bond_is_in_ring})
>>> bond_featurizer(mol) >>> bond_featurizer(mol)
{'bond_type': tensor([[1., 0., 0., 0.], {'type': tensor([[1., 0., 0., 0.],
[1., 0., 0., 0.], [1., 0., 0., 0.],
[1., 0., 0., 0.], [1., 0., 0., 0.],
[1., 0., 0., 0.]]), [1., 0., 0., 0.]]),
'in_ring': tensor([[0.], [0.], [0.], [0.]])} 'ring': tensor([[0.], [0.], [0.], [0.]])}
""" """
def __init__(self, featurizer_funcs, feat_sizes=None): def __init__(self, featurizer_funcs, feat_sizes=None):
self.featurizer_funcs = featurizer_funcs self.featurizer_funcs = featurizer_funcs
......
"""Convert molecules into DGLGraphs.""" """Convert molecules into DGLGraphs."""
# pylint: disable= no-member, arguments-differ, invalid-name
from functools import partial
import torch import torch
from dgl import DGLGraph from dgl import DGLGraph
from functools import partial
from rdkit import Chem from rdkit import Chem
from rdkit.Chem import rdmolfiles, rdmolops from rdkit.Chem import rdmolfiles, rdmolops
from sklearn.neighbors import NearestNeighbors from sklearn.neighbors import NearestNeighbors
...@@ -16,6 +17,7 @@ __all__ = ['mol_to_graph', ...@@ -16,6 +17,7 @@ __all__ = ['mol_to_graph',
'mol_to_nearest_neighbor_graph', 'mol_to_nearest_neighbor_graph',
'smiles_to_nearest_neighbor_graph'] 'smiles_to_nearest_neighbor_graph']
# pylint: disable=I1101
def mol_to_graph(mol, graph_constructor, node_featurizer, edge_featurizer, canonical_atom_order): def mol_to_graph(mol, graph_constructor, node_featurizer, edge_featurizer, canonical_atom_order):
"""Convert an RDKit molecule object into a DGLGraph and featurize for it. """Convert an RDKit molecule object into a DGLGraph and featurize for it.
...@@ -225,7 +227,8 @@ def mol_to_complete_graph(mol, add_self_loop=False, ...@@ -225,7 +227,8 @@ def mol_to_complete_graph(mol, add_self_loop=False,
g : DGLGraph g : DGLGraph
Complete DGLGraph for the molecule Complete DGLGraph for the molecule
""" """
return mol_to_graph(mol, partial(construct_complete_graph_from_mol, add_self_loop=add_self_loop), return mol_to_graph(mol,
partial(construct_complete_graph_from_mol, add_self_loop=add_self_loop),
node_featurizer, edge_featurizer, canonical_atom_order) node_featurizer, edge_featurizer, canonical_atom_order)
def smiles_to_complete_graph(smiles, add_self_loop=False, def smiles_to_complete_graph(smiles, add_self_loop=False,
...@@ -321,6 +324,7 @@ def k_nearest_neighbors(coordinates, neighbor_cutoff, max_num_neighbors=None, ...@@ -321,6 +324,7 @@ def k_nearest_neighbors(coordinates, neighbor_cutoff, max_num_neighbors=None,
return srcs, dsts, dists return srcs, dsts, dists
# pylint: disable=E1102
def mol_to_nearest_neighbor_graph(mol, def mol_to_nearest_neighbor_graph(mol,
coordinates, coordinates,
neighbor_cutoff, neighbor_cutoff,
......
"""Utils for RDKit, mostly adapted from DeepChem """Utils for RDKit, mostly adapted from DeepChem
(https://github.com/deepchem/deepchem/blob/master/deepchem).""" (https://github.com/deepchem/deepchem/blob/master/deepchem)."""
# pylint: disable= no-member, arguments-differ, invalid-name
import warnings import warnings
from functools import partial from functools import partial
...@@ -7,16 +8,12 @@ from multiprocessing import Pool ...@@ -7,16 +8,12 @@ from multiprocessing import Pool
from rdkit import Chem from rdkit import Chem
from rdkit.Chem import AllChem from rdkit.Chem import AllChem
try: __all__ = ['get_mol_3d_coordinates',
from StringIO import StringIO
except ImportError:
from io import StringIO
__all__ = ['get_mol_3D_coordinates',
'load_molecule', 'load_molecule',
'multiprocess_load_molecules'] 'multiprocess_load_molecules']
def get_mol_3D_coordinates(mol): # pylint: disable=W0702
def get_mol_3d_coordinates(mol):
"""Get 3D coordinates of the molecule. """Get 3D coordinates of the molecule.
Parameters Parameters
...@@ -42,6 +39,7 @@ def get_mol_3D_coordinates(mol): ...@@ -42,6 +39,7 @@ def get_mol_3D_coordinates(mol):
warnings.warn('Unable to get conformation of the molecule.') warnings.warn('Unable to get conformation of the molecule.')
return None return None
# pylint: disable=E1101
def load_molecule(molecule_file, sanitize=False, calc_charges=False, def load_molecule(molecule_file, sanitize=False, calc_charges=False,
remove_hs=False, use_conformation=True): remove_hs=False, use_conformation=True):
"""Load a molecule from a file. """Load a molecule from a file.
...@@ -80,8 +78,8 @@ def load_molecule(molecule_file, sanitize=False, calc_charges=False, ...@@ -80,8 +78,8 @@ def load_molecule(molecule_file, sanitize=False, calc_charges=False,
supplier = Chem.SDMolSupplier(molecule_file, sanitize=False, removeHs=False) supplier = Chem.SDMolSupplier(molecule_file, sanitize=False, removeHs=False)
mol = supplier[0] mol = supplier[0]
elif molecule_file.endswith('.pdbqt'): elif molecule_file.endswith('.pdbqt'):
with open(molecule_file) as f: with open(molecule_file) as file:
pdbqt_data = f.readlines() pdbqt_data = file.readlines()
pdb_block = '' pdb_block = ''
for line in pdbqt_data: for line in pdbqt_data:
pdb_block += '{}\n'.format(line[:66]) pdb_block += '{}\n'.format(line[:66])
...@@ -109,7 +107,7 @@ def load_molecule(molecule_file, sanitize=False, calc_charges=False, ...@@ -109,7 +107,7 @@ def load_molecule(molecule_file, sanitize=False, calc_charges=False,
return None, None return None, None
if use_conformation: if use_conformation:
coordinates = get_mol_3D_coordinates(mol) coordinates = get_mol_3d_coordinates(mol)
else: else:
coordinates = None coordinates = None
...@@ -151,7 +149,7 @@ def multiprocess_load_molecules(files, sanitize=False, calc_charges=False, ...@@ -151,7 +149,7 @@ def multiprocess_load_molecules(files, sanitize=False, calc_charges=False,
""" """
if num_processes == 1: if num_processes == 1:
mols_loaded = [] mols_loaded = []
for i, f in enumerate(files): for f in files:
mols_loaded.append(load_molecule( mols_loaded.append(load_molecule(
f, sanitize=sanitize, calc_charges=calc_charges, f, sanitize=sanitize, calc_charges=calc_charges,
remove_hs=remove_hs, use_conformation=use_conformation)) remove_hs=remove_hs, use_conformation=use_conformation))
......
...@@ -3,10 +3,8 @@ ...@@ -3,10 +3,8 @@
We mostly adapt them from deepchem We mostly adapt them from deepchem
(https://github.com/deepchem/deepchem/blob/master/deepchem/splits/splitters.py). (https://github.com/deepchem/deepchem/blob/master/deepchem/splits/splitters.py).
""" """
import dgl.backend as F # pylint: disable= no-member, arguments-differ, invalid-name
import numpy as np # pylint: disable=E0611
from dgl.data.utils import split_dataset, Subset
from collections import defaultdict from collections import defaultdict
from functools import partial from functools import partial
from itertools import accumulate, chain from itertools import accumulate, chain
...@@ -15,6 +13,10 @@ from rdkit.Chem import rdMolDescriptors ...@@ -15,6 +13,10 @@ from rdkit.Chem import rdMolDescriptors
from rdkit.Chem.rdmolops import FastFindRings from rdkit.Chem.rdmolops import FastFindRings
from rdkit.Chem.Scaffolds import MurckoScaffold from rdkit.Chem.Scaffolds import MurckoScaffold
import dgl.backend as F
import numpy as np
from dgl.data.utils import split_dataset, Subset
__all__ = ['ConsecutiveSplitter', __all__ = ['ConsecutiveSplitter',
'RandomSplitter', 'RandomSplitter',
'MolecularWeightSplitter', 'MolecularWeightSplitter',
...@@ -304,6 +306,7 @@ class RandomSplitter(object): ...@@ -304,6 +306,7 @@ class RandomSplitter(object):
return base_k_fold_split(partial(indices_split, indices=indices), dataset, k, log) return base_k_fold_split(partial(indices_split, indices=indices), dataset, k, log)
# pylint: disable=I1101
class MolecularWeightSplitter(object): class MolecularWeightSplitter(object):
"""Sort molecules based on their weights and then split them.""" """Sort molecules based on their weights and then split them."""
...@@ -335,7 +338,7 @@ class MolecularWeightSplitter(object): ...@@ -335,7 +338,7 @@ class MolecularWeightSplitter(object):
for i, mol in enumerate(molecules): for i, mol in enumerate(molecules):
count_and_log('Computing molecular weight for compound', count_and_log('Computing molecular weight for compound',
i, len(molecules), log_every_n) i, len(molecules), log_every_n)
mws.append(Chem.rdMolDescriptors.CalcExactMolWt(mol)) mws.append(rdMolDescriptors.CalcExactMolWt(mol))
return np.argsort(mws) return np.argsort(mws)
...@@ -427,6 +430,7 @@ class MolecularWeightSplitter(object): ...@@ -427,6 +430,7 @@ class MolecularWeightSplitter(object):
return base_k_fold_split(partial(indices_split, indices=sorted_indices), dataset, k, return base_k_fold_split(partial(indices_split, indices=sorted_indices), dataset, k,
log=(log_every_n is not None)) log=(log_every_n is not None))
# pylint: disable=W0702
class ScaffoldSplitter(object): class ScaffoldSplitter(object):
"""Group molecules based on their Bemis-Murcko scaffolds and then split the groups. """Group molecules based on their Bemis-Murcko scaffolds and then split the groups.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment