Unverified Commit cf9ba90f authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[Chem] ACNN and various utilities (#1117)

* Add several splitting methods

* Update

* Update

* Update

* Update

* Update

* Fix

* Update

* Update

* Update

* Update

* Fix

* Fix

* Fix

* Fix

* Fix

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Finally

* CI
parent 6beef85b
"""Convert complexes into DGLHeteroGraphs"""
import numpy as np
from ..utils import k_nearest_neighbors
from .... import graph, bipartite, hetero_from_relations
from .... import backend as F
__all__ = ['ACNN_graph_construction_and_featurization']
def filter_out_hydrogens(mol):
"""Get indices for non-hydrogen atoms.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance.
Returns
-------
indices_left : list of int
Indices of non-hydrogen atoms.
"""
indices_left = []
for i, atom in enumerate(mol.GetAtoms()):
atomic_num = atom.GetAtomicNum()
# Hydrogen atoms have an atomic number of 1.
if atomic_num != 1:
indices_left.append(i)
return indices_left
def get_atomic_numbers(mol, indices):
"""Get the atomic numbers for the specified atoms.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance.
indices : list of int
Specifying atoms.
Returns
-------
list of int
Atomic numbers computed.
"""
atomic_numbers = []
for i in indices:
atom = mol.GetAtomWithIdx(i)
atomic_numbers.append(atom.GetAtomicNum())
return atomic_numbers
def ACNN_graph_construction_and_featurization(ligand_mol,
protein_mol,
ligand_coordinates,
protein_coordinates,
max_num_ligand_atoms=None,
max_num_protein_atoms=None,
neighbor_cutoff=12.,
max_num_neighbors=12,
strip_hydrogens=False):
"""Graph construction and featurization for `Atomic Convolutional Networks for
Predicting Protein-Ligand Binding Affinity <https://arxiv.org/abs/1703.10603>`__.
Parameters
----------
ligand_mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance.
protein_mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance.
ligand_coordinates : Float Tensor of shape (V1, 3)
Atom coordinates in a ligand.
protein_coordinates : Float Tensor of shape (V2, 3)
Atom coordinates in a protein.
max_num_ligand_atoms : int or None
Maximum number of atoms in ligands for zero padding.
If None, no zero padding will be performed. Default to None.
max_num_protein_atoms : int or None
Maximum number of atoms in proteins for zero padding.
If None, no zero padding will be performed. Default to None.
neighbor_cutoff : float
Distance cutoff to define 'neighboring'. Default to 12.
max_num_neighbors : int
Maximum number of neighbors allowed for each atom. Default to 12.
strip_hydrogens : bool
Whether to exclude hydrogen atoms. Default to False.
"""
assert ligand_coordinates is not None, 'Expect ligand_coordinates to be provided.'
assert protein_coordinates is not None, 'Expect protein_coordinates to be provided.'
if strip_hydrogens:
# Remove hydrogen atoms and their corresponding coordinates
ligand_atom_indices_left = filter_out_hydrogens(ligand_mol)
protein_atom_indices_left = filter_out_hydrogens(protein_mol)
ligand_coordinates = ligand_coordinates.take(ligand_atom_indices_left, axis=0)
protein_coordinates = protein_coordinates.take(protein_atom_indices_left, axis=0)
else:
ligand_atom_indices_left = list(range(ligand_mol.GetNumAtoms()))
protein_atom_indices_left = list(range(protein_mol.GetNumAtoms()))
# Compute number of nodes for each type
if max_num_ligand_atoms is None:
num_ligand_atoms = len(ligand_atom_indices_left)
else:
num_ligand_atoms = max_num_ligand_atoms
if max_num_protein_atoms is None:
num_protein_atoms = len(protein_atom_indices_left)
else:
num_protein_atoms = max_num_protein_atoms
# Construct graph for atoms in the ligand
ligand_srcs, ligand_dsts, ligand_dists = k_nearest_neighbors(
ligand_coordinates, neighbor_cutoff, max_num_neighbors)
ligand_graph = graph((ligand_srcs, ligand_dsts),
'ligand_atom', 'ligand', num_ligand_atoms)
ligand_graph.edata['distance'] = F.reshape(F.zerocopy_from_numpy(
np.array(ligand_dists).astype(np.float32)), (-1, 1))
# Construct graph for atoms in the protein
protein_srcs, protein_dsts, protein_dists = k_nearest_neighbors(
protein_coordinates, neighbor_cutoff, max_num_neighbors)
protein_graph = graph((protein_srcs, protein_dsts),
'protein_atom', 'protein', num_protein_atoms)
protein_graph.edata['distance'] = F.reshape(F.zerocopy_from_numpy(
np.array(protein_dists).astype(np.float32)), (-1, 1))
# Construct 4 graphs for complex representation, including the connection within
# protein atoms, the connection within ligand atoms and the connection between
# protein and ligand atoms.
complex_srcs, complex_dsts, complex_dists = k_nearest_neighbors(
np.concatenate([ligand_coordinates, protein_coordinates]),
neighbor_cutoff, max_num_neighbors)
complex_srcs = np.array(complex_srcs)
complex_dsts = np.array(complex_dsts)
complex_dists = np.array(complex_dists)
offset = num_ligand_atoms
# ('ligand_atom', 'complex', 'ligand_atom')
inter_ligand_indices = np.intersect1d(
(complex_srcs < offset).nonzero()[0],
(complex_dsts < offset).nonzero()[0],
assume_unique=True)
inter_ligand_graph = graph(
(complex_srcs[inter_ligand_indices].tolist(),
complex_dsts[inter_ligand_indices].tolist()),
'ligand_atom', 'complex', num_ligand_atoms)
inter_ligand_graph.edata['distance'] = F.reshape(F.zerocopy_from_numpy(
complex_dists[inter_ligand_indices].astype(np.float32)), (-1, 1))
# ('protein_atom', 'complex', 'protein_atom')
inter_protein_indices = np.intersect1d(
(complex_srcs >= offset).nonzero()[0],
(complex_dsts >= offset).nonzero()[0],
assume_unique=True)
inter_protein_graph = graph(
((complex_srcs[inter_protein_indices] - offset).tolist(),
(complex_dsts[inter_protein_indices] - offset).tolist()),
'protein_atom', 'complex', num_protein_atoms)
inter_protein_graph.edata['distance'] = F.reshape(F.zerocopy_from_numpy(
complex_dists[inter_protein_indices].astype(np.float32)), (-1, 1))
# ('ligand_atom', 'complex', 'protein_atom')
ligand_protein_indices = np.intersect1d(
(complex_srcs < offset).nonzero()[0],
(complex_dsts >= offset).nonzero()[0],
assume_unique=True)
ligand_protein_graph = bipartite(
(complex_srcs[ligand_protein_indices].tolist(),
(complex_dsts[ligand_protein_indices] - offset).tolist()),
'ligand_atom', 'complex', 'protein_atom',
(num_ligand_atoms, num_protein_atoms))
ligand_protein_graph.edata['distance'] = F.reshape(F.zerocopy_from_numpy(
complex_dists[ligand_protein_indices].astype(np.float32)), (-1, 1))
# ('protein_atom', 'complex', 'ligand_atom')
protein_ligand_indices = np.intersect1d(
(complex_srcs >= offset).nonzero()[0],
(complex_dsts < offset).nonzero()[0],
assume_unique=True)
protein_ligand_graph = bipartite(
((complex_srcs[protein_ligand_indices] - offset).tolist(),
complex_dsts[protein_ligand_indices].tolist()),
'protein_atom', 'complex', 'ligand_atom',
(num_protein_atoms, num_ligand_atoms))
protein_ligand_graph.edata['distance'] = F.reshape(F.zerocopy_from_numpy(
complex_dists[protein_ligand_indices].astype(np.float32)), (-1, 1))
# Merge the graphs
g = hetero_from_relations(
[protein_graph,
ligand_graph,
inter_ligand_graph,
inter_protein_graph,
ligand_protein_graph,
protein_ligand_graph]
)
# Get atomic numbers for all atoms left and set node features
ligand_atomic_numbers = np.array(get_atomic_numbers(ligand_mol, ligand_atom_indices_left))
# zero padding
ligand_atomic_numbers = np.concatenate([
ligand_atomic_numbers, np.zeros(num_ligand_atoms - len(ligand_atom_indices_left))])
protein_atomic_numbers = np.array(get_atomic_numbers(protein_mol, protein_atom_indices_left))
# zero padding
protein_atomic_numbers = np.concatenate([
protein_atomic_numbers, np.zeros(num_protein_atoms - len(protein_atom_indices_left))])
g.nodes['ligand_atom'].data['atomic_number'] = F.reshape(F.zerocopy_from_numpy(
ligand_atomic_numbers.astype(np.float32)), (-1, 1))
g.nodes['protein_atom'].data['atomic_number'] = F.reshape(F.zerocopy_from_numpy(
protein_atomic_numbers.astype(np.float32)), (-1, 1))
# Prepare mask indicating the existence of nodes
ligand_masks = np.zeros((num_ligand_atoms, 1))
ligand_masks[:len(ligand_atom_indices_left), :] = 1
g.nodes['ligand_atom'].data['mask'] = F.zerocopy_from_numpy(
ligand_masks.astype(np.float32))
protein_masks = np.zeros((num_protein_atoms, 1))
protein_masks[:len(protein_atom_indices_left), :] = 1
g.nodes['protein_atom'].data['mask'] = F.zerocopy_from_numpy(
protein_masks.astype(np.float32))
return g
import dgl.backend as F
import itertools
import numpy as np
from functools import partial
from collections import defaultdict
from dgl import DGLGraph
from .... import backend as F
try:
from rdkit import Chem
......@@ -12,18 +11,38 @@ try:
except ImportError:
pass
__all__ = ['one_hot_encoding', 'atom_type_one_hot', 'atomic_number_one_hot', 'atomic_number',
'atom_degree_one_hot', 'atom_degree', 'atom_total_degree_one_hot', 'atom_total_degree',
'atom_implicit_valence_one_hot', 'atom_implicit_valence', 'atom_hybridization_one_hot',
'atom_total_num_H_one_hot', 'atom_total_num_H', 'atom_formal_charge_one_hot',
'atom_formal_charge', 'atom_num_radical_electrons_one_hot',
'atom_num_radical_electrons', 'atom_is_aromatic_one_hot', 'atom_is_aromatic',
'atom_chiral_tag_one_hot', 'atom_mass', 'ConcatFeaturizer', 'BaseAtomFeaturizer',
'CanonicalAtomFeaturizer', 'mol_to_graph', 'smiles_to_bigraph',
'mol_to_bigraph', 'smiles_to_complete_graph', 'mol_to_complete_graph',
'bond_type_one_hot', 'bond_is_conjugated_one_hot', 'bond_is_conjugated',
'bond_is_in_ring_one_hot', 'bond_is_in_ring', 'bond_stereo_one_hot',
'BaseBondFeaturizer', 'CanonicalBondFeaturizer']
__all__ = ['one_hot_encoding',
'atom_type_one_hot',
'atomic_number_one_hot',
'atomic_number',
'atom_degree_one_hot',
'atom_degree',
'atom_total_degree_one_hot',
'atom_total_degree',
'atom_implicit_valence_one_hot',
'atom_implicit_valence',
'atom_hybridization_one_hot',
'atom_total_num_H_one_hot',
'atom_total_num_H',
'atom_formal_charge_one_hot',
'atom_formal_charge',
'atom_num_radical_electrons_one_hot',
'atom_num_radical_electrons',
'atom_is_aromatic_one_hot',
'atom_is_aromatic',
'atom_chiral_tag_one_hot',
'atom_mass',
'ConcatFeaturizer',
'BaseAtomFeaturizer',
'CanonicalAtomFeaturizer',
'bond_type_one_hot',
'bond_is_conjugated_one_hot',
'bond_is_conjugated',
'bond_is_in_ring_one_hot',
'bond_is_in_ring',
'bond_stereo_one_hot',
'BaseBondFeaturizer',
'CanonicalBondFeaturizer']
def one_hot_encoding(x, allowable_set, encode_unknown=False):
"""One-hot encoding.
......@@ -197,7 +216,8 @@ def atom_total_degree_one_hot(atom, allowable_set=None, encode_unknown=False):
return one_hot_encoding(atom.GetTotalDegree(), allowable_set, encode_unknown)
def atom_total_degree(atom):
"""
"""The degree of an atom including Hs.
See Also
--------
atom_degree
......@@ -855,232 +875,3 @@ class CanonicalBondFeaturizer(BaseBondFeaturizer):
bond_is_in_ring,
bond_stereo_one_hot]
)})
#################################################################
# DGLGraph Construction
#################################################################
def mol_to_graph(mol, graph_constructor, atom_featurizer, bond_featurizer):
"""Convert an RDKit molecule object into a DGLGraph and featurize for it.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule holder
graph_constructor : callable
Takes an RDKit molecule as input and returns a DGLGraph
atom_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for atoms in a molecule, which can be used to update
ndata for a DGLGraph.
bond_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for bonds in a molecule, which can be used to update
edata for a DGLGraph.
Returns
-------
g : DGLGraph
Converted DGLGraph for the molecule
"""
new_order = rdmolfiles.CanonicalRankAtoms(mol)
mol = rdmolops.RenumberAtoms(mol, new_order)
g = graph_constructor(mol)
if atom_featurizer is not None:
g.ndata.update(atom_featurizer(mol))
if bond_featurizer is not None:
g.edata.update(bond_featurizer(mol))
return g
def construct_bigraph_from_mol(mol, add_self_loop=False):
"""Construct a bi-directed DGLGraph with topology only for the molecule.
The **i** th atom in the molecule, i.e. ``mol.GetAtomWithIdx(i)``, corresponds to the
**i** th node in the returned DGLGraph.
The **i** th bond in the molecule, i.e. ``mol.GetBondWithIdx(i)``, corresponds to the
**(2i)**-th and **(2i+1)**-th edges in the returned DGLGraph. The **(2i)**-th and
**(2i+1)**-th edges will be separately from **u** to **v** and **v** to **u**, where
**u** is ``bond.GetBeginAtomIdx()`` and **v** is ``bond.GetEndAtomIdx()``.
If self loops are added, the last **n** edges will separately be self loops for
atoms ``0, 1, ..., n-1``.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule holder
add_self_loop : bool
Whether to add self loops in DGLGraphs. Default to False.
Returns
-------
g : DGLGraph
Empty bigraph topology of the molecule
"""
g = DGLGraph()
# Add nodes
num_atoms = mol.GetNumAtoms()
g.add_nodes(num_atoms)
# Add edges
src_list = []
dst_list = []
num_bonds = mol.GetNumBonds()
for i in range(num_bonds):
bond = mol.GetBondWithIdx(i)
u = bond.GetBeginAtomIdx()
v = bond.GetEndAtomIdx()
src_list.extend([u, v])
dst_list.extend([v, u])
g.add_edges(src_list, dst_list)
if add_self_loop:
nodes = g.nodes()
g.add_edges(nodes, nodes)
return g
def mol_to_bigraph(mol, add_self_loop=False,
atom_featurizer=None,
bond_featurizer=None):
"""Convert an RDKit molecule object into a bi-directed DGLGraph and featurize for it.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule holder
add_self_loop : bool
Whether to add self loops in DGLGraphs. Default to False.
atom_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for atoms in a molecule, which can be used to update
ndata for a DGLGraph. Default to None.
bond_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for bonds in a molecule, which can be used to update
edata for a DGLGraph. Default to None.
Returns
-------
g : DGLGraph
Bi-directed DGLGraph for the molecule
"""
return mol_to_graph(mol, partial(construct_bigraph_from_mol, add_self_loop=add_self_loop),
atom_featurizer, bond_featurizer)
def smiles_to_bigraph(smiles, add_self_loop=False,
atom_featurizer=None,
bond_featurizer=None):
"""Convert a SMILES into a bi-directed DGLGraph and featurize for it.
Parameters
----------
smiles : str
String of SMILES
add_self_loop : bool
Whether to add self loops in DGLGraphs. Default to False.
atom_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for atoms in a molecule, which can be used to update
ndata for a DGLGraph. Default to None.
bond_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for bonds in a molecule, which can be used to update
edata for a DGLGraph. Default to None.
Returns
-------
g : DGLGraph
Bi-directed DGLGraph for the molecule
"""
mol = Chem.MolFromSmiles(smiles)
return mol_to_bigraph(mol, add_self_loop, atom_featurizer, bond_featurizer)
def construct_complete_graph_from_mol(mol, add_self_loop=False):
"""Construct a complete graph with topology only for the molecule
The **i** th atom in the molecule, i.e. ``mol.GetAtomWithIdx(i)``, corresponds to the
**i** th node in the returned DGLGraph.
The edges are in the order of (0, 0), (1, 0), (2, 0), ... (0, 1), (1, 1), (2, 1), ...
If self loops are not created, we will not have (0, 0), (1, 1), ...
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule holder
add_self_loop : bool
Whether to add self loops in DGLGraphs. Default to False.
Returns
-------
g : DGLGraph
Empty complete graph topology of the molecule
"""
g = DGLGraph()
num_atoms = mol.GetNumAtoms()
g.add_nodes(num_atoms)
if add_self_loop:
g.add_edges(
[i for i in range(num_atoms) for j in range(num_atoms)],
[j for i in range(num_atoms) for j in range(num_atoms)])
else:
g.add_edges(
[i for i in range(num_atoms) for j in range(num_atoms - 1)], [
j for i in range(num_atoms)
for j in range(num_atoms) if i != j
])
return g
def mol_to_complete_graph(mol, add_self_loop=False,
atom_featurizer=None,
bond_featurizer=None):
"""Convert an RDKit molecule into a complete DGLGraph and featurize for it.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule holder
add_self_loop : bool
Whether to add self loops in DGLGraphs. Default to False.
atom_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for atoms in a molecule, which can be used to update
ndata for a DGLGraph. Default to None.
bond_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for bonds in a molecule, which can be used to update
edata for a DGLGraph. Default to None.
Returns
-------
g : DGLGraph
Complete DGLGraph for the molecule
"""
return mol_to_graph(mol, partial(construct_complete_graph_from_mol, add_self_loop=add_self_loop),
atom_featurizer, bond_featurizer)
def smiles_to_complete_graph(smiles, add_self_loop=False,
atom_featurizer=None,
bond_featurizer=None):
"""Convert a SMILES into a complete DGLGraph and featurize for it.
Parameters
----------
smiles : str
String of SMILES
add_self_loop : bool
Whether to add self loops in DGLGraphs. Default to False.
atom_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for atoms in a molecule, which can be used to update
ndata for a DGLGraph. Default to None.
bond_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for bonds in a molecule, which can be used to update
edata for a DGLGraph. Default to None.
Returns
-------
g : DGLGraph
Complete DGLGraph for the molecule
"""
mol = Chem.MolFromSmiles(smiles)
return mol_to_complete_graph(mol, add_self_loop, atom_featurizer, bond_featurizer)
"""Convert molecules into DGLGraphs."""
import numpy as np
from functools import partial
from .... import DGLGraph
try:
import mdtraj
from rdkit import Chem
from rdkit.Chem import rdmolfiles, rdmolops
except ImportError:
pass
__all__ = ['mol_to_graph',
'smiles_to_bigraph',
'mol_to_bigraph',
'smiles_to_complete_graph',
'mol_to_complete_graph',
'k_nearest_neighbors']
def mol_to_graph(mol, graph_constructor, node_featurizer, edge_featurizer):
"""Convert an RDKit molecule object into a DGLGraph and featurize for it.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule holder
graph_constructor : callable
Takes an RDKit molecule as input and returns a DGLGraph
node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for nodes like atoms in a molecule, which can be used to
update ndata for a DGLGraph.
edge_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for edges like bonds in a molecule, which can be used to
update edata for a DGLGraph.
Returns
-------
g : DGLGraph
Converted DGLGraph for the molecule
"""
new_order = rdmolfiles.CanonicalRankAtoms(mol)
mol = rdmolops.RenumberAtoms(mol, new_order)
g = graph_constructor(mol)
if node_featurizer is not None:
g.ndata.update(node_featurizer(mol))
if edge_featurizer is not None:
g.edata.update(edge_featurizer(mol))
return g
def construct_bigraph_from_mol(mol, add_self_loop=False):
"""Construct a bi-directed DGLGraph with topology only for the molecule.
The **i** th atom in the molecule, i.e. ``mol.GetAtomWithIdx(i)``, corresponds to the
**i** th node in the returned DGLGraph.
The **i** th bond in the molecule, i.e. ``mol.GetBondWithIdx(i)``, corresponds to the
**(2i)**-th and **(2i+1)**-th edges in the returned DGLGraph. The **(2i)**-th and
**(2i+1)**-th edges will be separately from **u** to **v** and **v** to **u**, where
**u** is ``bond.GetBeginAtomIdx()`` and **v** is ``bond.GetEndAtomIdx()``.
If self loops are added, the last **n** edges will separately be self loops for
atoms ``0, 1, ..., n-1``.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule holder
add_self_loop : bool
Whether to add self loops in DGLGraphs. Default to False.
Returns
-------
g : DGLGraph
Empty bigraph topology of the molecule
"""
g = DGLGraph()
# Add nodes
num_atoms = mol.GetNumAtoms()
g.add_nodes(num_atoms)
# Add edges
src_list = []
dst_list = []
num_bonds = mol.GetNumBonds()
for i in range(num_bonds):
bond = mol.GetBondWithIdx(i)
u = bond.GetBeginAtomIdx()
v = bond.GetEndAtomIdx()
src_list.extend([u, v])
dst_list.extend([v, u])
g.add_edges(src_list, dst_list)
if add_self_loop:
nodes = g.nodes()
g.add_edges(nodes, nodes)
return g
def mol_to_bigraph(mol, add_self_loop=False,
node_featurizer=None,
edge_featurizer=None):
"""Convert an RDKit molecule object into a bi-directed DGLGraph and featurize for it.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule holder
add_self_loop : bool
Whether to add self loops in DGLGraphs. Default to False.
node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for nodes like atoms in a molecule, which can be used to update
ndata for a DGLGraph. Default to None.
edge_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for edges like bonds in a molecule, which can be used to update
edata for a DGLGraph. Default to None.
Returns
-------
g : DGLGraph
Bi-directed DGLGraph for the molecule
"""
return mol_to_graph(mol, partial(construct_bigraph_from_mol, add_self_loop=add_self_loop),
node_featurizer, edge_featurizer)
def smiles_to_bigraph(smiles, add_self_loop=False,
node_featurizer=None,
edge_featurizer=None):
"""Convert a SMILES into a bi-directed DGLGraph and featurize for it.
Parameters
----------
smiles : str
String of SMILES
add_self_loop : bool
Whether to add self loops in DGLGraphs. Default to False.
node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for nodes like atoms in a molecule, which can be used to update
ndata for a DGLGraph. Default to None.
edge_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for edges like bonds in a molecule, which can be used to update
edata for a DGLGraph. Default to None.
Returns
-------
g : DGLGraph
Bi-directed DGLGraph for the molecule
"""
mol = Chem.MolFromSmiles(smiles)
return mol_to_bigraph(mol, add_self_loop, node_featurizer, edge_featurizer)
def construct_complete_graph_from_mol(mol, add_self_loop=False):
"""Construct a complete graph with topology only for the molecule
The **i** th atom in the molecule, i.e. ``mol.GetAtomWithIdx(i)``, corresponds to the
**i** th node in the returned DGLGraph.
The edges are in the order of (0, 0), (1, 0), (2, 0), ... (0, 1), (1, 1), (2, 1), ...
If self loops are not created, we will not have (0, 0), (1, 1), ...
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule holder
add_self_loop : bool
Whether to add self loops in DGLGraphs. Default to False.
Returns
-------
g : DGLGraph
Empty complete graph topology of the molecule
"""
g = DGLGraph()
num_atoms = mol.GetNumAtoms()
g.add_nodes(num_atoms)
if add_self_loop:
g.add_edges(
[i for i in range(num_atoms) for j in range(num_atoms)],
[j for i in range(num_atoms) for j in range(num_atoms)])
else:
g.add_edges(
[i for i in range(num_atoms) for j in range(num_atoms - 1)], [
j for i in range(num_atoms)
for j in range(num_atoms) if i != j
])
return g
def mol_to_complete_graph(mol, add_self_loop=False,
node_featurizer=None,
edge_featurizer=None):
"""Convert an RDKit molecule into a complete DGLGraph and featurize for it.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule holder
add_self_loop : bool
Whether to add self loops in DGLGraphs. Default to False.
node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for nodes like atoms in a molecule, which can be used to update
ndata for a DGLGraph. Default to None.
edge_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for edges like bonds in a molecule, which can be used to update
edata for a DGLGraph. Default to None.
Returns
-------
g : DGLGraph
Complete DGLGraph for the molecule
"""
return mol_to_graph(mol, partial(construct_complete_graph_from_mol, add_self_loop=add_self_loop),
node_featurizer, edge_featurizer)
def smiles_to_complete_graph(smiles, add_self_loop=False,
node_featurizer=None,
edge_featurizer=None):
"""Convert a SMILES into a complete DGLGraph and featurize for it.
Parameters
----------
smiles : str
String of SMILES
add_self_loop : bool
Whether to add self loops in DGLGraphs. Default to False.
node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for nodes like atoms in a molecule, which can be used to update
ndata for a DGLGraph. Default to None.
edge_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for edges like bonds in a molecule, which can be used to update
edata for a DGLGraph. Default to None.
Returns
-------
g : DGLGraph
Complete DGLGraph for the molecule
"""
mol = Chem.MolFromSmiles(smiles)
return mol_to_complete_graph(mol, add_self_loop, node_featurizer, edge_featurizer)
def k_nearest_neighbors(coordinates, neighbor_cutoff, max_num_neighbors):
"""Find k nearest neighbors for each atom based on the 3D coordinates.
Parameters
----------
coordinates : numpy.ndarray of shape (N, 3)
The 3D coordinates of atoms in the molecule. N for the number of atoms.
neighbor_cutoff : float
Distance cutoff to define 'neighboring'.
max_num_neighbors : int or None.
If not None, then this specifies the maximum number of closest neighbors
allowed for each atom.
Returns
-------
neighbor_list : dict(int -> list of ints)
Mapping atom indices to their k nearest neighbors.
"""
num_atoms = coordinates.shape[0]
traj = mdtraj.Trajectory(coordinates.reshape((1, num_atoms, 3)), None)
neighbors = mdtraj.geometry.compute_neighborlist(traj, neighbor_cutoff)
srcs, dsts, distances = [], [], []
for i in range(num_atoms):
delta = coordinates[i] - coordinates.take(neighbors[i], axis=0)
dist = np.linalg.norm(delta, axis=1)
if max_num_neighbors is not None and len(neighbors[i]) > max_num_neighbors:
sorted_neighbors = list(zip(dist, neighbors[i]))
# Sort neighbors based on distance from smallest to largest
sorted_neighbors.sort(key=lambda tup: tup[0])
dsts.extend([i for _ in range(max_num_neighbors)])
srcs.extend([int(sorted_neighbors[j][1]) for j in range(max_num_neighbors)])
distances.extend([float(sorted_neighbors[j][0]) for j in range(max_num_neighbors)])
else:
dsts.extend([i for _ in range(len(neighbors[i]))])
srcs.extend(neighbors[i].tolist())
distances.extend(dist.tolist())
return srcs, dsts, distances
"""Utils for RDKit, mostly adapted from DeepChem
(https://github.com/deepchem/deepchem/blob/master/deepchem)."""
import warnings
from functools import partial
from multiprocessing import Pool
try:
import pdbfixer
import simtk
from rdkit import Chem
from rdkit.Chem import AllChem
from StringIO import StringIO
except ImportError:
from io import StringIO
__all__ = ['add_hydrogens_to_mol',
'get_mol_3D_coordinates',
'load_molecule',
'multiprocess_load_molecules']
def add_hydrogens_to_mol(mol):
"""Add hydrogens to an RDKit molecule instance.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance.
Returns
-------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance with hydrogens added. For failures in adding hydrogens,
the original RDKit molecule instance will be returned.
"""
try:
pdbblock = Chem.MolToPDBBlock(mol)
pdb_stringio = StringIO()
pdb_stringio.write(pdbblock)
pdb_stringio.seek(0)
fixer = pdbfixer.PDBFixer(pdbfile=pdb_stringio)
fixer.findMissingResidues()
fixer.findMissingAtoms()
fixer.addMissingAtoms()
fixer.addMissingHydrogens(7.4)
hydrogenated_io = StringIO()
simtk.openmm.app.PDBFile.writeFile(fixer.topology, fixer.positions,
hydrogenated_io)
hydrogenated_io.seek(0)
mol = Chem.MolFromPDBBlock(hydrogenated_io.read(), sanitize=False, removeHs=False)
pdb_stringio.close()
hydrogenated_io.close()
except ValueError:
warnings.warn('Failed to add hydrogens to the molecule.')
return mol
def get_mol_3D_coordinates(mol):
"""Get 3D coordinates of the molecule.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance.
Returns
-------
numpy.ndarray of shape (N, 3) or None
The 3D coordinates of atoms in the molecule. N for the number of atoms in
the molecule. For failures in getting the conformations, None will be returned.
"""
try:
conf = mol.GetConformer()
conf_num_atoms = conf.GetNumAtoms()
mol_num_atoms = mol.GetNumAtoms()
assert mol_num_atoms == conf_num_atoms, \
'Expect the number of atoms in the molecule and its conformation ' \
'to be the same, got {:d} and {:d}'.format(mol_num_atoms, conf_num_atoms)
return conf.GetPositions()
except:
warnings.warn('Unable to get conformation of the molecule.')
return None
def load_molecule(molecule_file, add_hydrogens=False, sanitize=False, calc_charges=False,
remove_hs=False, use_conformation=True):
"""Load a molecule from a file.
Parameters
----------
molecule_file : str
Path to file for storing a molecule, which can be of format '.mol2', '.sdf',
'.pdbqt', or '.pdb'.
add_hydrogens : bool
Whether to add hydrogens via pdbfixer. Default to False.
sanitize : bool
Whether sanitization is performed in initializing RDKit molecule instances. See
https://www.rdkit.org/docs/RDKit_Book.html for details of the sanitization.
Default to False.
calc_charges : bool
Whether to add Gasteiger charges via RDKit. Setting this to be True will enforce
``add_hydrogens`` and ``sanitize`` to be True. Default to False.
remove_hs : bool
Whether to remove hydrogens via RDKit. Note that removing hydrogens can be quite
slow for large molecules. Default to False.
use_conformation : bool
Whether we need to extract molecular conformation from proteins and ligands.
Default to True.
Returns
-------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance for the loaded molecule.
coordinates : np.ndarray of shape (N, 3) or None
The 3D coordinates of atoms in the molecule. N for the number of atoms in
the molecule. None will be returned if ``use_conformation`` is False or
we failed to get conformation information.
"""
if molecule_file.endswith('.mol2'):
mol = Chem.MolFromMol2File(molecule_file, sanitize=False, removeHs=False)
elif molecule_file.endswith('.sdf'):
supplier = Chem.SDMolSupplier(molecule_file, sanitize=False, removeHs=False)
mol = supplier[0]
elif molecule_file.endswith('.pdbqt'):
with open(molecule_file) as f:
pdbqt_data = f.readlines()
pdb_block = ''
for line in pdbqt_data:
pdb_block += '{}\n'.format(line[:66])
mol = Chem.MolFromPDBBlock(pdb_block, sanitize=False, removeHs=False)
elif molecule_file.endswith('.pdb'):
mol = Chem.MolFromPDBFile(molecule_file, sanitize=False, removeHs=False)
else:
return ValueError('Expect the format of the molecule_file to be '
'one of .mol2, .sdf, .pdbqt and .pdb, got {}'.format(molecule_file))
try:
if add_hydrogens or calc_charges:
mol = add_hydrogens_to_mol(mol)
if sanitize or calc_charges:
Chem.SanitizeMol(mol)
if calc_charges:
# Compute Gasteiger charges on the molecule.
try:
AllChem.ComputeGasteigerCharges(mol)
except:
warnings.warn('Unable to compute charges for the molecule.')
if remove_hs:
mol = Chem.RemoveHs(mol)
except:
return None, None
if use_conformation:
coordinates = get_mol_3D_coordinates(mol)
else:
coordinates = None
return mol, coordinates
def multiprocess_load_molecules(files, add_hydrogens=False, sanitize=False, calc_charges=False,
remove_hs=False, use_conformation=True, num_processes=2):
"""Load molecules from files with multiprocessing.
Parameters
----------
files : list of str
Each element is a path to a file storing a molecule, which can be of format '.mol2',
'.sdf', '.pdbqt', or '.pdb'.
add_hydrogens : bool
Whether to add hydrogens via pdbfixer. Default to False.
sanitize : bool
Whether sanitization is performed in initializing RDKit molecule instances. See
https://www.rdkit.org/docs/RDKit_Book.html for details of the sanitization.
Default to False.
calc_charges : bool
Whether to add Gasteiger charges via RDKit. Setting this to be True will enforce
``add_hydrogens`` and ``sanitize`` to be True. Default to False.
remove_hs : bool
Whether to remove hydrogens via RDKit. Note that removing hydrogens can be quite
slow for large molecules. Default to False.
use_conformation : bool
Whether we need to extract molecular conformation from proteins and ligands.
Default to True.
num_processes : int or None
Number of worker processes to use. If None,
then we will use the number of CPUs in the systetm. Default to 2.
Returns
-------
list of 2-tuples
The first element of each 2-tuple is an RDKit molecule instance. The second element
of each 2-tuple is the 3D atom coordinates of the corresponding molecule if
use_conformation is True and the coordinates has been successfully loaded. Otherwise,
it will be None.
"""
if num_processes == 1:
mols_loaded = []
for i, f in enumerate(files):
mols_loaded.append(load_molecule(
f, add_hydrogens=add_hydrogens, sanitize=sanitize, calc_charges=calc_charges,
remove_hs=remove_hs, use_conformation=use_conformation))
else:
with Pool(processes=num_processes) as pool:
mols_loaded = pool.map_async(partial(
load_molecule, add_hydrogens=add_hydrogens, sanitize=sanitize,
calc_charges=calc_charges, remove_hs=remove_hs,
use_conformation=use_conformation), files)
mols_loaded = mols_loaded.get()
return mols_loaded
This diff is collapsed.
# pylint: disable=C0111
"""Model Zoo Package"""
from .classifiers import GCNClassifier, GATClassifier
from .schnet import SchNet
from .mgcn import MGCNModel
......@@ -9,3 +8,4 @@ from .dgmg import DGMG
from .jtnn import DGLJTNNVAE
from .pretrain import load_pretrained
from .attentive_fp import AttentiveFP
from .acnn import ACNN
"""Atomic Convolutional Networks for Predicting Protein-Ligand Binding Affinity"""
# pylint: disable=C0103, C0123
import itertools
import torch
import torch.nn as nn
from ...nn.pytorch import AtomicConv
def truncated_normal_(tensor, mean=0., std=1.):
"""Fills the given tensor in-place with elements sampled from the truncated normal
distribution parameterized by mean and std.
The generated values follow a normal distribution with specified mean and
standard deviation, except that values whose magnitude is more than 2 std
from the mean are dropped.
We credit to Ruotian Luo for this implementation:
https://discuss.pytorch.org/t/implementing-truncated-normal-initializer/4778/15.
Parameters
----------
tensor : Float32 tensor of arbitrary shape
Tensor to be filled.
mean : float
Mean of the truncated normal distribution.
std : float
Standard deviation of the truncated normal distribution.
"""
shape = tensor.shape
tmp = tensor.new_empty(shape + (4,)).normal_()
valid = (tmp < 2) & (tmp > -2)
ind = valid.max(-1, keepdim=True)[1]
tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1))
tensor.data.mul_(std).add_(mean)
class ACNNPredictor(nn.Module):
"""Predictor for ACNN.
Parameters
----------
in_size : int
Number of radial filters used.
hidden_sizes : list of int
Specifying the hidden sizes for all layers in the predictor.
weight_init_stddevs : list of float
Specifying the standard deviations to use for truncated normal
distributions in initialzing weights for the predictor.
dropouts : list of float
Specifying the dropouts to use for all layers in the predictor.
features_to_use : None or float tensor of shape (T)
In the original paper, these are atomic numbers to consider, representing the types
of atoms. T for the number of types of atomic numbers. Default to None.
num_tasks : int
Output size.
"""
def __init__(self, in_size, hidden_sizes, weight_init_stddevs,
dropouts, features_to_use, num_tasks):
super(ACNNPredictor, self).__init__()
if type(features_to_use) != type(None):
in_size *= len(features_to_use)
modules = []
for i, h in enumerate(hidden_sizes):
linear_layer = nn.Linear(in_size, h)
truncated_normal_(linear_layer.weight, std=weight_init_stddevs[i])
modules.append(linear_layer)
modules.append(nn.ReLU())
modules.append(nn.Dropout(dropouts[i]))
in_size = h
linear_layer = nn.Linear(in_size, num_tasks)
truncated_normal_(linear_layer.weight, std=weight_init_stddevs[-1])
modules.append(linear_layer)
self.project = nn.Sequential(*modules)
def forward(self, batch_size, frag1_node_indices_in_complex, frag2_node_indices_in_complex,
ligand_conv_out, protein_conv_out, complex_conv_out):
"""Perform the prediction.
Parameters
----------
batch_size : int
Number of datapoints in a batch.
frag1_node_indices_in_complex : Int64 tensor of shape (V1)
Indices for atoms in the first fragment (protein) in the batched complex.
frag2_node_indices_in_complex : list of int of length V2
Indices for atoms in the second fragment (ligand) in the batched complex.
ligand_conv_out : Float32 tensor of shape (V2, K * T)
Updated ligand node representations. V2 for the number of atoms in the
ligand, K for the number of radial filters, and T for the number of types
of atomic numbers.
protein_conv_out : Float32 tensor of shape (V1, K * T)
Updated protein node representations. V1 for the number of
atoms in the protein, K for the number of radial filters,
and T for the number of types of atomic numbers.
complex_conv_out : Float32 tensor of shape (V1 + V2, K * T)
Updated complex node representations. V1 and V2 separately
for the number of atoms in the ligand and protein, K for
the number of radial filters, and T for the number of
types of atomic numbers.
Returns
-------
Float32 tensor of shape (B, O)
Predicted protein-ligand binding affinity. B for the number
of protein-ligand pairs in the batch and O for the number of tasks.
"""
ligand_feats = self.project(ligand_conv_out) # (V1, O)
protein_feats = self.project(protein_conv_out) # (V2, O)
complex_feats = self.project(complex_conv_out) # (V1+V2, O)
ligand_energy = ligand_feats.reshape(batch_size, -1).sum(-1, keepdim=True) # (B, O)
protein_energy = protein_feats.reshape(batch_size, -1).sum(-1, keepdim=True) # (B, O)
complex_ligand_energy = complex_feats[frag1_node_indices_in_complex].reshape(
batch_size, -1).sum(-1, keepdim=True)
complex_protein_energy = complex_feats[frag2_node_indices_in_complex].reshape(
batch_size, -1).sum(-1, keepdim=True)
complex_energy = complex_ligand_energy + complex_protein_energy
return complex_energy - (ligand_energy + protein_energy)
class ACNN(nn.Module):
"""Atomic Convolutional Networks.
The model was proposed in `Atomic Convolutional Networks for
Predicting Protein-Ligand Binding Affinity <https://arxiv.org/abs/1703.10603>`__.
Parameters
----------
hidden_sizes : list of int
Specifying the hidden sizes for all layers in the predictor.
weight_init_stddevs : list of float
Specifying the standard deviations to use for truncated normal
distributions in initialzing weights for the predictor.
dropouts : list of float
Specifying the dropouts to use for all layers in the predictor.
features_to_use : None or float tensor of shape (T)
In the original paper, these are atomic numbers to consider, representing the types
of atoms. T for the number of types of atomic numbers. Default to None.
radial : None or list
If not None, the list consists of 3 lists of floats, separately for the
options of interaction cutoff, the options of rbf kernel mean and the
options of rbf kernel scaling. If None, a default option of
``[[12.0], [0.0, 2.0, 4.0, 6.0, 8.0], [4.0]]`` will be used.
num_tasks : int
Number of output tasks.
"""
def __init__(self, hidden_sizes, weight_init_stddevs, dropouts,
features_to_use=None, radial=None, num_tasks=1):
super(ACNN, self).__init__()
if radial is None:
radial = [[12.0], [0.0, 2.0, 4.0, 6.0, 8.0], [4.0]]
# Take the product of sets of options and get a list of 3-tuples.
radial_params = [x for x in itertools.product(*radial)]
radial_params = torch.stack(list(map(torch.tensor, zip(*radial_params))), dim=1)
interaction_cutoffs = radial_params[:, 0]
rbf_kernel_means = radial_params[:, 1]
rbf_kernel_scaling = radial_params[:, 2]
self.ligand_conv = AtomicConv(interaction_cutoffs, rbf_kernel_means,
rbf_kernel_scaling, features_to_use)
self.protein_conv = AtomicConv(interaction_cutoffs, rbf_kernel_means,
rbf_kernel_scaling, features_to_use)
self.complex_conv = AtomicConv(interaction_cutoffs, rbf_kernel_means,
rbf_kernel_scaling, features_to_use)
self.predictor = ACNNPredictor(radial_params.shape[0], hidden_sizes,
weight_init_stddevs, dropouts, features_to_use, num_tasks)
def forward(self, graph):
"""Apply the model for prediction.
Parameters
----------
graph : DGLHeteroGraph
DGLHeteroGraph consisting of the ligand graph, the protein graph
and the complex graph, along with preprocessed features.
Returns
-------
Float32 tensor of shape (B, O)
Predicted protein-ligand binding affinity. B for the number
of protein-ligand pairs in the batch and O for the number of tasks.
"""
ligand_graph = graph[('ligand_atom', 'ligand', 'ligand_atom')]
ligand_graph_node_feats = ligand_graph.ndata['atomic_number']
assert ligand_graph_node_feats.shape[-1] == 1
ligand_graph_distances = ligand_graph.edata['distance']
ligand_conv_out = self.ligand_conv(ligand_graph,
ligand_graph_node_feats,
ligand_graph_distances)
protein_graph = graph[('protein_atom', 'protein', 'protein_atom')]
protein_graph_node_feats = protein_graph.ndata['atomic_number']
assert protein_graph_node_feats.shape[-1] == 1
protein_graph_distances = protein_graph.edata['distance']
protein_conv_out = self.protein_conv(protein_graph,
protein_graph_node_feats,
protein_graph_distances)
complex_graph = graph[:, 'complex', :]
complex_graph_node_feats = complex_graph.ndata['atomic_number']
assert complex_graph_node_feats.shape[-1] == 1
complex_graph_distances = complex_graph.edata['distance']
complex_conv_out = self.complex_conv(complex_graph,
complex_graph_node_feats,
complex_graph_distances)
frag1_node_indices_in_complex = torch.where(complex_graph.ndata['_TYPE'] == 0)[0]
frag2_node_indices_in_complex = list(set(range(complex_graph.number_of_nodes())) -
set(frag1_node_indices_in_complex.tolist()))
return self.predictor(
graph.batch_size,
frag1_node_indices_in_complex,
frag2_node_indices_in_complex,
ligand_conv_out, protein_conv_out, complex_conv_out)
"""Utilities for using pretrained models."""
import os
import numpy as np
import torch
from rdkit import Chem
......@@ -10,20 +11,21 @@ from .mgcn import MGCNModel
from .mpnn import MPNNModel
from .schnet import SchNet
from .attentive_fp import AttentiveFP
from .acnn import ACNN
from ...data.utils import _get_dgl_url, download, get_download_dir, extract_archive
URL = {
'GCN_Tox21' : 'pre_trained/gcn_tox21.pth',
'GAT_Tox21' : 'pre_trained/gat_tox21.pth',
'GCN_Tox21': 'pre_trained/gcn_tox21.pth',
'GAT_Tox21': 'pre_trained/gat_tox21.pth',
'MGCN_Alchemy': 'pre_trained/mgcn_alchemy.pth',
'SCHNET_Alchemy': 'pre_trained/schnet_alchemy.pth',
'MPNN_Alchemy': 'pre_trained/mpnn_alchemy.pth',
'AttentiveFP_Aromaticity': 'pre_trained/attentivefp_aromaticity.pth',
'DGMG_ChEMBL_canonical' : 'pre_trained/dgmg_ChEMBL_canonical.pth',
'DGMG_ChEMBL_random' : 'pre_trained/dgmg_ChEMBL_random.pth',
'DGMG_ZINC_canonical' : 'pre_trained/dgmg_ZINC_canonical.pth',
'DGMG_ZINC_random' : 'pre_trained/dgmg_ZINC_random.pth',
'JTNN_ZINC':'pre_trained/JTNN_ZINC.pth'
'DGMG_ChEMBL_canonical': 'pre_trained/dgmg_ChEMBL_canonical.pth',
'DGMG_ChEMBL_random': 'pre_trained/dgmg_ChEMBL_random.pth',
'DGMG_ZINC_canonical': 'pre_trained/dgmg_ZINC_canonical.pth',
'DGMG_ZINC_random': 'pre_trained/dgmg_ZINC_random.pth',
'JTNN_ZINC': 'pre_trained/JTNN_ZINC.pth'
}
def download_and_load_checkpoint(model_name, model, model_postfix,
......@@ -56,6 +58,9 @@ def download_and_load_checkpoint(model_name, model, model_postfix,
checkpoint = torch.load(local_pretrained_path, map_location='cpu')
model.load_state_dict(checkpoint['model_state_dict'])
if log:
print('Pretrained model loaded')
return model
def load_pretrained(model_name, log=True):
......@@ -77,6 +82,14 @@ def load_pretrained(model_name, log=True):
* ``'DGMG_ZINC_canonical'``
* ``'DGMG_ZINC_random'``
* ``'JTNN_ZINC'``
* ``'ACNN_PDBBind_core_pocket_random'``
* ``'ACNN_PDBBind_core_pocket_scaffold'``
* ``'ACNN_PDBBind_core_pocket_stratified'``
* ``'ACNN_PDBBind_core_pocket_temporal'``
* ``'ACNN_PDBBind_refined_pocket_random'``
* ``'ACNN_PDBBind_refined_pocket_scaffold'``
* ``'ACNN_PDBBind_refined_pocket_stratified'``
* ``'ACNN_PDBBind_refined_pocket_temporal'``
log : bool
Whether to print progress for model loading
......@@ -147,7 +160,13 @@ def load_pretrained(model_name, log=True):
hidden_size=450,
latent_size=56)
if log:
print('Pretrained model loaded')
elif model_name.startswith('ACNN_PDBBind_core_pocket'):
model = ACNN(hidden_sizes=[32, 32, 16],
weight_init_stddevs=[1. / float(np.sqrt(32)), 1. / float(np.sqrt(32)),
1. / float(np.sqrt(16)), 0.01],
dropouts=[0., 0., 0.],
features_to_use=torch.tensor([
1., 6., 7., 8., 9., 11., 12., 15., 16., 17., 20., 25., 30., 35., 53.]),
radial=[[12.0], [0.0, 4.0, 8.0], [4.0]])
return download_and_load_checkpoint(model_name, model, URL[model_name], log=log)
......@@ -18,8 +18,9 @@ from .gatedgraphconv import GatedGraphConv
from .densechebconv import DenseChebConv
from .densegraphconv import DenseGraphConv
from .densesageconv import DenseSAGEConv
from .atomicconv import AtomicConv
__all__ = ['GraphConv', 'GATConv', 'TAGConv', 'RelGraphConv', 'SAGEConv',
'SGConv', 'APPNPConv', 'GINConv', 'GatedGraphConv', 'GMMConv',
'ChebConv', 'AGNNConv', 'NNConv', 'DenseGraphConv', 'DenseSAGEConv',
'DenseChebConv', 'EdgeConv']
'DenseChebConv', 'EdgeConv', 'AtomicConv']
"""Torch Module for Atomic Convolution Layer"""
import numpy as np
import torch as th
import torch.nn as nn
class RadialPooling(nn.Module):
r"""Radial pooling from paper `Atomic Convolutional Networks for
Predicting Protein-Ligand Binding Affinity <https://arxiv.org/abs/1703.10603>`__.
We denote the distance between atom :math:`i` and :math:`j` by :math:`r_{ij}`.
A radial pooling layer transforms distances with radial filters. For radial filter
indexed by :math:`k`, it projects edge distances with
.. math::
h_{ij}^{k} = \exp(-\gamma_{k}|r_{ij}-r_{k}|^2)
If :math:`r_{ij} < c_k`,
.. math::
f_{ij}^{k} = 0.5 * \cos(\frac{\pi r_{ij}}{c_k} + 1),
else,
.. math::
f_{ij}^{k} = 0.
Finally,
.. math::
e_{ij}^{k} = h_{ij}^{k} * f_{ij}^{k}
Parameters
----------
interaction_cutoffs : float32 tensor of shape (K)
:math:`c_k` in the equations above. Roughly they can be considered as learnable cutoffs
and two atoms are considered as connected if the distance between them is smaller than
the cutoffs. K for the number of radial filters.
rbf_kernel_means : float32 tensor of shape (K)
:math:`r_k` in the equations above. K for the number of radial filters.
rbf_kernel_scaling : float32 tensor of shape (K)
:math:`\gamma_k` in the equations above. K for the number of radial filters.
"""
def __init__(self, interaction_cutoffs, rbf_kernel_means, rbf_kernel_scaling):
super(RadialPooling, self).__init__()
self.interaction_cutoffs = nn.Parameter(
interaction_cutoffs.reshape(-1, 1, 1), requires_grad=True)
self.rbf_kernel_means = nn.Parameter(
rbf_kernel_means.reshape(-1, 1, 1), requires_grad=True)
self.rbf_kernel_scaling = nn.Parameter(
rbf_kernel_scaling.reshape(-1, 1, 1), requires_grad=True)
def forward(self, distances):
"""Apply the layer to transform edge distances.
Parameters
----------
distances : Float32 tensor of shape (E, 1)
Distance between end nodes of edges. E for the number of edges.
Returns
-------
Float32 tensor of shape (K, E, 1)
Transformed edge distances. K for the number of radial filters.
"""
scaled_euclidean_distance = - self.rbf_kernel_scaling * \
(distances - self.rbf_kernel_means) ** 2 # (K, E, 1)
rbf_kernel_results = th.exp(scaled_euclidean_distance) # (K, E, 1)
cos_values = 0.5 * (th.cos(np.pi * distances / self.interaction_cutoffs) + 1) # (K, E, 1)
cutoff_values = th.where(
distances <= self.interaction_cutoffs,
cos_values, th.zeros_like(cos_values)) # (K, E, 1)
# Note that there appears to be an inconsistency between the paper and
# DeepChem's implementation. In the paper, the scaled_euclidean_distance first
# gets multiplied by cutoff_values, followed by exponentiation. Here we follow
# the practice of DeepChem.
return rbf_kernel_results * cutoff_values
def msg_func(edges):
"""Send messages along edges.
Parameters
----------
edges : EdgeBatch
A batch of edges.
Returns
-------
dict mapping 'm' to Float32 tensor of shape (E, K * T)
Messages computed. E for the number of edges, K for the number of
radial filters and T for the number of features to use
(types of atomic number in the paper).
"""
return {'m': th.einsum(
'ij,ik->ijk', edges.src['hv'], edges.data['he']).view(len(edges), -1)}
def reduce_func(nodes):
"""Collect messages and update node representations.
Parameters
----------
nodes : NodeBatch
A batch of nodes.
Returns
-------
dict mapping 'hv_new' to Float32 tensor of shape (V, K * T)
Updated node representations. V for the number of nodes, K for the number of
radial filters and T for the number of features to use
(types of atomic number in the paper).
"""
return {'hv_new': nodes.mailbox['m'].sum(1)}
class AtomicConv(nn.Module):
r"""Atomic Convolution Layer from paper `Atomic Convolutional Networks for
Predicting Protein-Ligand Binding Affinity <https://arxiv.org/abs/1703.10603>`__.
We denote the type of atom :math:`i` by :math:`z_i` and the distance between atom
:math:`i` and :math:`j` by :math:`r_{ij}`.
**Distance Transformation**
An atomic convolution layer first transforms distances with radial filters and
then perform a pooling operation.
For radial filter indexed by :math:`k`, it projects edge distances with
.. math::
h_{ij}^{k} = \exp(-\gamma_{k}|r_{ij}-r_{k}|^2)
If :math:`r_{ij} < c_k`,
.. math::
f_{ij}^{k} = 0.5 * \cos(\frac{\pi r_{ij}}{c_k} + 1),
else,
.. math::
f_{ij}^{k} = 0.
Finally,
.. math::
e_{ij}^{k} = h_{ij}^{k} * f_{ij}^{k}
**Aggregation**
For each type :math:`t`, each atom collects distance information from all neighbor atoms
of type :math:`t`:
.. math::
p_{i, t}^{k} = \sum_{j\in N(i)} e_{ij}^{k} * 1(z_j == t)
We concatenate the results for all RBF kernels and atom types.
Notes
-----
* This convolution operation is designed for molecular graphs in Chemistry, but it might
be possible to extend it to more general graphs.
* There seems to be an inconsistency about the definition of :math:`e_{ij}^{k}` in the
paper and the author's implementation. We follow the author's implementation. In the
paper, :math:`e_{ij}^{k}` was defined as
:math:`\exp(-\gamma_{k}|r_{ij}-r_{k}|^2 * f_{ij}^{k})`.
* :math:`\gamma_{k}`, :math:`r_k` and :math:`c_k` are all learnable.
Parameters
----------
interaction_cutoffs : float32 tensor of shape (K)
:math:`c_k` in the equations above. Roughly they can be considered as learnable cutoffs
and two atoms are considered as connected if the distance between them is smaller than
the cutoffs. K for the number of radial filters.
rbf_kernel_means : float32 tensor of shape (K)
:math:`r_k` in the equations above. K for the number of radial filters.
rbf_kernel_scaling : float32 tensor of shape (K)
:math:`\gamma_k` in the equations above. K for the number of radial filters.
features_to_use : None or float tensor of shape (T)
In the original paper, these are atomic numbers to consider, representing the types
of atoms. T for the number of types of atomic numbers. Default to None.
"""
def __init__(self, interaction_cutoffs, rbf_kernel_means,
rbf_kernel_scaling, features_to_use=None):
super(AtomicConv, self).__init__()
self.radial_pooling = RadialPooling(interaction_cutoffs=interaction_cutoffs,
rbf_kernel_means=rbf_kernel_means,
rbf_kernel_scaling=rbf_kernel_scaling)
if features_to_use is None:
self.num_channels = 1
self.features_to_use = None
else:
self.num_channels = len(features_to_use)
self.features_to_use = nn.Parameter(features_to_use, requires_grad=False)
def forward(self, graph, feat, distances):
"""Apply the atomic convolution layer.
Parameters
----------
graph : DGLGraph or BatchedDGLGraph
Topology based on which message passing is performed.
feat : Float32 tensor of shape (V, 1)
Initial node features, which are atomic numbers in the paper.
V for the number of nodes.
distances : Float32 tensor of shape (E, 1)
Distance between end nodes of edges. E for the number of edges.
Returns
-------
Float32 tensor of shape (V, K * T)
Updated node representations. V for the number of nodes, K for the
number of radial filters, and T for the number of types of atomic numbers.
"""
radial_pooled_values = self.radial_pooling(distances) # (K, E, 1)
graph = graph.local_var()
if self.features_to_use is not None:
feat = (feat == self.features_to_use).float() # (V, T)
graph.ndata['hv'] = feat
graph.edata['he'] = radial_pooled_values.transpose(1, 0).squeeze(-1) # (E, K)
graph.update_all(msg_func, reduce_func)
return graph.ndata['hv_new'].view(graph.number_of_nodes(), -1) # (V, K * T)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment