Unverified Commit 2ce0e21b authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

Update (#1366)

parent 16627575
"""Convert molecules into DGLGraphs."""
import numpy as np
import torch
from dgl import DGLGraph
from functools import partial
from rdkit import Chem
from rdkit.Chem import rdmolfiles, rdmolops
try:
import mdtraj
except ImportError:
pass
from sklearn.neighbors import NearestNeighbors
__all__ = ['mol_to_graph',
'smiles_to_bigraph',
'mol_to_bigraph',
'smiles_to_complete_graph',
'mol_to_complete_graph',
'k_nearest_neighbors']
'k_nearest_neighbors',
'mol_to_nearest_neighbor_graph',
'smiles_to_nearest_neighbor_graph']
def mol_to_graph(mol, graph_constructor, node_featurizer, edge_featurizer, canonical_atom_order):
"""Convert an RDKit molecule object into a DGLGraph and featurize for it.
......@@ -262,51 +260,207 @@ def smiles_to_complete_graph(smiles, add_self_loop=False,
return mol_to_complete_graph(mol, add_self_loop, node_featurizer,
edge_featurizer, canonical_atom_order)
def k_nearest_neighbors(coordinates, neighbor_cutoff, max_num_neighbors):
"""Find k nearest neighbors for each atom based on the 3D coordinates and
return the resulted edges.
def k_nearest_neighbors(coordinates, neighbor_cutoff, max_num_neighbors=None,
p_distance=2, self_loops=False):
"""Find k nearest neighbors for each atom
For each atom, find its k nearest neighbors and return edges
from these neighbors to it.
We do not guarantee that the edges are sorted according to the distance
between atoms.
Parameters
----------
coordinates : numpy.ndarray of shape (N, 3)
The 3D coordinates of atoms in the molecule. N for the number of atoms.
coordinates : numpy.ndarray of shape (N, D)
The coordinates of atoms in the molecule. N for the number of atoms
and D for the dimensions of the coordinates.
neighbor_cutoff : float
Distance cutoff to define 'neighboring'.
If the distance between a pair of nodes is larger than neighbor_cutoff,
they will not be considered as neighboring nodes.
max_num_neighbors : int or None.
If not None, then this specifies the maximum number of closest neighbors
allowed for each atom.
If not None, then this specifies the maximum number of neighbors
allowed for each atom. Default to None.
p_distance : int
We compute the distance between neighbors using Minkowski (:math:`l_p`)
distance. When ``p_distance = 1``, Minkowski distance is equivalent to
Manhattan distance. When ``p_distance = 2``, Minkowski distance is
equivalent to the standard Euclidean distance. Default to 2.
self_loops : bool
Whether to allow a node to be its own neighbor. Default to False.
Returns
-------
srcs : list of int
Source nodes.
dsts : list of int
Destination nodes.
Destination nodes, corresponding to ``srcs``.
distances : list of float
Distances between the end nodes.
Distances between the end nodes, corresponding to ``srcs`` and ``dsts``.
"""
num_atoms = coordinates.shape[0]
traj = mdtraj.Trajectory(coordinates.reshape((1, num_atoms, 3)), None)
neighbors = mdtraj.geometry.compute_neighborlist(traj, neighbor_cutoff)
srcs, dsts, distances = [], [], []
model = NearestNeighbors(radius=neighbor_cutoff, p=p_distance)
model.fit(coordinates)
dists_, nbrs = model.radius_neighbors(coordinates)
srcs, dsts, dists = [], [], []
for i in range(num_atoms):
delta = coordinates[i] - coordinates.take(neighbors[i], axis=0)
dist = np.linalg.norm(delta, axis=1)
if max_num_neighbors is not None and len(neighbors[i]) > max_num_neighbors:
sorted_neighbors = list(zip(dist, neighbors[i]))
dists_i = dists_[i].tolist()
nbrs_i = nbrs[i].tolist()
if not self_loops:
dists_i.remove(0)
nbrs_i.remove(i)
if max_num_neighbors is not None and len(nbrs_i) > max_num_neighbors:
packed_nbrs = list(zip(dists_i, nbrs_i))
# Sort neighbors based on distance from smallest to largest
sorted_neighbors.sort(key=lambda tup: tup[0])
packed_nbrs.sort(key=lambda tup: tup[0])
dists_i, nbrs_i = map(list, zip(*packed_nbrs))
dsts.extend([i for _ in range(max_num_neighbors)])
srcs.extend([int(sorted_neighbors[j][1]) for j in range(max_num_neighbors)])
distances.extend([float(sorted_neighbors[j][0]) for j in range(max_num_neighbors)])
srcs.extend(nbrs_i[:max_num_neighbors])
dists.extend(dists_i[:max_num_neighbors])
else:
dsts.extend([i for _ in range(len(neighbors[i]))])
srcs.extend(neighbors[i].tolist())
distances.extend(dist.tolist())
dsts.extend([i for _ in range(len(nbrs_i))])
srcs.extend(nbrs_i)
dists.extend(dists_i)
return srcs, dsts, dists
def mol_to_nearest_neighbor_graph(mol,
coordinates,
neighbor_cutoff,
max_num_neighbors=None,
p_distance=2,
add_self_loop=False,
node_featurizer=None,
edge_featurizer=None,
canonical_atom_order=True,
keep_dists=False,
dist_field='dist'):
"""Convert an RDKit molecule into a nearest neighbor graph and featurize for it.
Different from bigraph and complete graph, the nearest neighbor graph
may not be symmetric since i is the closest neighbor of j does not
necessarily suggest the other way.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule holder
coordinates : numpy.ndarray of shape (N, D)
The coordinates of atoms in the molecule. N for the number of atoms
and D for the dimensions of the coordinates.
neighbor_cutoff : float
If the distance between a pair of nodes is larger than neighbor_cutoff,
they will not be considered as neighboring nodes.
max_num_neighbors : int or None.
If not None, then this specifies the maximum number of neighbors
allowed for each atom. Default to None.
p_distance : int
We compute the distance between neighbors using Minkowski (:math:`l_p`)
distance. When ``p_distance = 1``, Minkowski distance is equivalent to
Manhattan distance. When ``p_distance = 2``, Minkowski distance is
equivalent to the standard Euclidean distance. Default to 2.
add_self_loop : bool
Whether to add self loops in DGLGraphs. Default to False.
node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for nodes like atoms in a molecule, which can be used to update
ndata for a DGLGraph. Default to None.
edge_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for edges like bonds in a molecule, which can be used to update
edata for a DGLGraph. Default to None.
canonical_atom_order : bool
Whether to use a canonical order of atoms returned by RDKit. Setting it
to true might change the order of atoms in the graph constructed. Default
to True.
keep_dists : bool
Whether to store the distance between neighboring atoms in ``edata`` of the
constructed DGLGraphs. Default to False.
dist_field : str
Field for storing distance between neighboring atoms in ``edata``. This comes
into effect only when ``keep_dists=True``. Default to ``'dist'``.
"""
if canonical_atom_order:
new_order = rdmolfiles.CanonicalRankAtoms(mol)
mol = rdmolops.RenumberAtoms(mol, new_order)
srcs, dsts, dists = k_nearest_neighbors(coordinates=coordinates,
neighbor_cutoff=neighbor_cutoff,
max_num_neighbors=max_num_neighbors,
p_distance=p_distance,
self_loops=add_self_loop)
g = DGLGraph()
# Add nodes first since some nodes may be completely isolated
num_atoms = mol.GetNumAtoms()
g.add_nodes(num_atoms)
# Add edges
g.add_edges(srcs, dsts)
if node_featurizer is not None:
g.ndata.update(node_featurizer(mol))
return srcs, dsts, distances
if edge_featurizer is not None:
g.edata.update(edge_featurizer(mol))
# Todo(Mufei): smiles_to_knn_graph, mol_to_knn_graph
if keep_dists:
assert dist_field not in g.edata, \
'Expect {} to be reserved for distance between neighboring atoms.'
g.edata[dist_field] = torch.tensor(dists).float().reshape(-1, 1)
return g
def smiles_to_nearest_neighbor_graph(smiles,
coordinates,
neighbor_cutoff,
max_num_neighbors=None,
p_distance=2,
add_self_loop=False,
node_featurizer=None,
edge_featurizer=None,
canonical_atom_order=True,
keep_dists=False,
dist_field='dist'):
"""Convert a SMILES into a nearest neighbor graph and featurize for it.
Different from bigraph and complete graph, the nearest neighbor graph
may not be symmetric since i is the closest neighbor of j does not
necessarily suggest the other way.
Parameters
----------
smiles : str
String of SMILES
coordinates : numpy.ndarray of shape (N, D)
The coordinates of atoms in the molecule. N for the number of atoms
and D for the dimensions of the coordinates.
neighbor_cutoff : float
If the distance between a pair of nodes is larger than neighbor_cutoff,
they will not be considered as neighboring nodes.
max_num_neighbors : int or None.
If not None, then this specifies the maximum number of neighbors
allowed for each atom. Default to None.
p_distance : int
We compute the distance between neighbors using Minkowski (:math:`l_p`)
distance. When ``p_distance = 1``, Minkowski distance is equivalent to
Manhattan distance. When ``p_distance = 2``, Minkowski distance is
equivalent to the standard Euclidean distance. Default to 2.
add_self_loop : bool
Whether to add self loops in DGLGraphs. Default to False.
node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for nodes like atoms in a molecule, which can be used to update
ndata for a DGLGraph. Default to None.
edge_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for edges like bonds in a molecule, which can be used to update
edata for a DGLGraph. Default to None.
canonical_atom_order : bool
Whether to use a canonical order of atoms returned by RDKit. Setting it
to true might change the order of atoms in the graph constructed. Default
to True.
keep_dists : bool
Whether to store the distance between neighboring atoms in ``edata`` of the
constructed DGLGraphs. Default to False.
dist_field : str
Field for storing distance between neighboring atoms in ``edata``. This comes
into effect only when ``keep_dists=True``. Default to ``'dist'``.
"""
mol = Chem.MolFromSmiles(smiles)
return mol_to_nearest_neighbor_graph(
mol, coordinates, neighbor_cutoff, max_num_neighbors, p_distance, add_self_loop,
node_featurizer, edge_featurizer, canonical_atom_order, keep_dists, dist_field)
......@@ -27,7 +27,7 @@ setup(
if package.startswith('dgllife')],
install_requires=[
'torch>=1'
'scikit-learn>=0.21.2',
'scikit-learn>=0.22.2',
'pandas>=0.25.1',
'requests>=2.22.0',
'tqdm'
......
......@@ -4,6 +4,7 @@ import torch
from dgllife.utils.featurizers import *
from dgllife.utils.mol_to_graph import *
from rdkit import Chem
from rdkit.Chem import AllChem
test_smiles1 = 'CCO'
test_smiles2 = 'Fc1ccccc1'
......@@ -131,18 +132,101 @@ def test_k_nearest_neighbors():
neighbor_cutoff = 1.
max_num_neighbors = 2
srcs, dsts, dists = k_nearest_neighbors(coordinates, neighbor_cutoff, max_num_neighbors)
assert srcs == [2, 3, 2, 0, 0, 1, 0, 2, 5, 4]
assert dsts == [0, 0, 1, 1, 2, 2, 3, 3, 4, 5]
assert dists == [0.07071067811865474,
0.07810249675906654,
0.07071067811865477,
0.1,
0.07071067811865474,
0.07071067811865477,
0.07810249675906654,
0.07810249675906654,
0.14142135623730956,
0.14142135623730956]
assert srcs == [2, 3, 2, 0, 0, 1, 0, 2, 1, 5, 4]
assert dsts == [0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5]
assert dists == [0.07071067811865478, 0.0781024967590666, 0.07071067811865483,
0.1, 0.07071067811865478, 0.07071067811865483, 0.0781024967590666,
0.0781024967590666, 1.0, 0.14142135623730956, 0.14142135623730956]
# Test the case where self loops are included
srcs, dsts, dists = k_nearest_neighbors(coordinates, neighbor_cutoff,
max_num_neighbors, self_loops=True)
assert srcs == [0, 2, 1, 2, 2, 0, 3, 0, 4, 5, 4, 5]
assert dsts == [0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5]
assert dists == [0.0, 0.07071067811865478, 0.0, 0.07071067811865483, 0.0,
0.07071067811865478, 0.0, 0.0781024967590666, 0.0,
0.14142135623730956, 0.14142135623730956, 0.0]
# Test the case where max_num_neighbors is not given
srcs, dsts, dists = k_nearest_neighbors(coordinates, neighbor_cutoff=10.)
assert srcs == [1, 2, 3, 4, 5, 0, 2, 3, 4, 5, 0, 1, 3, 4, 5,
0, 1, 2, 4, 5, 0, 1, 2, 3, 5, 0, 1, 2, 3, 4]
assert dsts == [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2,
3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5]
assert dists == [0.1, 0.07071067811865478, 0.0781024967590666, 1.1,
1.2041594578792296, 0.1, 0.07071067811865483,
0.12688577540449525, 1.0, 1.104536101718726,
0.07071067811865478, 0.07071067811865483,
0.0781024967590666, 1.0511898020814319, 1.151086443322134,
0.0781024967590666, 0.12688577540449525, 0.0781024967590666,
1.1027692415006867, 1.202538980657176, 1.1, 1.0,
1.0511898020814319, 1.1027692415006867, 0.14142135623730956,
1.2041594578792296, 1.104536101718726, 1.151086443322134,
1.202538980657176, 0.14142135623730956]
def test_smiles_to_nearest_neighbor_graph():
mol = Chem.MolFromSmiles(test_smiles1)
AllChem.EmbedMolecule(mol)
coordinates = mol.GetConformers()[0].GetPositions()
# Test node featurizer
test_node_featurizer = TestAtomFeaturizer()
g = smiles_to_nearest_neighbor_graph(test_smiles1, coordinates, neighbor_cutoff=10,
node_featurizer=test_node_featurizer)
assert torch.allclose(g.ndata['hv'], torch.tensor([[6.], [8.], [6.]]))
assert g.number_of_edges() == 6
assert 'dist' not in g.edata
# Test self loops
g = smiles_to_nearest_neighbor_graph(test_smiles1, coordinates, neighbor_cutoff=10,
add_self_loop=True)
assert g.number_of_edges() == 9
# Test max_num_neighbors
g = smiles_to_nearest_neighbor_graph(test_smiles1, coordinates, neighbor_cutoff=10,
max_num_neighbors=1, add_self_loop=True)
assert g.number_of_edges() == 3
# Test pairwise distances
g = smiles_to_nearest_neighbor_graph(test_smiles1, coordinates,
neighbor_cutoff=10, keep_dists=True)
assert 'dist' in g.edata
coordinates = torch.from_numpy(coordinates)
srcs, dsts = g.edges()
dist = torch.norm(
coordinates[srcs] - coordinates[dsts], dim=1, p=2).float().reshape(-1, 1)
assert torch.allclose(dist, g.edata['dist'])
def test_mol_to_nearest_neighbor_graph():
mol = Chem.MolFromSmiles(test_smiles1)
AllChem.EmbedMolecule(mol)
coordinates = mol.GetConformers()[0].GetPositions()
# Test node featurizer
test_node_featurizer = TestAtomFeaturizer()
g = mol_to_nearest_neighbor_graph(mol, coordinates, neighbor_cutoff=10,
node_featurizer=test_node_featurizer)
assert torch.allclose(g.ndata['hv'], torch.tensor([[6.], [8.], [6.]]))
assert g.number_of_edges() == 6
assert 'dist' not in g.edata
# Test self loops
g = mol_to_nearest_neighbor_graph(mol, coordinates, neighbor_cutoff=10, add_self_loop=True)
assert g.number_of_edges() == 9
# Test max_num_neighbors
g = mol_to_nearest_neighbor_graph(mol, coordinates, neighbor_cutoff=10,
max_num_neighbors=1, add_self_loop=True)
assert g.number_of_edges() == 3
# Test pairwise distances
g = mol_to_nearest_neighbor_graph(mol, coordinates, neighbor_cutoff=10, keep_dists=True)
assert 'dist' in g.edata
coordinates = torch.from_numpy(coordinates)
srcs, dsts = g.edges()
dist = torch.norm(
coordinates[srcs] - coordinates[dsts], dim=1, p=2).float().reshape(-1, 1)
assert torch.allclose(dist, g.edata['dist'])
if __name__ == '__main__':
test_smiles_to_bigraph()
......@@ -150,3 +234,5 @@ if __name__ == '__main__':
test_smiles_to_complete_graph()
test_mol_to_complete_graph()
test_k_nearest_neighbors()
test_smiles_to_nearest_neighbor_graph()
test_mol_to_nearest_neighbor_graph()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment