Unverified Commit cf9ba90f authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[Chem] ACNN and various utilities (#1117)

* Add several splitting methods

* Update

* Update

* Update

* Update

* Update

* Fix

* Update

* Update

* Update

* Update

* Fix

* Fix

* Fix

* Fix

* Fix

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Finally

* CI
parent 6beef85b
"""Convert complexes into DGLHeteroGraphs"""
import numpy as np
from ..utils import k_nearest_neighbors
from .... import graph, bipartite, hetero_from_relations
from .... import backend as F
__all__ = ['ACNN_graph_construction_and_featurization']
def filter_out_hydrogens(mol):
"""Get indices for non-hydrogen atoms.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance.
Returns
-------
indices_left : list of int
Indices of non-hydrogen atoms.
"""
indices_left = []
for i, atom in enumerate(mol.GetAtoms()):
atomic_num = atom.GetAtomicNum()
# Hydrogen atoms have an atomic number of 1.
if atomic_num != 1:
indices_left.append(i)
return indices_left
def get_atomic_numbers(mol, indices):
"""Get the atomic numbers for the specified atoms.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance.
indices : list of int
Specifying atoms.
Returns
-------
list of int
Atomic numbers computed.
"""
atomic_numbers = []
for i in indices:
atom = mol.GetAtomWithIdx(i)
atomic_numbers.append(atom.GetAtomicNum())
return atomic_numbers
def ACNN_graph_construction_and_featurization(ligand_mol,
protein_mol,
ligand_coordinates,
protein_coordinates,
max_num_ligand_atoms=None,
max_num_protein_atoms=None,
neighbor_cutoff=12.,
max_num_neighbors=12,
strip_hydrogens=False):
"""Graph construction and featurization for `Atomic Convolutional Networks for
Predicting Protein-Ligand Binding Affinity <https://arxiv.org/abs/1703.10603>`__.
Parameters
----------
ligand_mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance.
protein_mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance.
ligand_coordinates : Float Tensor of shape (V1, 3)
Atom coordinates in a ligand.
protein_coordinates : Float Tensor of shape (V2, 3)
Atom coordinates in a protein.
max_num_ligand_atoms : int or None
Maximum number of atoms in ligands for zero padding.
If None, no zero padding will be performed. Default to None.
max_num_protein_atoms : int or None
Maximum number of atoms in proteins for zero padding.
If None, no zero padding will be performed. Default to None.
neighbor_cutoff : float
Distance cutoff to define 'neighboring'. Default to 12.
max_num_neighbors : int
Maximum number of neighbors allowed for each atom. Default to 12.
strip_hydrogens : bool
Whether to exclude hydrogen atoms. Default to False.
"""
assert ligand_coordinates is not None, 'Expect ligand_coordinates to be provided.'
assert protein_coordinates is not None, 'Expect protein_coordinates to be provided.'
if strip_hydrogens:
# Remove hydrogen atoms and their corresponding coordinates
ligand_atom_indices_left = filter_out_hydrogens(ligand_mol)
protein_atom_indices_left = filter_out_hydrogens(protein_mol)
ligand_coordinates = ligand_coordinates.take(ligand_atom_indices_left, axis=0)
protein_coordinates = protein_coordinates.take(protein_atom_indices_left, axis=0)
else:
ligand_atom_indices_left = list(range(ligand_mol.GetNumAtoms()))
protein_atom_indices_left = list(range(protein_mol.GetNumAtoms()))
# Compute number of nodes for each type
if max_num_ligand_atoms is None:
num_ligand_atoms = len(ligand_atom_indices_left)
else:
num_ligand_atoms = max_num_ligand_atoms
if max_num_protein_atoms is None:
num_protein_atoms = len(protein_atom_indices_left)
else:
num_protein_atoms = max_num_protein_atoms
# Construct graph for atoms in the ligand
ligand_srcs, ligand_dsts, ligand_dists = k_nearest_neighbors(
ligand_coordinates, neighbor_cutoff, max_num_neighbors)
ligand_graph = graph((ligand_srcs, ligand_dsts),
'ligand_atom', 'ligand', num_ligand_atoms)
ligand_graph.edata['distance'] = F.reshape(F.zerocopy_from_numpy(
np.array(ligand_dists).astype(np.float32)), (-1, 1))
# Construct graph for atoms in the protein
protein_srcs, protein_dsts, protein_dists = k_nearest_neighbors(
protein_coordinates, neighbor_cutoff, max_num_neighbors)
protein_graph = graph((protein_srcs, protein_dsts),
'protein_atom', 'protein', num_protein_atoms)
protein_graph.edata['distance'] = F.reshape(F.zerocopy_from_numpy(
np.array(protein_dists).astype(np.float32)), (-1, 1))
# Construct 4 graphs for complex representation, including the connection within
# protein atoms, the connection within ligand atoms and the connection between
# protein and ligand atoms.
complex_srcs, complex_dsts, complex_dists = k_nearest_neighbors(
np.concatenate([ligand_coordinates, protein_coordinates]),
neighbor_cutoff, max_num_neighbors)
complex_srcs = np.array(complex_srcs)
complex_dsts = np.array(complex_dsts)
complex_dists = np.array(complex_dists)
offset = num_ligand_atoms
# ('ligand_atom', 'complex', 'ligand_atom')
inter_ligand_indices = np.intersect1d(
(complex_srcs < offset).nonzero()[0],
(complex_dsts < offset).nonzero()[0],
assume_unique=True)
inter_ligand_graph = graph(
(complex_srcs[inter_ligand_indices].tolist(),
complex_dsts[inter_ligand_indices].tolist()),
'ligand_atom', 'complex', num_ligand_atoms)
inter_ligand_graph.edata['distance'] = F.reshape(F.zerocopy_from_numpy(
complex_dists[inter_ligand_indices].astype(np.float32)), (-1, 1))
# ('protein_atom', 'complex', 'protein_atom')
inter_protein_indices = np.intersect1d(
(complex_srcs >= offset).nonzero()[0],
(complex_dsts >= offset).nonzero()[0],
assume_unique=True)
inter_protein_graph = graph(
((complex_srcs[inter_protein_indices] - offset).tolist(),
(complex_dsts[inter_protein_indices] - offset).tolist()),
'protein_atom', 'complex', num_protein_atoms)
inter_protein_graph.edata['distance'] = F.reshape(F.zerocopy_from_numpy(
complex_dists[inter_protein_indices].astype(np.float32)), (-1, 1))
# ('ligand_atom', 'complex', 'protein_atom')
ligand_protein_indices = np.intersect1d(
(complex_srcs < offset).nonzero()[0],
(complex_dsts >= offset).nonzero()[0],
assume_unique=True)
ligand_protein_graph = bipartite(
(complex_srcs[ligand_protein_indices].tolist(),
(complex_dsts[ligand_protein_indices] - offset).tolist()),
'ligand_atom', 'complex', 'protein_atom',
(num_ligand_atoms, num_protein_atoms))
ligand_protein_graph.edata['distance'] = F.reshape(F.zerocopy_from_numpy(
complex_dists[ligand_protein_indices].astype(np.float32)), (-1, 1))
# ('protein_atom', 'complex', 'ligand_atom')
protein_ligand_indices = np.intersect1d(
(complex_srcs >= offset).nonzero()[0],
(complex_dsts < offset).nonzero()[0],
assume_unique=True)
protein_ligand_graph = bipartite(
((complex_srcs[protein_ligand_indices] - offset).tolist(),
complex_dsts[protein_ligand_indices].tolist()),
'protein_atom', 'complex', 'ligand_atom',
(num_protein_atoms, num_ligand_atoms))
protein_ligand_graph.edata['distance'] = F.reshape(F.zerocopy_from_numpy(
complex_dists[protein_ligand_indices].astype(np.float32)), (-1, 1))
# Merge the graphs
g = hetero_from_relations(
[protein_graph,
ligand_graph,
inter_ligand_graph,
inter_protein_graph,
ligand_protein_graph,
protein_ligand_graph]
)
# Get atomic numbers for all atoms left and set node features
ligand_atomic_numbers = np.array(get_atomic_numbers(ligand_mol, ligand_atom_indices_left))
# zero padding
ligand_atomic_numbers = np.concatenate([
ligand_atomic_numbers, np.zeros(num_ligand_atoms - len(ligand_atom_indices_left))])
protein_atomic_numbers = np.array(get_atomic_numbers(protein_mol, protein_atom_indices_left))
# zero padding
protein_atomic_numbers = np.concatenate([
protein_atomic_numbers, np.zeros(num_protein_atoms - len(protein_atom_indices_left))])
g.nodes['ligand_atom'].data['atomic_number'] = F.reshape(F.zerocopy_from_numpy(
ligand_atomic_numbers.astype(np.float32)), (-1, 1))
g.nodes['protein_atom'].data['atomic_number'] = F.reshape(F.zerocopy_from_numpy(
protein_atomic_numbers.astype(np.float32)), (-1, 1))
# Prepare mask indicating the existence of nodes
ligand_masks = np.zeros((num_ligand_atoms, 1))
ligand_masks[:len(ligand_atom_indices_left), :] = 1
g.nodes['ligand_atom'].data['mask'] = F.zerocopy_from_numpy(
ligand_masks.astype(np.float32))
protein_masks = np.zeros((num_protein_atoms, 1))
protein_masks[:len(protein_atom_indices_left), :] = 1
g.nodes['protein_atom'].data['mask'] = F.zerocopy_from_numpy(
protein_masks.astype(np.float32))
return g
import dgl.backend as F
import itertools import itertools
import numpy as np import numpy as np
from functools import partial
from collections import defaultdict from collections import defaultdict
from dgl import DGLGraph
from .... import backend as F
try: try:
from rdkit import Chem from rdkit import Chem
...@@ -12,18 +11,38 @@ try: ...@@ -12,18 +11,38 @@ try:
except ImportError: except ImportError:
pass pass
__all__ = ['one_hot_encoding', 'atom_type_one_hot', 'atomic_number_one_hot', 'atomic_number', __all__ = ['one_hot_encoding',
'atom_degree_one_hot', 'atom_degree', 'atom_total_degree_one_hot', 'atom_total_degree', 'atom_type_one_hot',
'atom_implicit_valence_one_hot', 'atom_implicit_valence', 'atom_hybridization_one_hot', 'atomic_number_one_hot',
'atom_total_num_H_one_hot', 'atom_total_num_H', 'atom_formal_charge_one_hot', 'atomic_number',
'atom_formal_charge', 'atom_num_radical_electrons_one_hot', 'atom_degree_one_hot',
'atom_num_radical_electrons', 'atom_is_aromatic_one_hot', 'atom_is_aromatic', 'atom_degree',
'atom_chiral_tag_one_hot', 'atom_mass', 'ConcatFeaturizer', 'BaseAtomFeaturizer', 'atom_total_degree_one_hot',
'CanonicalAtomFeaturizer', 'mol_to_graph', 'smiles_to_bigraph', 'atom_total_degree',
'mol_to_bigraph', 'smiles_to_complete_graph', 'mol_to_complete_graph', 'atom_implicit_valence_one_hot',
'bond_type_one_hot', 'bond_is_conjugated_one_hot', 'bond_is_conjugated', 'atom_implicit_valence',
'bond_is_in_ring_one_hot', 'bond_is_in_ring', 'bond_stereo_one_hot', 'atom_hybridization_one_hot',
'BaseBondFeaturizer', 'CanonicalBondFeaturizer'] 'atom_total_num_H_one_hot',
'atom_total_num_H',
'atom_formal_charge_one_hot',
'atom_formal_charge',
'atom_num_radical_electrons_one_hot',
'atom_num_radical_electrons',
'atom_is_aromatic_one_hot',
'atom_is_aromatic',
'atom_chiral_tag_one_hot',
'atom_mass',
'ConcatFeaturizer',
'BaseAtomFeaturizer',
'CanonicalAtomFeaturizer',
'bond_type_one_hot',
'bond_is_conjugated_one_hot',
'bond_is_conjugated',
'bond_is_in_ring_one_hot',
'bond_is_in_ring',
'bond_stereo_one_hot',
'BaseBondFeaturizer',
'CanonicalBondFeaturizer']
def one_hot_encoding(x, allowable_set, encode_unknown=False): def one_hot_encoding(x, allowable_set, encode_unknown=False):
"""One-hot encoding. """One-hot encoding.
...@@ -197,7 +216,8 @@ def atom_total_degree_one_hot(atom, allowable_set=None, encode_unknown=False): ...@@ -197,7 +216,8 @@ def atom_total_degree_one_hot(atom, allowable_set=None, encode_unknown=False):
return one_hot_encoding(atom.GetTotalDegree(), allowable_set, encode_unknown) return one_hot_encoding(atom.GetTotalDegree(), allowable_set, encode_unknown)
def atom_total_degree(atom): def atom_total_degree(atom):
""" """The degree of an atom including Hs.
See Also See Also
-------- --------
atom_degree atom_degree
...@@ -855,232 +875,3 @@ class CanonicalBondFeaturizer(BaseBondFeaturizer): ...@@ -855,232 +875,3 @@ class CanonicalBondFeaturizer(BaseBondFeaturizer):
bond_is_in_ring, bond_is_in_ring,
bond_stereo_one_hot] bond_stereo_one_hot]
)}) )})
#################################################################
# DGLGraph Construction
#################################################################
def mol_to_graph(mol, graph_constructor, atom_featurizer, bond_featurizer):
"""Convert an RDKit molecule object into a DGLGraph and featurize for it.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule holder
graph_constructor : callable
Takes an RDKit molecule as input and returns a DGLGraph
atom_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for atoms in a molecule, which can be used to update
ndata for a DGLGraph.
bond_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for bonds in a molecule, which can be used to update
edata for a DGLGraph.
Returns
-------
g : DGLGraph
Converted DGLGraph for the molecule
"""
new_order = rdmolfiles.CanonicalRankAtoms(mol)
mol = rdmolops.RenumberAtoms(mol, new_order)
g = graph_constructor(mol)
if atom_featurizer is not None:
g.ndata.update(atom_featurizer(mol))
if bond_featurizer is not None:
g.edata.update(bond_featurizer(mol))
return g
def construct_bigraph_from_mol(mol, add_self_loop=False):
"""Construct a bi-directed DGLGraph with topology only for the molecule.
The **i** th atom in the molecule, i.e. ``mol.GetAtomWithIdx(i)``, corresponds to the
**i** th node in the returned DGLGraph.
The **i** th bond in the molecule, i.e. ``mol.GetBondWithIdx(i)``, corresponds to the
**(2i)**-th and **(2i+1)**-th edges in the returned DGLGraph. The **(2i)**-th and
**(2i+1)**-th edges will be separately from **u** to **v** and **v** to **u**, where
**u** is ``bond.GetBeginAtomIdx()`` and **v** is ``bond.GetEndAtomIdx()``.
If self loops are added, the last **n** edges will separately be self loops for
atoms ``0, 1, ..., n-1``.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule holder
add_self_loop : bool
Whether to add self loops in DGLGraphs. Default to False.
Returns
-------
g : DGLGraph
Empty bigraph topology of the molecule
"""
g = DGLGraph()
# Add nodes
num_atoms = mol.GetNumAtoms()
g.add_nodes(num_atoms)
# Add edges
src_list = []
dst_list = []
num_bonds = mol.GetNumBonds()
for i in range(num_bonds):
bond = mol.GetBondWithIdx(i)
u = bond.GetBeginAtomIdx()
v = bond.GetEndAtomIdx()
src_list.extend([u, v])
dst_list.extend([v, u])
g.add_edges(src_list, dst_list)
if add_self_loop:
nodes = g.nodes()
g.add_edges(nodes, nodes)
return g
def mol_to_bigraph(mol, add_self_loop=False,
atom_featurizer=None,
bond_featurizer=None):
"""Convert an RDKit molecule object into a bi-directed DGLGraph and featurize for it.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule holder
add_self_loop : bool
Whether to add self loops in DGLGraphs. Default to False.
atom_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for atoms in a molecule, which can be used to update
ndata for a DGLGraph. Default to None.
bond_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for bonds in a molecule, which can be used to update
edata for a DGLGraph. Default to None.
Returns
-------
g : DGLGraph
Bi-directed DGLGraph for the molecule
"""
return mol_to_graph(mol, partial(construct_bigraph_from_mol, add_self_loop=add_self_loop),
atom_featurizer, bond_featurizer)
def smiles_to_bigraph(smiles, add_self_loop=False,
atom_featurizer=None,
bond_featurizer=None):
"""Convert a SMILES into a bi-directed DGLGraph and featurize for it.
Parameters
----------
smiles : str
String of SMILES
add_self_loop : bool
Whether to add self loops in DGLGraphs. Default to False.
atom_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for atoms in a molecule, which can be used to update
ndata for a DGLGraph. Default to None.
bond_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for bonds in a molecule, which can be used to update
edata for a DGLGraph. Default to None.
Returns
-------
g : DGLGraph
Bi-directed DGLGraph for the molecule
"""
mol = Chem.MolFromSmiles(smiles)
return mol_to_bigraph(mol, add_self_loop, atom_featurizer, bond_featurizer)
def construct_complete_graph_from_mol(mol, add_self_loop=False):
"""Construct a complete graph with topology only for the molecule
The **i** th atom in the molecule, i.e. ``mol.GetAtomWithIdx(i)``, corresponds to the
**i** th node in the returned DGLGraph.
The edges are in the order of (0, 0), (1, 0), (2, 0), ... (0, 1), (1, 1), (2, 1), ...
If self loops are not created, we will not have (0, 0), (1, 1), ...
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule holder
add_self_loop : bool
Whether to add self loops in DGLGraphs. Default to False.
Returns
-------
g : DGLGraph
Empty complete graph topology of the molecule
"""
g = DGLGraph()
num_atoms = mol.GetNumAtoms()
g.add_nodes(num_atoms)
if add_self_loop:
g.add_edges(
[i for i in range(num_atoms) for j in range(num_atoms)],
[j for i in range(num_atoms) for j in range(num_atoms)])
else:
g.add_edges(
[i for i in range(num_atoms) for j in range(num_atoms - 1)], [
j for i in range(num_atoms)
for j in range(num_atoms) if i != j
])
return g
def mol_to_complete_graph(mol, add_self_loop=False,
atom_featurizer=None,
bond_featurizer=None):
"""Convert an RDKit molecule into a complete DGLGraph and featurize for it.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule holder
add_self_loop : bool
Whether to add self loops in DGLGraphs. Default to False.
atom_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for atoms in a molecule, which can be used to update
ndata for a DGLGraph. Default to None.
bond_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for bonds in a molecule, which can be used to update
edata for a DGLGraph. Default to None.
Returns
-------
g : DGLGraph
Complete DGLGraph for the molecule
"""
return mol_to_graph(mol, partial(construct_complete_graph_from_mol, add_self_loop=add_self_loop),
atom_featurizer, bond_featurizer)
def smiles_to_complete_graph(smiles, add_self_loop=False,
atom_featurizer=None,
bond_featurizer=None):
"""Convert a SMILES into a complete DGLGraph and featurize for it.
Parameters
----------
smiles : str
String of SMILES
add_self_loop : bool
Whether to add self loops in DGLGraphs. Default to False.
atom_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for atoms in a molecule, which can be used to update
ndata for a DGLGraph. Default to None.
bond_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for bonds in a molecule, which can be used to update
edata for a DGLGraph. Default to None.
Returns
-------
g : DGLGraph
Complete DGLGraph for the molecule
"""
mol = Chem.MolFromSmiles(smiles)
return mol_to_complete_graph(mol, add_self_loop, atom_featurizer, bond_featurizer)
"""Convert molecules into DGLGraphs."""
import numpy as np
from functools import partial
from .... import DGLGraph
try:
import mdtraj
from rdkit import Chem
from rdkit.Chem import rdmolfiles, rdmolops
except ImportError:
pass
__all__ = ['mol_to_graph',
'smiles_to_bigraph',
'mol_to_bigraph',
'smiles_to_complete_graph',
'mol_to_complete_graph',
'k_nearest_neighbors']
def mol_to_graph(mol, graph_constructor, node_featurizer, edge_featurizer):
"""Convert an RDKit molecule object into a DGLGraph and featurize for it.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule holder
graph_constructor : callable
Takes an RDKit molecule as input and returns a DGLGraph
node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for nodes like atoms in a molecule, which can be used to
update ndata for a DGLGraph.
edge_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for edges like bonds in a molecule, which can be used to
update edata for a DGLGraph.
Returns
-------
g : DGLGraph
Converted DGLGraph for the molecule
"""
new_order = rdmolfiles.CanonicalRankAtoms(mol)
mol = rdmolops.RenumberAtoms(mol, new_order)
g = graph_constructor(mol)
if node_featurizer is not None:
g.ndata.update(node_featurizer(mol))
if edge_featurizer is not None:
g.edata.update(edge_featurizer(mol))
return g
def construct_bigraph_from_mol(mol, add_self_loop=False):
"""Construct a bi-directed DGLGraph with topology only for the molecule.
The **i** th atom in the molecule, i.e. ``mol.GetAtomWithIdx(i)``, corresponds to the
**i** th node in the returned DGLGraph.
The **i** th bond in the molecule, i.e. ``mol.GetBondWithIdx(i)``, corresponds to the
**(2i)**-th and **(2i+1)**-th edges in the returned DGLGraph. The **(2i)**-th and
**(2i+1)**-th edges will be separately from **u** to **v** and **v** to **u**, where
**u** is ``bond.GetBeginAtomIdx()`` and **v** is ``bond.GetEndAtomIdx()``.
If self loops are added, the last **n** edges will separately be self loops for
atoms ``0, 1, ..., n-1``.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule holder
add_self_loop : bool
Whether to add self loops in DGLGraphs. Default to False.
Returns
-------
g : DGLGraph
Empty bigraph topology of the molecule
"""
g = DGLGraph()
# Add nodes
num_atoms = mol.GetNumAtoms()
g.add_nodes(num_atoms)
# Add edges
src_list = []
dst_list = []
num_bonds = mol.GetNumBonds()
for i in range(num_bonds):
bond = mol.GetBondWithIdx(i)
u = bond.GetBeginAtomIdx()
v = bond.GetEndAtomIdx()
src_list.extend([u, v])
dst_list.extend([v, u])
g.add_edges(src_list, dst_list)
if add_self_loop:
nodes = g.nodes()
g.add_edges(nodes, nodes)
return g
def mol_to_bigraph(mol, add_self_loop=False,
node_featurizer=None,
edge_featurizer=None):
"""Convert an RDKit molecule object into a bi-directed DGLGraph and featurize for it.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule holder
add_self_loop : bool
Whether to add self loops in DGLGraphs. Default to False.
node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for nodes like atoms in a molecule, which can be used to update
ndata for a DGLGraph. Default to None.
edge_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for edges like bonds in a molecule, which can be used to update
edata for a DGLGraph. Default to None.
Returns
-------
g : DGLGraph
Bi-directed DGLGraph for the molecule
"""
return mol_to_graph(mol, partial(construct_bigraph_from_mol, add_self_loop=add_self_loop),
node_featurizer, edge_featurizer)
def smiles_to_bigraph(smiles, add_self_loop=False,
node_featurizer=None,
edge_featurizer=None):
"""Convert a SMILES into a bi-directed DGLGraph and featurize for it.
Parameters
----------
smiles : str
String of SMILES
add_self_loop : bool
Whether to add self loops in DGLGraphs. Default to False.
node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for nodes like atoms in a molecule, which can be used to update
ndata for a DGLGraph. Default to None.
edge_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for edges like bonds in a molecule, which can be used to update
edata for a DGLGraph. Default to None.
Returns
-------
g : DGLGraph
Bi-directed DGLGraph for the molecule
"""
mol = Chem.MolFromSmiles(smiles)
return mol_to_bigraph(mol, add_self_loop, node_featurizer, edge_featurizer)
def construct_complete_graph_from_mol(mol, add_self_loop=False):
"""Construct a complete graph with topology only for the molecule
The **i** th atom in the molecule, i.e. ``mol.GetAtomWithIdx(i)``, corresponds to the
**i** th node in the returned DGLGraph.
The edges are in the order of (0, 0), (1, 0), (2, 0), ... (0, 1), (1, 1), (2, 1), ...
If self loops are not created, we will not have (0, 0), (1, 1), ...
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule holder
add_self_loop : bool
Whether to add self loops in DGLGraphs. Default to False.
Returns
-------
g : DGLGraph
Empty complete graph topology of the molecule
"""
g = DGLGraph()
num_atoms = mol.GetNumAtoms()
g.add_nodes(num_atoms)
if add_self_loop:
g.add_edges(
[i for i in range(num_atoms) for j in range(num_atoms)],
[j for i in range(num_atoms) for j in range(num_atoms)])
else:
g.add_edges(
[i for i in range(num_atoms) for j in range(num_atoms - 1)], [
j for i in range(num_atoms)
for j in range(num_atoms) if i != j
])
return g
def mol_to_complete_graph(mol, add_self_loop=False,
node_featurizer=None,
edge_featurizer=None):
"""Convert an RDKit molecule into a complete DGLGraph and featurize for it.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule holder
add_self_loop : bool
Whether to add self loops in DGLGraphs. Default to False.
node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for nodes like atoms in a molecule, which can be used to update
ndata for a DGLGraph. Default to None.
edge_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for edges like bonds in a molecule, which can be used to update
edata for a DGLGraph. Default to None.
Returns
-------
g : DGLGraph
Complete DGLGraph for the molecule
"""
return mol_to_graph(mol, partial(construct_complete_graph_from_mol, add_self_loop=add_self_loop),
node_featurizer, edge_featurizer)
def smiles_to_complete_graph(smiles, add_self_loop=False,
node_featurizer=None,
edge_featurizer=None):
"""Convert a SMILES into a complete DGLGraph and featurize for it.
Parameters
----------
smiles : str
String of SMILES
add_self_loop : bool
Whether to add self loops in DGLGraphs. Default to False.
node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for nodes like atoms in a molecule, which can be used to update
ndata for a DGLGraph. Default to None.
edge_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for edges like bonds in a molecule, which can be used to update
edata for a DGLGraph. Default to None.
Returns
-------
g : DGLGraph
Complete DGLGraph for the molecule
"""
mol = Chem.MolFromSmiles(smiles)
return mol_to_complete_graph(mol, add_self_loop, node_featurizer, edge_featurizer)
def k_nearest_neighbors(coordinates, neighbor_cutoff, max_num_neighbors):
"""Find k nearest neighbors for each atom based on the 3D coordinates.
Parameters
----------
coordinates : numpy.ndarray of shape (N, 3)
The 3D coordinates of atoms in the molecule. N for the number of atoms.
neighbor_cutoff : float
Distance cutoff to define 'neighboring'.
max_num_neighbors : int or None.
If not None, then this specifies the maximum number of closest neighbors
allowed for each atom.
Returns
-------
neighbor_list : dict(int -> list of ints)
Mapping atom indices to their k nearest neighbors.
"""
num_atoms = coordinates.shape[0]
traj = mdtraj.Trajectory(coordinates.reshape((1, num_atoms, 3)), None)
neighbors = mdtraj.geometry.compute_neighborlist(traj, neighbor_cutoff)
srcs, dsts, distances = [], [], []
for i in range(num_atoms):
delta = coordinates[i] - coordinates.take(neighbors[i], axis=0)
dist = np.linalg.norm(delta, axis=1)
if max_num_neighbors is not None and len(neighbors[i]) > max_num_neighbors:
sorted_neighbors = list(zip(dist, neighbors[i]))
# Sort neighbors based on distance from smallest to largest
sorted_neighbors.sort(key=lambda tup: tup[0])
dsts.extend([i for _ in range(max_num_neighbors)])
srcs.extend([int(sorted_neighbors[j][1]) for j in range(max_num_neighbors)])
distances.extend([float(sorted_neighbors[j][0]) for j in range(max_num_neighbors)])
else:
dsts.extend([i for _ in range(len(neighbors[i]))])
srcs.extend(neighbors[i].tolist())
distances.extend(dist.tolist())
return srcs, dsts, distances
"""Utils for RDKit, mostly adapted from DeepChem
(https://github.com/deepchem/deepchem/blob/master/deepchem)."""
import warnings
from functools import partial
from multiprocessing import Pool
try:
import pdbfixer
import simtk
from rdkit import Chem
from rdkit.Chem import AllChem
from StringIO import StringIO
except ImportError:
from io import StringIO
__all__ = ['add_hydrogens_to_mol',
'get_mol_3D_coordinates',
'load_molecule',
'multiprocess_load_molecules']
def add_hydrogens_to_mol(mol):
"""Add hydrogens to an RDKit molecule instance.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance.
Returns
-------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance with hydrogens added. For failures in adding hydrogens,
the original RDKit molecule instance will be returned.
"""
try:
pdbblock = Chem.MolToPDBBlock(mol)
pdb_stringio = StringIO()
pdb_stringio.write(pdbblock)
pdb_stringio.seek(0)
fixer = pdbfixer.PDBFixer(pdbfile=pdb_stringio)
fixer.findMissingResidues()
fixer.findMissingAtoms()
fixer.addMissingAtoms()
fixer.addMissingHydrogens(7.4)
hydrogenated_io = StringIO()
simtk.openmm.app.PDBFile.writeFile(fixer.topology, fixer.positions,
hydrogenated_io)
hydrogenated_io.seek(0)
mol = Chem.MolFromPDBBlock(hydrogenated_io.read(), sanitize=False, removeHs=False)
pdb_stringio.close()
hydrogenated_io.close()
except ValueError:
warnings.warn('Failed to add hydrogens to the molecule.')
return mol
def get_mol_3D_coordinates(mol):
"""Get 3D coordinates of the molecule.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance.
Returns
-------
numpy.ndarray of shape (N, 3) or None
The 3D coordinates of atoms in the molecule. N for the number of atoms in
the molecule. For failures in getting the conformations, None will be returned.
"""
try:
conf = mol.GetConformer()
conf_num_atoms = conf.GetNumAtoms()
mol_num_atoms = mol.GetNumAtoms()
assert mol_num_atoms == conf_num_atoms, \
'Expect the number of atoms in the molecule and its conformation ' \
'to be the same, got {:d} and {:d}'.format(mol_num_atoms, conf_num_atoms)
return conf.GetPositions()
except:
warnings.warn('Unable to get conformation of the molecule.')
return None
def load_molecule(molecule_file, add_hydrogens=False, sanitize=False, calc_charges=False,
remove_hs=False, use_conformation=True):
"""Load a molecule from a file.
Parameters
----------
molecule_file : str
Path to file for storing a molecule, which can be of format '.mol2', '.sdf',
'.pdbqt', or '.pdb'.
add_hydrogens : bool
Whether to add hydrogens via pdbfixer. Default to False.
sanitize : bool
Whether sanitization is performed in initializing RDKit molecule instances. See
https://www.rdkit.org/docs/RDKit_Book.html for details of the sanitization.
Default to False.
calc_charges : bool
Whether to add Gasteiger charges via RDKit. Setting this to be True will enforce
``add_hydrogens`` and ``sanitize`` to be True. Default to False.
remove_hs : bool
Whether to remove hydrogens via RDKit. Note that removing hydrogens can be quite
slow for large molecules. Default to False.
use_conformation : bool
Whether we need to extract molecular conformation from proteins and ligands.
Default to True.
Returns
-------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance for the loaded molecule.
coordinates : np.ndarray of shape (N, 3) or None
The 3D coordinates of atoms in the molecule. N for the number of atoms in
the molecule. None will be returned if ``use_conformation`` is False or
we failed to get conformation information.
"""
if molecule_file.endswith('.mol2'):
mol = Chem.MolFromMol2File(molecule_file, sanitize=False, removeHs=False)
elif molecule_file.endswith('.sdf'):
supplier = Chem.SDMolSupplier(molecule_file, sanitize=False, removeHs=False)
mol = supplier[0]
elif molecule_file.endswith('.pdbqt'):
with open(molecule_file) as f:
pdbqt_data = f.readlines()
pdb_block = ''
for line in pdbqt_data:
pdb_block += '{}\n'.format(line[:66])
mol = Chem.MolFromPDBBlock(pdb_block, sanitize=False, removeHs=False)
elif molecule_file.endswith('.pdb'):
mol = Chem.MolFromPDBFile(molecule_file, sanitize=False, removeHs=False)
else:
return ValueError('Expect the format of the molecule_file to be '
'one of .mol2, .sdf, .pdbqt and .pdb, got {}'.format(molecule_file))
try:
if add_hydrogens or calc_charges:
mol = add_hydrogens_to_mol(mol)
if sanitize or calc_charges:
Chem.SanitizeMol(mol)
if calc_charges:
# Compute Gasteiger charges on the molecule.
try:
AllChem.ComputeGasteigerCharges(mol)
except:
warnings.warn('Unable to compute charges for the molecule.')
if remove_hs:
mol = Chem.RemoveHs(mol)
except:
return None, None
if use_conformation:
coordinates = get_mol_3D_coordinates(mol)
else:
coordinates = None
return mol, coordinates
def multiprocess_load_molecules(files, add_hydrogens=False, sanitize=False, calc_charges=False,
remove_hs=False, use_conformation=True, num_processes=2):
"""Load molecules from files with multiprocessing.
Parameters
----------
files : list of str
Each element is a path to a file storing a molecule, which can be of format '.mol2',
'.sdf', '.pdbqt', or '.pdb'.
add_hydrogens : bool
Whether to add hydrogens via pdbfixer. Default to False.
sanitize : bool
Whether sanitization is performed in initializing RDKit molecule instances. See
https://www.rdkit.org/docs/RDKit_Book.html for details of the sanitization.
Default to False.
calc_charges : bool
Whether to add Gasteiger charges via RDKit. Setting this to be True will enforce
``add_hydrogens`` and ``sanitize`` to be True. Default to False.
remove_hs : bool
Whether to remove hydrogens via RDKit. Note that removing hydrogens can be quite
slow for large molecules. Default to False.
use_conformation : bool
Whether we need to extract molecular conformation from proteins and ligands.
Default to True.
num_processes : int or None
Number of worker processes to use. If None,
then we will use the number of CPUs in the systetm. Default to 2.
Returns
-------
list of 2-tuples
The first element of each 2-tuple is an RDKit molecule instance. The second element
of each 2-tuple is the 3D atom coordinates of the corresponding molecule if
use_conformation is True and the coordinates has been successfully loaded. Otherwise,
it will be None.
"""
if num_processes == 1:
mols_loaded = []
for i, f in enumerate(files):
mols_loaded.append(load_molecule(
f, add_hydrogens=add_hydrogens, sanitize=sanitize, calc_charges=calc_charges,
remove_hs=remove_hs, use_conformation=use_conformation))
else:
with Pool(processes=num_processes) as pool:
mols_loaded = pool.map_async(partial(
load_molecule, add_hydrogens=add_hydrogens, sanitize=sanitize,
calc_charges=calc_charges, remove_hs=remove_hs,
use_conformation=use_conformation), files)
mols_loaded = mols_loaded.get()
return mols_loaded
"""Various methods for splitting chemical datasets.
We mostly adapt them from deepchem
(https://github.com/deepchem/deepchem/blob/master/deepchem/splits/splitters.py).
"""
import numpy as np
from collections import defaultdict
from functools import partial
from itertools import accumulate, chain
from ...utils import split_dataset, Subset
from .... import backend as F
try:
from rdkit import Chem
from rdkit.Chem import rdMolDescriptors
from rdkit.Chem.rdmolops import FastFindRings
from rdkit.Chem.Scaffolds import MurckoScaffold
except ImportError:
pass
__all__ = ['ConsecutiveSplitter',
'RandomSplitter',
'MolecularWeightSplitter',
'ScaffoldSplitter',
'SingleTaskStratifiedSplitter']
def base_k_fold_split(split_method, dataset, k, log):
"""Split dataset for k-fold cross validation.
Parameters
----------
split_method : callable
Arbitrary method for splitting the dataset
into training, validation and test subsets.
dataset
We assume ``len(dataset)`` gives the size for the dataset and ``dataset[i]``
gives the ith datapoint.
k : int
Number of folds to use and should be no smaller than 2.
log : bool
Whether to print a message at the start of preparing each fold.
Returns
-------
all_folds : list of 2-tuples
Each element of the list represents a fold and is a 2-tuple (train_set, val_set).
"""
assert k >= 2, 'Expect the number of folds to be no smaller than 2, got {:d}'.format(k)
all_folds = []
frac_per_part = 1./ k
for i in range(k):
if log:
print('Processing fold {:d}/{:d}'.format(i+1, k))
# We are reusing the code for train-validation-test split.
train_set1, val_set, train_set2 = split_method(dataset,
frac_train=i * frac_per_part,
frac_val=frac_per_part,
frac_test=1. - (i + 1) * frac_per_part)
# For cross validation, each fold consists of only a train subset and
# a validation subset.
train_set = Subset(dataset, train_set1.indices + train_set2.indices)
all_folds.append((train_set, val_set))
return all_folds
def train_val_test_sanity_check(frac_train, frac_val, frac_test):
"""Sanity check for train-val-test split
Ensure that the fractions of the dataset to use for training,
validation and test add up to 1.
Parameters
----------
frac_train : float
Fraction of the dataset to use for training.
frac_val : float
Fraction of the dataset to use for validation.
frac_test : float
Fraction of the dataset to use for test.
"""
total_fraction = frac_train + frac_val + frac_test
assert np.allclose(total_fraction, 1.), \
'Expect the sum of fractions for training, validation and ' \
'test to be 1, got {:.4f}'.format(total_fraction)
def indices_split(dataset, frac_train, frac_val, frac_test, indices):
"""Reorder datapoints based on the specified indices and then take consecutive
chunks as subsets.
Parameters
----------
dataset
We assume ``len(dataset)`` gives the size for the dataset and ``dataset[i]``
gives the ith datapoint.
frac_train : float
Fraction of data to use for training.
frac_val : float
Fraction of data to use for validation.
frac_test : float
Fraction of data to use for test.
indices : list or ndarray
Indices specifying the order of datapoints.
Returns
-------
list of length 3
Subsets for training, validation and test, which are all :class:`Subset` instances.
"""
frac_list = np.array([frac_train, frac_val, frac_test])
assert np.allclose(np.sum(frac_list), 1.), \
'Expect frac_list sum to 1, got {:.4f}'.format(np.sum(frac_list))
num_data = len(dataset)
lengths = (num_data * frac_list).astype(int)
lengths[-1] = num_data - np.sum(lengths[:-1])
return [Subset(dataset, list(indices[offset - length:offset]))
for offset, length in zip(accumulate(lengths), lengths)]
def count_and_log(message, i, total, log_every_n):
"""Print a message to reflect the progress of processing once a while.
Parameters
----------
message : str
Message to print.
i : int
Current index.
total : int
Total count.
log_every_n : None or int
Molecule related computation can take a long time for a large dataset and we want
to learn the progress of processing. This can be done by printing a message whenever
a batch of ``log_every_n`` molecules have been processed. If None, no messages will
be printed.
"""
if (log_every_n is not None) and ((i+1) % log_every_n == 0):
print('{} {:d}/{:d}'.format(message, i+1, total))
def prepare_mols(dataset, mols, sanitize, log_every_n=1000):
"""Prepare RDKit molecule instances.
Parameters
----------
dataset
We assume ``len(dataset)`` gives the size for the dataset, ``dataset[i]``
gives the ith datapoint and ``dataset.smiles[i]`` gives the SMILES for the
ith datapoint.
mols : None or list of rdkit.Chem.rdchem.Mol
None or pre-computed RDKit molecule instances. If not None, we expect a
one-on-one correspondence between ``dataset.smiles`` and ``mols``, i.e.
``mols[i]`` corresponds to ``dataset.smiles[i]``.
sanitize : bool
This argument only comes into effect when ``mols`` is None and decides whether
sanitization is performed in initializing RDKit molecule instances. See
https://www.rdkit.org/docs/RDKit_Book.html for details of the sanitization.
log_every_n : None or int
Molecule related computation can take a long time for a large dataset and we want
to learn the progress of processing. This can be done by printing a message whenever
a batch of ``log_every_n`` molecules have been processed. If None, no messages will
be printed. Default to 1000.
Returns
-------
mols : list of rdkit.Chem.rdchem.Mol
RDkit molecule instances where there is a one-on-one correspondence between
``dataset.smiles`` and ``mols``, i.e. ``mols[i]`` corresponds to ``dataset.smiles[i]``.
"""
if mols is not None:
# Sanity check
assert len(mols) == len(dataset), \
'Expect mols to be of the same size as that of the dataset, ' \
'got {:d} and {:d}'.format(len(mols), len(dataset))
else:
if log_every_n is not None:
print('Start initializing RDKit molecule instances...')
mols = []
for i, s in enumerate(dataset.smiles):
count_and_log('Creating RDKit molecule instance',
i, len(dataset.smiles), log_every_n)
mols.append(Chem.MolFromSmiles(s, sanitize=sanitize))
return mols
class ConsecutiveSplitter(object):
"""Split datasets with the input order.
The dataset is split without permutation, so the splitting is deterministic.
"""
@staticmethod
def train_val_test_split(dataset, frac_train=0.8, frac_val=0.1, frac_test=0.1):
"""Split the dataset into three consecutive chunks for training, validation and test.
Parameters
----------
dataset
We assume ``len(dataset)`` gives the size for the dataset and ``dataset[i]``
gives the ith datapoint.
frac_train : float
Fraction of data to use for training. By default, we set this to be 0.8, i.e.
80% of the dataset is used for training.
frac_val : float
Fraction of data to use for validation. By default, we set this to be 0.1, i.e.
10% of the dataset is used for validation.
frac_test : float
Fraction of data to use for test. By default, we set this to be 0.1, i.e.
10% of the dataset is used for test.
Returns
-------
list of length 3
Subsets for training, validation and test, which are all :class:`Subset` instances.
"""
return split_dataset(dataset, frac_list=[frac_train, frac_val, frac_test], shuffle=False)
@staticmethod
def k_fold_split(dataset, k=5, log=True):
"""Split the dataset for k-fold cross validation by taking consecutive chunks.
Parameters
----------
dataset
We assume ``len(dataset)`` gives the size for the dataset and ``dataset[i]``
gives the ith datapoint.
k : int
Number of folds to use and should be no smaller than 2. Default to be 5.
log : bool
Whether to print a message at the start of preparing each fold.
Returns
-------
list of 2-tuples
Each element of the list represents a fold and is a 2-tuple (train_set, val_set).
"""
return base_k_fold_split(ConsecutiveSplitter.train_val_test_split, dataset, k, log)
class RandomSplitter(object):
"""Randomly reorder datasets and then split them.
The dataset is split with permutation and the splitting is hence random.
"""
@staticmethod
def train_val_test_split(dataset, frac_train=0.8, frac_val=0.1,
frac_test=0.1, random_state=None):
"""Randomly permute the dataset and then split it into
three consecutive chunks for training, validation and test.
Parameters
----------
dataset
We assume ``len(dataset)`` gives the size for the dataset and ``dataset[i]``
gives the ith datapoint.
frac_train : float
Fraction of data to use for training. By default, we set this to be 0.8, i.e.
80% of the dataset is used for training.
frac_val : float
Fraction of data to use for validation. By default, we set this to be 0.1, i.e.
10% of the dataset is used for validation.
frac_test : float
Fraction of data to use for test. By default, we set this to be 0.1, i.e.
10% of the dataset is used for test.
random_state : None, int or array_like, optional
Random seed used to initialize the pseudo-random number generator.
Can be any integer between 0 and 2**32 - 1 inclusive, an array
(or other sequence) of such integers, or None (the default).
If seed is None, then RandomState will try to read data from /dev/urandom
(or the Windows analogue) if available or seed from the clock otherwise.
Returns
-------
list of length 3
Subsets for training, validation and test.
"""
return split_dataset(dataset, frac_list=[frac_train, frac_val, frac_test],
shuffle=True, random_state=random_state)
@staticmethod
def k_fold_split(dataset, k=5, random_state=None, log=True):
"""Randomly permute the dataset and then split it
for k-fold cross validation by taking consecutive chunks.
Parameters
----------
dataset
We assume ``len(dataset)`` gives the size for the dataset and ``dataset[i]``
gives the ith datapoint.
k : int
Number of folds to use and should be no smaller than 2. Default to be 5.
random_state : None, int or array_like, optional
Random seed used to initialize the pseudo-random number generator.
Can be any integer between 0 and 2**32 - 1 inclusive, an array
(or other sequence) of such integers, or None (the default).
If seed is None, then RandomState will try to read data from /dev/urandom
(or the Windows analogue) if available or seed from the clock otherwise.
log : bool
Whether to print a message at the start of preparing each fold. Default to True.
Returns
-------
list of 2-tuples
Each element of the list represents a fold and is a 2-tuple (train_set, val_set).
"""
# Permute the dataset only once so that each datapoint
# will appear once in exactly one fold.
indices = np.random.RandomState(seed=random_state).permutation(len(dataset))
return base_k_fold_split(partial(indices_split, indices=indices), dataset, k, log)
class MolecularWeightSplitter(object):
"""Sort molecules based on their weights and then split them."""
@staticmethod
def molecular_weight_indices(molecules, log_every_n):
"""Reorder molecules based on molecular weights.
Parameters
----------
molecules : list of rdkit.Chem.rdchem.Mol
Pre-computed RDKit molecule instances. We expect a one-on-one
correspondence between ``dataset.smiles`` and ``mols``, i.e.
``mols[i]`` corresponds to ``dataset.smiles[i]``.
log_every_n : None or int
Molecule related computation can take a long time for a large dataset and we want
to learn the progress of processing. This can be done by printing a message whenever
a batch of ``log_every_n`` molecules have been processed. If None, no messages will
be printed.
Returns
-------
indices : list or ndarray
Indices specifying the order of datapoints, which are basically
argsort of the molecular weights.
"""
if log_every_n is not None:
print('Start computing molecular weights.')
mws = []
for i, mol in enumerate(molecules):
count_and_log('Computing molecular weight for compound',
i, len(molecules), log_every_n)
mws.append(Chem.rdMolDescriptors.CalcExactMolWt(mol))
return np.argsort(mws)
@staticmethod
def train_val_test_split(dataset, mols=None, sanitize=True, frac_train=0.8,
frac_val=0.1, frac_test=0.1, log_every_n=1000):
"""Sort molecules based on their weights and then split them into
three consecutive chunks for training, validation and test.
Parameters
----------
dataset
We assume ``len(dataset)`` gives the size for the dataset, ``dataset[i]``
gives the ith datapoint and ``dataset.smiles[i]`` gives the SMILES for the
ith datapoint.
mols : None or list of rdkit.Chem.rdchem.Mol
None or pre-computed RDKit molecule instances. If not None, we expect a
one-on-one correspondence between ``dataset.smiles`` and ``mols``, i.e.
``mols[i]`` corresponds to ``dataset.smiles[i]``. Default to None.
sanitize : bool
This argument only comes into effect when ``mols`` is None and decides whether
sanitization is performed in initializing RDKit molecule instances. See
https://www.rdkit.org/docs/RDKit_Book.html for details of the sanitization.
Default to be True.
frac_train : float
Fraction of data to use for training. By default, we set this to be 0.8, i.e.
80% of the dataset is used for training.
frac_val : float
Fraction of data to use for validation. By default, we set this to be 0.1, i.e.
10% of the dataset is used for validation.
frac_test : float
Fraction of data to use for test. By default, we set this to be 0.1, i.e.
10% of the dataset is used for test.
log_every_n : None or int
Molecule related computation can take a long time for a large dataset and we want
to learn the progress of processing. This can be done by printing a message whenever
a batch of ``log_every_n`` molecules have been processed. If None, no messages will
be printed. Default to 1000.
Returns
-------
list of length 3
Subsets for training, validation and test, which are all :class:`Subset` instances.
"""
# Perform sanity check first as molecule instance initialization and descriptor
# computation can take a long time.
train_val_test_sanity_check(frac_train, frac_val, frac_test)
molecules = prepare_mols(dataset, mols, sanitize, log_every_n)
sorted_indices = MolecularWeightSplitter.molecular_weight_indices(molecules, log_every_n)
return indices_split(dataset, frac_train, frac_val, frac_test, sorted_indices)
@staticmethod
def k_fold_split(dataset, mols=None, sanitize=True, k=5, log_every_n=1000):
"""Sort molecules based on their weights and then split them
for k-fold cross validation by taking consecutive chunks.
Parameters
----------
dataset
We assume ``len(dataset)`` gives the size for the dataset, ``dataset[i]``
gives the ith datapoint and ``dataset.smiles[i]`` gives the SMILES for the
ith datapoint.
mols : None or list of rdkit.Chem.rdchem.Mol
None or pre-computed RDKit molecule instances. If not None, we expect a
one-on-one correspondence between ``dataset.smiles`` and ``mols``, i.e.
``mols[i]`` corresponds to ``dataset.smiles[i]``. Default to None.
sanitize : bool
This argument only comes into effect when ``mols`` is None and decides whether
sanitization is performed in initializing RDKit molecule instances. See
https://www.rdkit.org/docs/RDKit_Book.html for details of the sanitization.
Default to be True.
k : int
Number of folds to use and should be no smaller than 2. Default to be 5.
log_every_n : None or int
Molecule related computation can take a long time for a large dataset and we want
to learn the progress of processing. This can be done by printing a message whenever
a batch of ``log_every_n`` molecules have been processed. If None, no messages will
be printed. Default to 1000.
Returns
-------
list of 2-tuples
Each element of the list represents a fold and is a 2-tuple (train_set, val_set).
"""
molecules = prepare_mols(dataset, mols, sanitize, log_every_n)
sorted_indices = MolecularWeightSplitter.molecular_weight_indices(molecules, log_every_n)
return base_k_fold_split(partial(indices_split, indices=sorted_indices), dataset, k,
log=(log_every_n is not None))
class ScaffoldSplitter(object):
"""Group molecules based on their Bemis-Murcko scaffolds and then split the groups.
Group molecules so that all molecules in a group have a same scaffold (see reference).
The dataset is then split at the level of groups.
References
----------
Bemis, G. W.; Murcko, M. A. “The Properties of Known Drugs.
1. Molecular Frameworks.” J. Med. Chem. 39:2887-93 (1996).
"""
@staticmethod
def get_ordered_scaffold_sets(molecules, include_chirality, log_every_n):
"""Group molecules based on their Bemis-Murcko scaffolds and
order these groups based on their sizes.
The order is decided by comparing the size of groups, where groups with a larger size
are placed before the ones with a smaller size.
Parameters
----------
molecules : list of rdkit.Chem.rdchem.Mol
Pre-computed RDKit molecule instances. We expect a one-on-one
correspondence between ``dataset.smiles`` and ``mols``, i.e.
``mols[i]`` corresponds to ``dataset.smiles[i]``.
include_chirality : bool
Whether to consider chirality in computing scaffolds.
log_every_n : None or int
Molecule related computation can take a long time for a large dataset and we want
to learn the progress of processing. This can be done by printing a message whenever
a batch of ``log_every_n`` molecules have been processed. If None, no messages will
be printed.
Returns
-------
scaffold_sets : list
Each element of the list is a list of int,
representing the indices of compounds with a same scaffold.
"""
if log_every_n is not None:
print('Start computing Bemis-Murcko scaffolds.')
scaffolds = defaultdict(list)
for i, mol in enumerate(molecules):
count_and_log('Computing Bemis-Murcko for compound',
i, len(molecules), log_every_n)
# For mols that have not been sanitized, we need to compute their ring information
try:
FastFindRings(mol)
mol_scaffold = MurckoScaffold.MurckoScaffoldSmiles(
mol=mol, includeChirality=include_chirality)
# Group molecules that have the same scaffold
scaffolds[mol_scaffold].append(i)
except:
print('Failed to compute the scaffold for molecule {:d} '
'and it will be excluded.'.format(i+1))
# Order groups of molecules by first comparing the size of groups
# and then the index of the first compound in the group.
scaffold_sets = [
scaffold_set for (scaffold, scaffold_set) in sorted(
scaffolds.items(), key=lambda x: (len(x[1]), x[1][0]), reverse=True)
]
return scaffold_sets
@staticmethod
def train_val_test_split(dataset, mols=None, sanitize=True, include_chirality=False,
frac_train=0.8, frac_val=0.1, frac_test=0.1, log_every_n=1000):
"""Split the dataset into training, validation and test set based on molecular scaffolds.
This spliting method ensures that molecules with a same scaffold will be collectively
in only one of the training, validation or test set. As a result, the fraction
of dataset to use for training and validation tend to be smaller than ``frac_train``
and ``frac_val``, while the fraction of dataset to use for test tends to be larger
than ``frac_test``.
Parameters
----------
dataset
We assume ``len(dataset)`` gives the size for the dataset, ``dataset[i]``
gives the ith datapoint and ``dataset.smiles[i]`` gives the SMILES for the
ith datapoint.
mols : None or list of rdkit.Chem.rdchem.Mol
None or pre-computed RDKit molecule instances. If not None, we expect a
one-on-one correspondence between ``dataset.smiles`` and ``mols``, i.e.
``mols[i]`` corresponds to ``dataset.smiles[i]``. Default to None.
sanitize : bool
This argument only comes into effect when ``mols`` is None and decides whether
sanitization is performed in initializing RDKit molecule instances. See
https://www.rdkit.org/docs/RDKit_Book.html for details of the sanitization.
Default to True.
include_chirality : bool
Whether to consider chirality in computing scaffolds. Default to False.
frac_train : float
Fraction of data to use for training. By default, we set this to be 0.8, i.e.
80% of the dataset is used for training.
frac_val : float
Fraction of data to use for validation. By default, we set this to be 0.1, i.e.
10% of the dataset is used for validation.
frac_test : float
Fraction of data to use for test. By default, we set this to be 0.1, i.e.
10% of the dataset is used for test.
log_every_n : None or int
Molecule related computation can take a long time for a large dataset and we want
to learn the progress of processing. This can be done by printing a message whenever
a batch of ``log_every_n`` molecules have been processed. If None, no messages will
be printed. Default to 1000.
Returns
-------
list of length 3
Subsets for training, validation and test, which are all :class:`Subset` instances.
"""
# Perform sanity check first as molecule related computation can take a long time.
train_val_test_sanity_check(frac_train, frac_val, frac_test)
molecules = prepare_mols(dataset, mols, sanitize)
scaffold_sets = ScaffoldSplitter.get_ordered_scaffold_sets(
molecules, include_chirality, log_every_n)
train_indices, val_indices, test_indices = [], [], []
train_cutoff = int(frac_train * len(molecules))
val_cutoff = int((frac_train + frac_val) * len(molecules))
for group_indices in scaffold_sets:
if len(train_indices) + len(group_indices) > train_cutoff:
if len(train_indices) + len(val_indices) + len(group_indices) > val_cutoff:
test_indices.extend(group_indices)
else:
val_indices.extend(group_indices)
else:
train_indices.extend(group_indices)
return [Subset(dataset, train_indices),
Subset(dataset, val_indices),
Subset(dataset, test_indices)]
@staticmethod
def k_fold_split(dataset, mols=None, sanitize=True,
include_chirality=False, k=5, log_every_n=1000):
"""Group molecules based on their scaffolds and sort groups based on their sizes.
The groups are then split for k-fold cross validation.
Same as usual k-fold splitting methods, each molecule will appear only once
in the validation set among all folds. In addition, this method ensures that
molecules with a same scaffold will be collectively in either the training
set or the validation set for each fold.
Note that the folds can be highly imbalanced depending on the
scaffold distribution in the dataset.
Parameters
----------
dataset
We assume ``len(dataset)`` gives the size for the dataset, ``dataset[i]``
gives the ith datapoint and ``dataset.smiles[i]`` gives the SMILES for the
ith datapoint.
mols : None or list of rdkit.Chem.rdchem.Mol
None or pre-computed RDKit molecule instances. If not None, we expect a
one-on-one correspondence between ``dataset.smiles`` and ``mols``, i.e.
``mols[i]`` corresponds to ``dataset.smiles[i]``. Default to None.
sanitize : bool
This argument only comes into effect when ``mols`` is None and decides whether
sanitization is performed in initializing RDKit molecule instances. See
https://www.rdkit.org/docs/RDKit_Book.html for details of the sanitization.
Default to True.
include_chirality : bool
Whether to consider chirality in computing scaffolds. Default to False.
k : int
Number of folds to use and should be no smaller than 2. Default to be 5.
log_every_n : None or int
Molecule related computation can take a long time for a large dataset and we want
to learn the progress of processing. This can be done by printing a message whenever
a batch of ``log_every_n`` molecules have been processed. If None, no messages will
be printed. Default to 1000.
Returns
-------
list of 2-tuples
Each element of the list represents a fold and is a 2-tuple (train_set, val_set).
"""
assert k >= 2, 'Expect the number of folds to be no smaller than 2, got {:d}'.format(k)
molecules = prepare_mols(dataset, mols, sanitize)
scaffold_sets = ScaffoldSplitter.get_ordered_scaffold_sets(
molecules, include_chirality, log_every_n)
# k buckets that form a relatively balanced partition of the dataset
index_buckets = [[] for _ in range(k)]
for group_indices in scaffold_sets:
bucket_chosen = int(np.argmin([len(bucket) for bucket in index_buckets]))
index_buckets[bucket_chosen].extend(group_indices)
all_folds = []
for i in range(k):
if log_every_n is not None:
print('Processing fold {:d}/{:d}'.format(i + 1, k))
train_indices = list(chain.from_iterable(index_buckets[:i] + index_buckets[i+1:]))
val_indices = index_buckets[i]
all_folds.append((Subset(dataset, train_indices), Subset(dataset, val_indices)))
return all_folds
class SingleTaskStratifiedSplitter(object):
"""Splits the dataset by stratification on a single task.
We sort the molecules based on their label values for a task and then repeatedly
take buckets of datapoints to augment the training, validation and test subsets.
"""
@staticmethod
def train_val_test_split(dataset, labels, task_id, frac_train=0.8, frac_val=0.1,
frac_test=0.1, bucket_size=10, random_state=None):
"""Split the dataset into training, validation and test subsets as stated above.
Parameters
----------
dataset
We assume ``len(dataset)`` gives the size for the dataset, ``dataset[i]``
gives the ith datapoint and ``dataset.smiles[i]`` gives the SMILES for the
ith datapoint.
labels : tensor of shape (N, T)
Dataset labels all tasks. N for the number of datapoints and T for the number
of tasks.
task_id : int
Index for the task.
frac_train : float
Fraction of data to use for training. By default, we set this to be 0.8, i.e.
80% of the dataset is used for training.
frac_val : float
Fraction of data to use for validation. By default, we set this to be 0.1, i.e.
10% of the dataset is used for validation.
frac_test : float
Fraction of data to use for test. By default, we set this to be 0.1, i.e.
10% of the dataset is used for test.
bucket_size : int
Size of bucket of datapoints. Default to 10.
random_state : None, int or array_like, optional
Random seed used to initialize the pseudo-random number generator.
Can be any integer between 0 and 2**32 - 1 inclusive, an array
(or other sequence) of such integers, or None (the default).
If seed is None, then RandomState will try to read data from /dev/urandom
(or the Windows analogue) if available or seed from the clock otherwise.
Returns
-------
list of length 3
Subsets for training, validation and test, which are all :class:`Subset` instances.
"""
train_val_test_sanity_check(frac_train, frac_val, frac_test)
if random_state is not None:
np.random.seed(random_state)
if not isinstance(labels, np.ndarray):
labels = F.asnumpy(labels)
task_labels = labels[:, task_id]
sorted_indices = np.argsort(task_labels)
train_bucket_cutoff = int(np.round(frac_train * bucket_size))
val_bucket_cutoff = int(np.round(frac_val * bucket_size)) + train_bucket_cutoff
train_indices, val_indices, test_indices = [], [], []
while sorted_indices.shape[0] >= bucket_size:
current_batch, sorted_indices = np.split(sorted_indices, [bucket_size])
shuffled = np.random.permutation(range(bucket_size))
train_indices.extend(
current_batch[shuffled[:train_bucket_cutoff]].tolist())
val_indices.extend(
current_batch[shuffled[train_bucket_cutoff:val_bucket_cutoff]].tolist())
test_indices.extend(
current_batch[shuffled[val_bucket_cutoff:]].tolist())
# Place rest samples in the training set.
train_indices.extend(sorted_indices.tolist())
return [Subset(dataset, train_indices),
Subset(dataset, val_indices),
Subset(dataset, test_indices)]
@staticmethod
def k_fold_split(dataset, labels, task_id, k=5, log=True):
"""Sort molecules based on their label values for a task and then split them
for k-fold cross validation by taking consecutive chunks.
Parameters
----------
dataset
We assume ``len(dataset)`` gives the size for the dataset, ``dataset[i]``
gives the ith datapoint and ``dataset.smiles[i]`` gives the SMILES for the
ith datapoint.
labels : tensor of shape (N, T)
Dataset labels all tasks. N for the number of datapoints and T for the number
of tasks.
task_id : int
Index for the task.
k : int
Number of folds to use and should be no smaller than 2. Default to be 5.
log : bool
Whether to print a message at the start of preparing each fold.
Returns
-------
list of 2-tuples
Each element of the list represents a fold and is a 2-tuple (train_set, val_set).
"""
if not isinstance(labels, np.ndarray):
labels = F.asnumpy(labels)
task_labels = labels[:, task_id]
sorted_indices = np.argsort(task_labels).tolist()
return base_k_fold_split(partial(indices_split, indices=sorted_indices), dataset, k, log)
# pylint: disable=C0111 # pylint: disable=C0111
"""Model Zoo Package""" """Model Zoo Package"""
from .classifiers import GCNClassifier, GATClassifier from .classifiers import GCNClassifier, GATClassifier
from .schnet import SchNet from .schnet import SchNet
from .mgcn import MGCNModel from .mgcn import MGCNModel
...@@ -9,3 +8,4 @@ from .dgmg import DGMG ...@@ -9,3 +8,4 @@ from .dgmg import DGMG
from .jtnn import DGLJTNNVAE from .jtnn import DGLJTNNVAE
from .pretrain import load_pretrained from .pretrain import load_pretrained
from .attentive_fp import AttentiveFP from .attentive_fp import AttentiveFP
from .acnn import ACNN
"""Atomic Convolutional Networks for Predicting Protein-Ligand Binding Affinity"""
# pylint: disable=C0103, C0123
import itertools
import torch
import torch.nn as nn
from ...nn.pytorch import AtomicConv
def truncated_normal_(tensor, mean=0., std=1.):
"""Fills the given tensor in-place with elements sampled from the truncated normal
distribution parameterized by mean and std.
The generated values follow a normal distribution with specified mean and
standard deviation, except that values whose magnitude is more than 2 std
from the mean are dropped.
We credit to Ruotian Luo for this implementation:
https://discuss.pytorch.org/t/implementing-truncated-normal-initializer/4778/15.
Parameters
----------
tensor : Float32 tensor of arbitrary shape
Tensor to be filled.
mean : float
Mean of the truncated normal distribution.
std : float
Standard deviation of the truncated normal distribution.
"""
shape = tensor.shape
tmp = tensor.new_empty(shape + (4,)).normal_()
valid = (tmp < 2) & (tmp > -2)
ind = valid.max(-1, keepdim=True)[1]
tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1))
tensor.data.mul_(std).add_(mean)
class ACNNPredictor(nn.Module):
"""Predictor for ACNN.
Parameters
----------
in_size : int
Number of radial filters used.
hidden_sizes : list of int
Specifying the hidden sizes for all layers in the predictor.
weight_init_stddevs : list of float
Specifying the standard deviations to use for truncated normal
distributions in initialzing weights for the predictor.
dropouts : list of float
Specifying the dropouts to use for all layers in the predictor.
features_to_use : None or float tensor of shape (T)
In the original paper, these are atomic numbers to consider, representing the types
of atoms. T for the number of types of atomic numbers. Default to None.
num_tasks : int
Output size.
"""
def __init__(self, in_size, hidden_sizes, weight_init_stddevs,
dropouts, features_to_use, num_tasks):
super(ACNNPredictor, self).__init__()
if type(features_to_use) != type(None):
in_size *= len(features_to_use)
modules = []
for i, h in enumerate(hidden_sizes):
linear_layer = nn.Linear(in_size, h)
truncated_normal_(linear_layer.weight, std=weight_init_stddevs[i])
modules.append(linear_layer)
modules.append(nn.ReLU())
modules.append(nn.Dropout(dropouts[i]))
in_size = h
linear_layer = nn.Linear(in_size, num_tasks)
truncated_normal_(linear_layer.weight, std=weight_init_stddevs[-1])
modules.append(linear_layer)
self.project = nn.Sequential(*modules)
def forward(self, batch_size, frag1_node_indices_in_complex, frag2_node_indices_in_complex,
ligand_conv_out, protein_conv_out, complex_conv_out):
"""Perform the prediction.
Parameters
----------
batch_size : int
Number of datapoints in a batch.
frag1_node_indices_in_complex : Int64 tensor of shape (V1)
Indices for atoms in the first fragment (protein) in the batched complex.
frag2_node_indices_in_complex : list of int of length V2
Indices for atoms in the second fragment (ligand) in the batched complex.
ligand_conv_out : Float32 tensor of shape (V2, K * T)
Updated ligand node representations. V2 for the number of atoms in the
ligand, K for the number of radial filters, and T for the number of types
of atomic numbers.
protein_conv_out : Float32 tensor of shape (V1, K * T)
Updated protein node representations. V1 for the number of
atoms in the protein, K for the number of radial filters,
and T for the number of types of atomic numbers.
complex_conv_out : Float32 tensor of shape (V1 + V2, K * T)
Updated complex node representations. V1 and V2 separately
for the number of atoms in the ligand and protein, K for
the number of radial filters, and T for the number of
types of atomic numbers.
Returns
-------
Float32 tensor of shape (B, O)
Predicted protein-ligand binding affinity. B for the number
of protein-ligand pairs in the batch and O for the number of tasks.
"""
ligand_feats = self.project(ligand_conv_out) # (V1, O)
protein_feats = self.project(protein_conv_out) # (V2, O)
complex_feats = self.project(complex_conv_out) # (V1+V2, O)
ligand_energy = ligand_feats.reshape(batch_size, -1).sum(-1, keepdim=True) # (B, O)
protein_energy = protein_feats.reshape(batch_size, -1).sum(-1, keepdim=True) # (B, O)
complex_ligand_energy = complex_feats[frag1_node_indices_in_complex].reshape(
batch_size, -1).sum(-1, keepdim=True)
complex_protein_energy = complex_feats[frag2_node_indices_in_complex].reshape(
batch_size, -1).sum(-1, keepdim=True)
complex_energy = complex_ligand_energy + complex_protein_energy
return complex_energy - (ligand_energy + protein_energy)
class ACNN(nn.Module):
"""Atomic Convolutional Networks.
The model was proposed in `Atomic Convolutional Networks for
Predicting Protein-Ligand Binding Affinity <https://arxiv.org/abs/1703.10603>`__.
Parameters
----------
hidden_sizes : list of int
Specifying the hidden sizes for all layers in the predictor.
weight_init_stddevs : list of float
Specifying the standard deviations to use for truncated normal
distributions in initialzing weights for the predictor.
dropouts : list of float
Specifying the dropouts to use for all layers in the predictor.
features_to_use : None or float tensor of shape (T)
In the original paper, these are atomic numbers to consider, representing the types
of atoms. T for the number of types of atomic numbers. Default to None.
radial : None or list
If not None, the list consists of 3 lists of floats, separately for the
options of interaction cutoff, the options of rbf kernel mean and the
options of rbf kernel scaling. If None, a default option of
``[[12.0], [0.0, 2.0, 4.0, 6.0, 8.0], [4.0]]`` will be used.
num_tasks : int
Number of output tasks.
"""
def __init__(self, hidden_sizes, weight_init_stddevs, dropouts,
features_to_use=None, radial=None, num_tasks=1):
super(ACNN, self).__init__()
if radial is None:
radial = [[12.0], [0.0, 2.0, 4.0, 6.0, 8.0], [4.0]]
# Take the product of sets of options and get a list of 3-tuples.
radial_params = [x for x in itertools.product(*radial)]
radial_params = torch.stack(list(map(torch.tensor, zip(*radial_params))), dim=1)
interaction_cutoffs = radial_params[:, 0]
rbf_kernel_means = radial_params[:, 1]
rbf_kernel_scaling = radial_params[:, 2]
self.ligand_conv = AtomicConv(interaction_cutoffs, rbf_kernel_means,
rbf_kernel_scaling, features_to_use)
self.protein_conv = AtomicConv(interaction_cutoffs, rbf_kernel_means,
rbf_kernel_scaling, features_to_use)
self.complex_conv = AtomicConv(interaction_cutoffs, rbf_kernel_means,
rbf_kernel_scaling, features_to_use)
self.predictor = ACNNPredictor(radial_params.shape[0], hidden_sizes,
weight_init_stddevs, dropouts, features_to_use, num_tasks)
def forward(self, graph):
"""Apply the model for prediction.
Parameters
----------
graph : DGLHeteroGraph
DGLHeteroGraph consisting of the ligand graph, the protein graph
and the complex graph, along with preprocessed features.
Returns
-------
Float32 tensor of shape (B, O)
Predicted protein-ligand binding affinity. B for the number
of protein-ligand pairs in the batch and O for the number of tasks.
"""
ligand_graph = graph[('ligand_atom', 'ligand', 'ligand_atom')]
ligand_graph_node_feats = ligand_graph.ndata['atomic_number']
assert ligand_graph_node_feats.shape[-1] == 1
ligand_graph_distances = ligand_graph.edata['distance']
ligand_conv_out = self.ligand_conv(ligand_graph,
ligand_graph_node_feats,
ligand_graph_distances)
protein_graph = graph[('protein_atom', 'protein', 'protein_atom')]
protein_graph_node_feats = protein_graph.ndata['atomic_number']
assert protein_graph_node_feats.shape[-1] == 1
protein_graph_distances = protein_graph.edata['distance']
protein_conv_out = self.protein_conv(protein_graph,
protein_graph_node_feats,
protein_graph_distances)
complex_graph = graph[:, 'complex', :]
complex_graph_node_feats = complex_graph.ndata['atomic_number']
assert complex_graph_node_feats.shape[-1] == 1
complex_graph_distances = complex_graph.edata['distance']
complex_conv_out = self.complex_conv(complex_graph,
complex_graph_node_feats,
complex_graph_distances)
frag1_node_indices_in_complex = torch.where(complex_graph.ndata['_TYPE'] == 0)[0]
frag2_node_indices_in_complex = list(set(range(complex_graph.number_of_nodes())) -
set(frag1_node_indices_in_complex.tolist()))
return self.predictor(
graph.batch_size,
frag1_node_indices_in_complex,
frag2_node_indices_in_complex,
ligand_conv_out, protein_conv_out, complex_conv_out)
"""Utilities for using pretrained models.""" """Utilities for using pretrained models."""
import os import os
import numpy as np
import torch import torch
from rdkit import Chem from rdkit import Chem
...@@ -10,20 +11,21 @@ from .mgcn import MGCNModel ...@@ -10,20 +11,21 @@ from .mgcn import MGCNModel
from .mpnn import MPNNModel from .mpnn import MPNNModel
from .schnet import SchNet from .schnet import SchNet
from .attentive_fp import AttentiveFP from .attentive_fp import AttentiveFP
from .acnn import ACNN
from ...data.utils import _get_dgl_url, download, get_download_dir, extract_archive from ...data.utils import _get_dgl_url, download, get_download_dir, extract_archive
URL = { URL = {
'GCN_Tox21' : 'pre_trained/gcn_tox21.pth', 'GCN_Tox21': 'pre_trained/gcn_tox21.pth',
'GAT_Tox21' : 'pre_trained/gat_tox21.pth', 'GAT_Tox21': 'pre_trained/gat_tox21.pth',
'MGCN_Alchemy': 'pre_trained/mgcn_alchemy.pth', 'MGCN_Alchemy': 'pre_trained/mgcn_alchemy.pth',
'SCHNET_Alchemy': 'pre_trained/schnet_alchemy.pth', 'SCHNET_Alchemy': 'pre_trained/schnet_alchemy.pth',
'MPNN_Alchemy': 'pre_trained/mpnn_alchemy.pth', 'MPNN_Alchemy': 'pre_trained/mpnn_alchemy.pth',
'AttentiveFP_Aromaticity': 'pre_trained/attentivefp_aromaticity.pth', 'AttentiveFP_Aromaticity': 'pre_trained/attentivefp_aromaticity.pth',
'DGMG_ChEMBL_canonical' : 'pre_trained/dgmg_ChEMBL_canonical.pth', 'DGMG_ChEMBL_canonical': 'pre_trained/dgmg_ChEMBL_canonical.pth',
'DGMG_ChEMBL_random' : 'pre_trained/dgmg_ChEMBL_random.pth', 'DGMG_ChEMBL_random': 'pre_trained/dgmg_ChEMBL_random.pth',
'DGMG_ZINC_canonical' : 'pre_trained/dgmg_ZINC_canonical.pth', 'DGMG_ZINC_canonical': 'pre_trained/dgmg_ZINC_canonical.pth',
'DGMG_ZINC_random' : 'pre_trained/dgmg_ZINC_random.pth', 'DGMG_ZINC_random': 'pre_trained/dgmg_ZINC_random.pth',
'JTNN_ZINC':'pre_trained/JTNN_ZINC.pth' 'JTNN_ZINC': 'pre_trained/JTNN_ZINC.pth'
} }
def download_and_load_checkpoint(model_name, model, model_postfix, def download_and_load_checkpoint(model_name, model, model_postfix,
...@@ -56,6 +58,9 @@ def download_and_load_checkpoint(model_name, model, model_postfix, ...@@ -56,6 +58,9 @@ def download_and_load_checkpoint(model_name, model, model_postfix,
checkpoint = torch.load(local_pretrained_path, map_location='cpu') checkpoint = torch.load(local_pretrained_path, map_location='cpu')
model.load_state_dict(checkpoint['model_state_dict']) model.load_state_dict(checkpoint['model_state_dict'])
if log:
print('Pretrained model loaded')
return model return model
def load_pretrained(model_name, log=True): def load_pretrained(model_name, log=True):
...@@ -77,6 +82,14 @@ def load_pretrained(model_name, log=True): ...@@ -77,6 +82,14 @@ def load_pretrained(model_name, log=True):
* ``'DGMG_ZINC_canonical'`` * ``'DGMG_ZINC_canonical'``
* ``'DGMG_ZINC_random'`` * ``'DGMG_ZINC_random'``
* ``'JTNN_ZINC'`` * ``'JTNN_ZINC'``
* ``'ACNN_PDBBind_core_pocket_random'``
* ``'ACNN_PDBBind_core_pocket_scaffold'``
* ``'ACNN_PDBBind_core_pocket_stratified'``
* ``'ACNN_PDBBind_core_pocket_temporal'``
* ``'ACNN_PDBBind_refined_pocket_random'``
* ``'ACNN_PDBBind_refined_pocket_scaffold'``
* ``'ACNN_PDBBind_refined_pocket_stratified'``
* ``'ACNN_PDBBind_refined_pocket_temporal'``
log : bool log : bool
Whether to print progress for model loading Whether to print progress for model loading
...@@ -147,7 +160,13 @@ def load_pretrained(model_name, log=True): ...@@ -147,7 +160,13 @@ def load_pretrained(model_name, log=True):
hidden_size=450, hidden_size=450,
latent_size=56) latent_size=56)
if log: elif model_name.startswith('ACNN_PDBBind_core_pocket'):
print('Pretrained model loaded') model = ACNN(hidden_sizes=[32, 32, 16],
weight_init_stddevs=[1. / float(np.sqrt(32)), 1. / float(np.sqrt(32)),
1. / float(np.sqrt(16)), 0.01],
dropouts=[0., 0., 0.],
features_to_use=torch.tensor([
1., 6., 7., 8., 9., 11., 12., 15., 16., 17., 20., 25., 30., 35., 53.]),
radial=[[12.0], [0.0, 4.0, 8.0], [4.0]])
return download_and_load_checkpoint(model_name, model, URL[model_name], log=log) return download_and_load_checkpoint(model_name, model, URL[model_name], log=log)
...@@ -18,8 +18,9 @@ from .gatedgraphconv import GatedGraphConv ...@@ -18,8 +18,9 @@ from .gatedgraphconv import GatedGraphConv
from .densechebconv import DenseChebConv from .densechebconv import DenseChebConv
from .densegraphconv import DenseGraphConv from .densegraphconv import DenseGraphConv
from .densesageconv import DenseSAGEConv from .densesageconv import DenseSAGEConv
from .atomicconv import AtomicConv
__all__ = ['GraphConv', 'GATConv', 'TAGConv', 'RelGraphConv', 'SAGEConv', __all__ = ['GraphConv', 'GATConv', 'TAGConv', 'RelGraphConv', 'SAGEConv',
'SGConv', 'APPNPConv', 'GINConv', 'GatedGraphConv', 'GMMConv', 'SGConv', 'APPNPConv', 'GINConv', 'GatedGraphConv', 'GMMConv',
'ChebConv', 'AGNNConv', 'NNConv', 'DenseGraphConv', 'DenseSAGEConv', 'ChebConv', 'AGNNConv', 'NNConv', 'DenseGraphConv', 'DenseSAGEConv',
'DenseChebConv', 'EdgeConv'] 'DenseChebConv', 'EdgeConv', 'AtomicConv']
"""Torch Module for Atomic Convolution Layer"""
import numpy as np
import torch as th
import torch.nn as nn
class RadialPooling(nn.Module):
r"""Radial pooling from paper `Atomic Convolutional Networks for
Predicting Protein-Ligand Binding Affinity <https://arxiv.org/abs/1703.10603>`__.
We denote the distance between atom :math:`i` and :math:`j` by :math:`r_{ij}`.
A radial pooling layer transforms distances with radial filters. For radial filter
indexed by :math:`k`, it projects edge distances with
.. math::
h_{ij}^{k} = \exp(-\gamma_{k}|r_{ij}-r_{k}|^2)
If :math:`r_{ij} < c_k`,
.. math::
f_{ij}^{k} = 0.5 * \cos(\frac{\pi r_{ij}}{c_k} + 1),
else,
.. math::
f_{ij}^{k} = 0.
Finally,
.. math::
e_{ij}^{k} = h_{ij}^{k} * f_{ij}^{k}
Parameters
----------
interaction_cutoffs : float32 tensor of shape (K)
:math:`c_k` in the equations above. Roughly they can be considered as learnable cutoffs
and two atoms are considered as connected if the distance between them is smaller than
the cutoffs. K for the number of radial filters.
rbf_kernel_means : float32 tensor of shape (K)
:math:`r_k` in the equations above. K for the number of radial filters.
rbf_kernel_scaling : float32 tensor of shape (K)
:math:`\gamma_k` in the equations above. K for the number of radial filters.
"""
def __init__(self, interaction_cutoffs, rbf_kernel_means, rbf_kernel_scaling):
super(RadialPooling, self).__init__()
self.interaction_cutoffs = nn.Parameter(
interaction_cutoffs.reshape(-1, 1, 1), requires_grad=True)
self.rbf_kernel_means = nn.Parameter(
rbf_kernel_means.reshape(-1, 1, 1), requires_grad=True)
self.rbf_kernel_scaling = nn.Parameter(
rbf_kernel_scaling.reshape(-1, 1, 1), requires_grad=True)
def forward(self, distances):
"""Apply the layer to transform edge distances.
Parameters
----------
distances : Float32 tensor of shape (E, 1)
Distance between end nodes of edges. E for the number of edges.
Returns
-------
Float32 tensor of shape (K, E, 1)
Transformed edge distances. K for the number of radial filters.
"""
scaled_euclidean_distance = - self.rbf_kernel_scaling * \
(distances - self.rbf_kernel_means) ** 2 # (K, E, 1)
rbf_kernel_results = th.exp(scaled_euclidean_distance) # (K, E, 1)
cos_values = 0.5 * (th.cos(np.pi * distances / self.interaction_cutoffs) + 1) # (K, E, 1)
cutoff_values = th.where(
distances <= self.interaction_cutoffs,
cos_values, th.zeros_like(cos_values)) # (K, E, 1)
# Note that there appears to be an inconsistency between the paper and
# DeepChem's implementation. In the paper, the scaled_euclidean_distance first
# gets multiplied by cutoff_values, followed by exponentiation. Here we follow
# the practice of DeepChem.
return rbf_kernel_results * cutoff_values
def msg_func(edges):
"""Send messages along edges.
Parameters
----------
edges : EdgeBatch
A batch of edges.
Returns
-------
dict mapping 'm' to Float32 tensor of shape (E, K * T)
Messages computed. E for the number of edges, K for the number of
radial filters and T for the number of features to use
(types of atomic number in the paper).
"""
return {'m': th.einsum(
'ij,ik->ijk', edges.src['hv'], edges.data['he']).view(len(edges), -1)}
def reduce_func(nodes):
"""Collect messages and update node representations.
Parameters
----------
nodes : NodeBatch
A batch of nodes.
Returns
-------
dict mapping 'hv_new' to Float32 tensor of shape (V, K * T)
Updated node representations. V for the number of nodes, K for the number of
radial filters and T for the number of features to use
(types of atomic number in the paper).
"""
return {'hv_new': nodes.mailbox['m'].sum(1)}
class AtomicConv(nn.Module):
r"""Atomic Convolution Layer from paper `Atomic Convolutional Networks for
Predicting Protein-Ligand Binding Affinity <https://arxiv.org/abs/1703.10603>`__.
We denote the type of atom :math:`i` by :math:`z_i` and the distance between atom
:math:`i` and :math:`j` by :math:`r_{ij}`.
**Distance Transformation**
An atomic convolution layer first transforms distances with radial filters and
then perform a pooling operation.
For radial filter indexed by :math:`k`, it projects edge distances with
.. math::
h_{ij}^{k} = \exp(-\gamma_{k}|r_{ij}-r_{k}|^2)
If :math:`r_{ij} < c_k`,
.. math::
f_{ij}^{k} = 0.5 * \cos(\frac{\pi r_{ij}}{c_k} + 1),
else,
.. math::
f_{ij}^{k} = 0.
Finally,
.. math::
e_{ij}^{k} = h_{ij}^{k} * f_{ij}^{k}
**Aggregation**
For each type :math:`t`, each atom collects distance information from all neighbor atoms
of type :math:`t`:
.. math::
p_{i, t}^{k} = \sum_{j\in N(i)} e_{ij}^{k} * 1(z_j == t)
We concatenate the results for all RBF kernels and atom types.
Notes
-----
* This convolution operation is designed for molecular graphs in Chemistry, but it might
be possible to extend it to more general graphs.
* There seems to be an inconsistency about the definition of :math:`e_{ij}^{k}` in the
paper and the author's implementation. We follow the author's implementation. In the
paper, :math:`e_{ij}^{k}` was defined as
:math:`\exp(-\gamma_{k}|r_{ij}-r_{k}|^2 * f_{ij}^{k})`.
* :math:`\gamma_{k}`, :math:`r_k` and :math:`c_k` are all learnable.
Parameters
----------
interaction_cutoffs : float32 tensor of shape (K)
:math:`c_k` in the equations above. Roughly they can be considered as learnable cutoffs
and two atoms are considered as connected if the distance between them is smaller than
the cutoffs. K for the number of radial filters.
rbf_kernel_means : float32 tensor of shape (K)
:math:`r_k` in the equations above. K for the number of radial filters.
rbf_kernel_scaling : float32 tensor of shape (K)
:math:`\gamma_k` in the equations above. K for the number of radial filters.
features_to_use : None or float tensor of shape (T)
In the original paper, these are atomic numbers to consider, representing the types
of atoms. T for the number of types of atomic numbers. Default to None.
"""
def __init__(self, interaction_cutoffs, rbf_kernel_means,
rbf_kernel_scaling, features_to_use=None):
super(AtomicConv, self).__init__()
self.radial_pooling = RadialPooling(interaction_cutoffs=interaction_cutoffs,
rbf_kernel_means=rbf_kernel_means,
rbf_kernel_scaling=rbf_kernel_scaling)
if features_to_use is None:
self.num_channels = 1
self.features_to_use = None
else:
self.num_channels = len(features_to_use)
self.features_to_use = nn.Parameter(features_to_use, requires_grad=False)
def forward(self, graph, feat, distances):
"""Apply the atomic convolution layer.
Parameters
----------
graph : DGLGraph or BatchedDGLGraph
Topology based on which message passing is performed.
feat : Float32 tensor of shape (V, 1)
Initial node features, which are atomic numbers in the paper.
V for the number of nodes.
distances : Float32 tensor of shape (E, 1)
Distance between end nodes of edges. E for the number of edges.
Returns
-------
Float32 tensor of shape (V, K * T)
Updated node representations. V for the number of nodes, K for the
number of radial filters, and T for the number of types of atomic numbers.
"""
radial_pooled_values = self.radial_pooling(distances) # (K, E, 1)
graph = graph.local_var()
if self.features_to_use is not None:
feat = (feat == self.features_to_use).float() # (V, T)
graph.ndata['hv'] = feat
graph.edata['he'] = radial_pooled_values.transpose(1, 0).squeeze(-1) # (E, K)
graph.update_all(msg_func, reduce_func)
return graph.ndata['hv_new'].view(graph.number_of_nodes(), -1) # (V, K * T)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment