"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "91ddd2a25b848df0fa1262d4f1cd98c7ccb87750"
Unverified Commit 545cc065 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[DGL-LifeSci] WLN for Reaction Center Prediction (#1360)

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update
parent 5a1ef70f
...@@ -26,14 +26,15 @@ setup( ...@@ -26,14 +26,15 @@ setup(
packages=[package for package in find_packages() packages=[package for package in find_packages()
if package.startswith('dgllife')], if package.startswith('dgllife')],
install_requires=[ install_requires=[
'dgl>=0.4',
'torch>=1' 'torch>=1'
'scikit-learn>=0.21.2', 'scikit-learn>=0.21.2',
'pandas>=0.25.1', 'pandas>=0.25.1',
'requests>=2.22.0' 'requests>=2.22.0',
'tqdm'
], ],
url='https://github.com/dmlc/dgl/tree/master/apps/life_sci', url='https://github.com/dmlc/dgl/tree/master/apps/life_sci',
classifiers=[ classifiers=[
'Development Status :: 3 - Alpha',
'Programming Language :: Python :: 3', 'Programming Language :: Python :: 3',
] ],
) )
import os import os
from dgllife.data import * from dgllife.data import *
from dgllife.data.uspto import get_bond_changes, process_file
def remove_file(fname): def remove_file(fname):
if os.path.isfile(fname): if os.path.isfile(fname):
...@@ -10,14 +11,17 @@ def remove_file(fname): ...@@ -10,14 +11,17 @@ def remove_file(fname):
pass pass
def test_pubchem_aromaticity(): def test_pubchem_aromaticity():
print('Test pubchem aromaticity')
dataset = PubChemBioAssayAromaticity() dataset = PubChemBioAssayAromaticity()
remove_file('pubchem_aromaticity_dglgraph.bin') remove_file('pubchem_aromaticity_dglgraph.bin')
def test_tox21(): def test_tox21():
print('Test Tox21')
dataset = Tox21() dataset = Tox21()
remove_file('tox21_dglgraph.bin') remove_file('tox21_dglgraph.bin')
def test_alchemy(): def test_alchemy():
print('Test Alchemy')
dataset = TencentAlchemyDataset(mode='valid', dataset = TencentAlchemyDataset(mode='valid',
node_featurizer=None, node_featurizer=None,
edge_featurizer=None) edge_featurizer=None)
...@@ -27,10 +31,47 @@ def test_alchemy(): ...@@ -27,10 +31,47 @@ def test_alchemy():
load=False) load=False)
def test_pdbbind(): def test_pdbbind():
print('Test PDBBind')
dataset = PDBBind(subset='core', remove_hs=True) dataset = PDBBind(subset='core', remove_hs=True)
def test_wln_reaction():
print('Test datasets for reaction prediction with WLN')
reaction1 = '[CH2:15]([CH:16]([CH3:17])[CH3:18])[Mg+:19].[CH2:20]1[O:21][CH2:22][CH2:23]' \
'[CH2:24]1.[Cl-:14].[OH:1][c:2]1[n:3][cH:4][c:5]([C:6](=[O:7])[N:8]([O:9]' \
'[CH3:10])[CH3:11])[cH:12][cH:13]1>>[OH:1][c:2]1[n:3][cH:4][c:5]([C:6](=[O:7])' \
'[CH2:15][CH:16]([CH3:17])[CH3:18])[cH:12][cH:13]1\n'
reaction2 = '[CH3:14][NH2:15].[N+:1](=[O:2])([O-:3])[c:4]1[cH:5][c:6]([C:7](=[O:8])[OH:9])' \
'[cH:10][cH:11][c:12]1[Cl:13].[OH2:16]>>[N+:1](=[O:2])([O-:3])[c:4]1[cH:5][c:6]' \
'([C:7](=[O:8])[OH:9])[cH:10][cH:11][c:12]1[NH:15][CH3:14]\n'
reactions = [reaction1, reaction2]
# Test utility functions
assert get_bond_changes(reaction2) == {('12', '13', 0.0), ('12', '15', 1.0)}
with open('test.txt', 'w') as f:
for reac in reactions:
f.write(reac)
process_file('test.txt')
with open('test.txt.proc', 'r') as f:
lines = f.readlines()
for i in range(len(lines)):
l = lines[i].strip()
react = reactions[i].strip()
bond_changes = get_bond_changes(react)
assert l == '{} {}'.format(
react,
';'.join(['{}-{}-{}'.format(x[0], x[1], x[2]) for x in bond_changes]))
remove_file('test.txt.proc')
# Test configured dataset
dataset = WLNReactionDataset('test.txt', 'test_graphs.bin')
remove_file('test.txt')
remove_file('test.txt.proc')
remove_file('test_graphs.bin')
if __name__ == '__main__': if __name__ == '__main__':
test_pubchem_aromaticity() test_pubchem_aromaticity()
test_tox21() test_tox21()
test_alchemy() test_alchemy()
test_pdbbind() test_pdbbind()
test_wln_reaction()
...@@ -209,6 +209,32 @@ def test_mpnn_gnn(): ...@@ -209,6 +209,32 @@ def test_mpnn_gnn():
assert gnn(g, node_feats, edge_feats).shape == torch.Size([3, 2]) assert gnn(g, node_feats, edge_feats).shape == torch.Size([3, 2])
assert gnn(bg, batch_node_feats, batch_edge_feats).shape == torch.Size([8, 2]) assert gnn(bg, batch_node_feats, batch_edge_feats).shape == torch.Size([8, 2])
def test_wln():
if torch.cuda.is_available():
device = torch.device('cuda:0')
else:
device = torch.device('cpu')
g, node_feats, edge_feats = test_graph3()
g, node_feats, edge_feats = g.to(device), node_feats.to(device), edge_feats.to(device)
bg, batch_node_feats, batch_edge_feats = test_graph4()
bg, batch_node_feats, batch_edge_feats = bg.to(device), batch_node_feats.to(device), \
batch_edge_feats.to(device)
# Test default setting
gnn = WLN(node_in_feats=1,
edge_in_feats=2)
assert gnn(g, node_feats, edge_feats).shape == torch.Size([3, 300])
assert gnn(bg, batch_node_feats, batch_edge_feats).shape == torch.Size([8, 300])
# Test configured setting
gnn = WLN(node_in_feats=1,
edge_in_feats=2,
node_out_feats=3,
n_layers=1)
assert gnn(g, node_feats, edge_feats).shape == torch.Size([3, 3])
assert gnn(bg, batch_node_feats, batch_edge_feats).shape == torch.Size([8, 3])
if __name__ == '__main__': if __name__ == '__main__':
test_gcn() test_gcn()
test_gat() test_gat()
...@@ -216,3 +242,4 @@ if __name__ == '__main__': ...@@ -216,3 +242,4 @@ if __name__ == '__main__':
test_schnet_gnn() test_schnet_gnn()
test_mgcn_gnn() test_mgcn_gnn()
test_mpnn_gnn() test_mpnn_gnn()
test_wln()
import dgl
import torch
from dgl import DGLGraph
from dgllife.model.model_zoo import *
def get_complete_graph(num_nodes):
edge_list = []
for i in range(num_nodes):
for j in range(num_nodes):
edge_list.append((i, j))
return DGLGraph(edge_list)
def test_graph1():
"""
Bi-directed graphs and complete graphs for the molecules.
In addition to node features/edge features, we also return
features for the pairs of nodes.
"""
mol_graph = DGLGraph([(0, 1), (0, 2), (1, 2)])
node_feats = torch.arange(mol_graph.number_of_nodes()).float().reshape(-1, 1)
edge_feats = torch.arange(2 * mol_graph.number_of_edges()).float().reshape(-1, 2)
complete_graph = get_complete_graph(mol_graph.number_of_nodes())
atom_pair_feats = torch.arange(complete_graph.number_of_edges()).float().reshape(-1, 1)
return mol_graph, node_feats, edge_feats, complete_graph, atom_pair_feats
def test_graph2():
"""Batched version of test_graph1"""
mol_graph1 = DGLGraph([(0, 1), (0, 2), (1, 2)])
mol_graph2 = DGLGraph([(0, 1), (1, 2), (1, 3), (1, 4)])
batch_mol_graph = dgl.batch([mol_graph1, mol_graph2])
node_feats = torch.arange(batch_mol_graph.number_of_nodes()).float().reshape(-1, 1)
edge_feats = torch.arange(2 * batch_mol_graph.number_of_edges()).float().reshape(-1, 2)
complete_graph1 = get_complete_graph(mol_graph1.number_of_nodes())
complete_graph2 = get_complete_graph(mol_graph2.number_of_nodes())
batch_complete_graph = dgl.batch([complete_graph1, complete_graph2])
atom_pair_feats = torch.arange(batch_complete_graph.number_of_edges()).float().reshape(-1, 1)
return batch_mol_graph, node_feats, edge_feats, batch_complete_graph, atom_pair_feats
def test_wln_reaction_center():
if torch.cuda.is_available():
device = torch.device('cuda:0')
else:
device = torch.device('cpu')
mol_graph, node_feats, edge_feats, complete_graph, atom_pair_feats = test_graph1()
mol_graph = mol_graph.to(device)
node_feats, edge_feats = node_feats.to(device), edge_feats.to(device)
complete_graph = complete_graph.to(device)
atom_pair_feats = atom_pair_feats.to(device)
batch_mol_graph, batch_node_feats, batch_edge_feats, batch_complete_graph, \
batch_atom_pair_feats = test_graph2()
batch_mol_graph = batch_mol_graph.to(device)
batch_node_feats, batch_edge_feats = batch_node_feats.to(device), batch_edge_feats.to(device)
batch_complete_graph = batch_complete_graph.to(device)
batch_atom_pair_feats = batch_atom_pair_feats.to(device)
# Test default setting
model = WLNReactionCenter(node_in_feats=1,
edge_in_feats=2,
node_pair_in_feats=1).to(device)
assert model(mol_graph, complete_graph, node_feats, edge_feats, atom_pair_feats).shape == \
torch.Size([complete_graph.number_of_edges(), 5])
assert model(batch_mol_graph, batch_complete_graph, batch_node_feats,
batch_edge_feats, batch_atom_pair_feats).shape == \
torch.Size([batch_complete_graph.number_of_edges(), 5])
# Test configured setting
model = WLNReactionCenter(node_in_feats=1,
edge_in_feats=2,
node_pair_in_feats=1,
node_out_feats=1,
n_layers=1,
n_tasks=1).to(device)
assert model(mol_graph, complete_graph, node_feats, edge_feats, atom_pair_feats).shape == \
torch.Size([complete_graph.number_of_edges(), 1])
assert model(batch_mol_graph, batch_complete_graph, batch_node_feats,
batch_edge_feats, batch_atom_pair_feats).shape == \
torch.Size([batch_complete_graph.number_of_edges(), 1])
if __name__ == '__main__':
test_wln_reaction_center()
...@@ -54,6 +54,16 @@ def test_atom_total_degree(): ...@@ -54,6 +54,16 @@ def test_atom_total_degree():
assert atom_total_degree(mol.GetAtomWithIdx(0)) == [4] assert atom_total_degree(mol.GetAtomWithIdx(0)) == [4]
assert atom_total_degree(mol.GetAtomWithIdx(2)) == [2] assert atom_total_degree(mol.GetAtomWithIdx(2)) == [2]
def test_atom_explicit_valence_one_hot():
mol = test_mol1()
assert atom_implicit_valence_one_hot(mol.GetAtomWithIdx(0), [1, 2, 3]) == [1, 0, 0]
assert atom_implicit_valence_one_hot(mol.GetAtomWithIdx(1), [1, 2, 3]) == [0, 1, 0]
def test_atom_explicit_valence():
mol = test_mol1()
assert atom_explicit_valence(mol.GetAtomWithIdx(0)) == [1]
assert atom_explicit_valence(mol.GetAtomWithIdx(1)) == [2]
def test_atom_implicit_valence_one_hot(): def test_atom_implicit_valence_one_hot():
mol = test_mol1() mol = test_mol1()
assert atom_implicit_valence_one_hot(mol.GetAtomWithIdx(0), [1, 2, 3]) == [0, 0, 1] assert atom_implicit_valence_one_hot(mol.GetAtomWithIdx(0), [1, 2, 3]) == [0, 0, 1]
...@@ -106,6 +116,18 @@ def test_atom_is_aromatic(): ...@@ -106,6 +116,18 @@ def test_atom_is_aromatic():
mol = test_mol2() mol = test_mol2()
assert atom_is_aromatic(mol.GetAtomWithIdx(0)) == [1] assert atom_is_aromatic(mol.GetAtomWithIdx(0)) == [1]
def test_atom_is_in_ring_one_hot():
mol = test_mol1()
assert atom_is_in_ring_one_hot(mol.GetAtomWithIdx(0)) == [1, 0]
mol = test_mol2()
assert atom_is_in_ring_one_hot(mol.GetAtomWithIdx(0)) == [0, 1]
def test_atom_is_in_ring():
mol = test_mol1()
assert atom_is_in_ring(mol.GetAtomWithIdx(0)) == [0]
mol = test_mol2()
assert atom_is_in_ring(mol.GetAtomWithIdx(0)) == [1]
def test_atom_chiral_tag_one_hot(): def test_atom_chiral_tag_one_hot():
mol = test_mol1() mol = test_mol1()
assert atom_chiral_tag_one_hot(mol.GetAtomWithIdx(0)) == [1, 0, 0, 0] assert atom_chiral_tag_one_hot(mol.GetAtomWithIdx(0)) == [1, 0, 0, 0]
...@@ -246,6 +268,7 @@ if __name__ == '__main__': ...@@ -246,6 +268,7 @@ if __name__ == '__main__':
test_atom_degree() test_atom_degree()
test_atom_total_degree_one_hot() test_atom_total_degree_one_hot()
test_atom_total_degree() test_atom_total_degree()
test_atom_explicit_valence()
test_atom_implicit_valence_one_hot() test_atom_implicit_valence_one_hot()
test_atom_implicit_valence() test_atom_implicit_valence()
test_atom_hybridization_one_hot() test_atom_hybridization_one_hot()
...@@ -257,6 +280,8 @@ if __name__ == '__main__': ...@@ -257,6 +280,8 @@ if __name__ == '__main__':
test_atom_num_radical_electrons() test_atom_num_radical_electrons()
test_atom_is_aromatic_one_hot() test_atom_is_aromatic_one_hot()
test_atom_is_aromatic() test_atom_is_aromatic()
test_atom_is_in_ring_one_hot()
test_atom_is_in_ring()
test_atom_chiral_tag_one_hot() test_atom_chiral_tag_one_hot()
test_atom_mass() test_atom_mass()
test_concat_featurizer() test_concat_featurizer()
......
...@@ -7,6 +7,9 @@ from rdkit import Chem ...@@ -7,6 +7,9 @@ from rdkit import Chem
test_smiles1 = 'CCO' test_smiles1 = 'CCO'
test_smiles2 = 'Fc1ccccc1' test_smiles2 = 'Fc1ccccc1'
test_smiles3 = '[CH2:1]([CH3:2])[N:3]1[CH2:4][CH2:5][C:6]([CH3:16])' \
'([CH3:17])[c:7]2[cH:8][cH:9][c:10]([N+:13]([O-:14])=[O:15])' \
'[cH:11][c:12]21.[CH3:18][CH2:19][O:20][C:21]([CH3:22])=[O:23]'
class TestAtomFeaturizer(BaseAtomFeaturizer): class TestAtomFeaturizer(BaseAtomFeaturizer):
def __init__(self): def __init__(self):
...@@ -37,6 +40,16 @@ def test_smiles_to_bigraph(): ...@@ -37,6 +40,16 @@ def test_smiles_to_bigraph():
[1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.],
[1.], [1.], [1.], [1.]])) [1.], [1.], [1.], [1.]]))
# Test the case where atoms come with a default order and we do not
# want to change the order, which is related to the application of
# reaction center prediction.
g3 = smiles_to_bigraph(test_smiles3, node_featurizer=test_node_featurizer,
canonical_atom_order=False)
assert torch.allclose(g3.ndata['hv'], torch.tensor([[6.], [6.], [7.], [6.], [6.], [6.],
[6.], [6.], [6.], [6.], [6.], [6.],
[7.], [8.], [8.], [6.], [6.], [6.],
[6.], [8.], [6.], [6.], [8.]]))
def test_mol_to_bigraph(): def test_mol_to_bigraph():
mol1 = Chem.MolFromSmiles(test_smiles1) mol1 = Chem.MolFromSmiles(test_smiles1)
g1 = mol_to_bigraph(mol1, add_self_loop=True) g1 = mol_to_bigraph(mol1, add_self_loop=True)
...@@ -57,24 +70,56 @@ def test_mol_to_bigraph(): ...@@ -57,24 +70,56 @@ def test_mol_to_bigraph():
[1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.],
[1.], [1.], [1.], [1.]])) [1.], [1.], [1.], [1.]]))
# Test the case where atoms come with a default order and we do not
# want to change the order, which is related to the application of
# reaction center prediction.
mol3 = Chem.MolFromSmiles(test_smiles3)
g3 = mol_to_bigraph(mol3, node_featurizer=test_node_featurizer,
canonical_atom_order=False)
assert torch.allclose(g3.ndata['hv'], torch.tensor([[6.], [6.], [7.], [6.], [6.], [6.],
[6.], [6.], [6.], [6.], [6.], [6.],
[7.], [8.], [8.], [6.], [6.], [6.],
[6.], [8.], [6.], [6.], [8.]]))
def test_smiles_to_complete_graph(): def test_smiles_to_complete_graph():
test_node_featurizer = TestAtomFeaturizer() test_node_featurizer = TestAtomFeaturizer()
g = smiles_to_complete_graph(test_smiles1, add_self_loop=False, g1 = smiles_to_complete_graph(test_smiles1, add_self_loop=False,
node_featurizer=test_node_featurizer) node_featurizer=test_node_featurizer)
src, dst = g.edges() src, dst = g1.edges()
assert torch.allclose(src, torch.LongTensor([0, 0, 1, 1, 2, 2])) assert torch.allclose(src, torch.LongTensor([0, 0, 1, 1, 2, 2]))
assert torch.allclose(dst, torch.LongTensor([1, 2, 0, 2, 0, 1])) assert torch.allclose(dst, torch.LongTensor([1, 2, 0, 2, 0, 1]))
assert torch.allclose(g.ndata['hv'], torch.tensor([[6.], [8.], [6.]])) assert torch.allclose(g1.ndata['hv'], torch.tensor([[6.], [8.], [6.]]))
# Test the case where atoms come with a default order and we do not
# want to change the order, which is related to the application of
# reaction center prediction.
g2 = smiles_to_complete_graph(test_smiles3, node_featurizer=test_node_featurizer,
canonical_atom_order=False)
assert torch.allclose(g2.ndata['hv'], torch.tensor([[6.], [6.], [7.], [6.], [6.], [6.],
[6.], [6.], [6.], [6.], [6.], [6.],
[7.], [8.], [8.], [6.], [6.], [6.],
[6.], [8.], [6.], [6.], [8.]]))
def test_mol_to_complete_graph(): def test_mol_to_complete_graph():
test_node_featurizer = TestAtomFeaturizer() test_node_featurizer = TestAtomFeaturizer()
mol1 = Chem.MolFromSmiles(test_smiles1) mol1 = Chem.MolFromSmiles(test_smiles1)
g = mol_to_complete_graph(mol1, add_self_loop=False, g1 = mol_to_complete_graph(mol1, add_self_loop=False,
node_featurizer=test_node_featurizer) node_featurizer=test_node_featurizer)
src, dst = g.edges() src, dst = g1.edges()
assert torch.allclose(src, torch.LongTensor([0, 0, 1, 1, 2, 2])) assert torch.allclose(src, torch.LongTensor([0, 0, 1, 1, 2, 2]))
assert torch.allclose(dst, torch.LongTensor([1, 2, 0, 2, 0, 1])) assert torch.allclose(dst, torch.LongTensor([1, 2, 0, 2, 0, 1]))
assert torch.allclose(g.ndata['hv'], torch.tensor([[6.], [8.], [6.]])) assert torch.allclose(g1.ndata['hv'], torch.tensor([[6.], [8.], [6.]]))
# Test the case where atoms come with a default order and we do not
# want to change the order, which is related to the application of
# reaction center prediction.
mol2 = Chem.MolFromSmiles(test_smiles3)
g2 = mol_to_complete_graph(mol2, node_featurizer=test_node_featurizer,
canonical_atom_order=False)
assert torch.allclose(g2.ndata['hv'], torch.tensor([[6.], [6.], [7.], [6.], [6.], [6.],
[6.], [6.], [6.], [6.], [6.], [6.],
[7.], [8.], [8.], [6.], [6.], [6.],
[6.], [8.], [6.], [6.], [8.]]))
def test_k_nearest_neighbors(): def test_k_nearest_neighbors():
coordinates = np.array([[0.1, 0.1, 0.1], coordinates = np.array([[0.1, 0.1, 0.1],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment