"src/vscode:/vscode.git/clone" did not exist on "95d374845373d9bcbc20df08475240b90d555962"
Unverified Commit 011656fd authored by jjhu94's avatar jjhu94 Committed by GitHub
Browse files

[DGL-LifeSci] Weave for Molecular Property Prediction (#1441)



* featurize for weave model

* weave module for molecular graphs

* Update

* completed weave model

* update the whole weave model

* Update

* update atom (node) features

'

* "featurizer"

* Add files via upload

add edge featurizer

* Add files via upload

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update
Co-authored-by: default avatarmufeili <mufeili1996@gmail.com>
parent 28117cd9
...@@ -91,6 +91,27 @@ def test_gat_tox21(): ...@@ -91,6 +91,27 @@ def test_gat_tox21():
remove_file('GAT_Tox21_pre_trained.pth') remove_file('GAT_Tox21_pre_trained.pth')
def test_weave_tox21():
if torch.cuda.is_available():
device = torch.device('cuda:0')
else:
device = torch.device('cpu')
node_featurizer = WeaveAtomFeaturizer()
edge_featurizer = WeaveEdgeFeaturizer(max_distance=2)
g1 = smiles_to_complete_graph('CO', node_featurizer=node_featurizer,
edge_featurizer=edge_featurizer, add_self_loop=True)
g2 = smiles_to_complete_graph('CCO', node_featurizer=node_featurizer,
edge_featurizer=edge_featurizer, add_self_loop=True)
bg = dgl.batch([g1, g2])
model = load_pretrained('Weave_Tox21').to(device)
model(bg.to(device), bg.ndata.pop('h').to(device), bg.edata.pop('e').to(device))
model.eval()
model(g1.to(device), g1.ndata.pop('h').to(device), g1.edata.pop('e').to(device))
remove_file('Weave_Tox21_pre_trained.pth')
def chirality(atom): def chirality(atom):
try: try:
return one_hot_encoding(atom.GetProp('_CIPCode'), ['R', 'S']) + \ return one_hot_encoding(atom.GetProp('_CIPCode'), ['R', 'S']) + \
......
import dgl import dgl
import torch import torch
import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from dgl import DGLGraph from dgl import DGLGraph
...@@ -232,6 +233,36 @@ def test_mpnn_predictor(): ...@@ -232,6 +233,36 @@ def test_mpnn_predictor():
assert mpnn_predictor(bg, batch_node_feats, batch_edge_feats).shape == \ assert mpnn_predictor(bg, batch_node_feats, batch_edge_feats).shape == \
torch.Size([2, 2]) torch.Size([2, 2])
def test_weave_predictor():
if torch.cuda.is_available():
device = torch.device('cuda:0')
else:
device = torch.device('cpu')
bg, batch_node_feats, batch_edge_feats = test_graph4()
bg, batch_node_feats, batch_edge_feats = bg.to(device), batch_node_feats.to(device), \
batch_edge_feats.to(device)
# Test default setting
weave_predictor = WeavePredictor(node_in_feats=1,
edge_in_feats=2).to(device)
assert weave_predictor(bg, batch_node_feats, batch_edge_feats).shape == \
torch.Size([2, 1])
# Test configured setting
weave_predictor = WeavePredictor(node_in_feats=1,
edge_in_feats=2,
num_gnn_layers=2,
gnn_hidden_feats=10,
gnn_activation=F.relu,
graph_feats=128,
gaussian_expand=True,
gaussian_memberships=None,
readout_activation=nn.Tanh(),
n_tasks=2).to(device)
assert weave_predictor(bg, batch_node_feats, batch_edge_feats).shape == \
torch.Size([2, 2])
if __name__ == '__main__': if __name__ == '__main__':
test_mlp_predictor() test_mlp_predictor()
test_gcn_predictor() test_gcn_predictor()
...@@ -240,3 +271,4 @@ if __name__ == '__main__': ...@@ -240,3 +271,4 @@ if __name__ == '__main__':
test_schnet_predictor() test_schnet_predictor()
test_mgcn_predictor() test_mgcn_predictor()
test_mpnn_predictor() test_mpnn_predictor()
test_weave_predictor()
...@@ -79,7 +79,27 @@ def test_mlp_readout(): ...@@ -79,7 +79,27 @@ def test_mlp_readout():
assert model(g, node_feats).shape == torch.Size([1, 3]) assert model(g, node_feats).shape == torch.Size([1, 3])
assert model(bg, batch_node_feats).shape == torch.Size([2, 3]) assert model(bg, batch_node_feats).shape == torch.Size([2, 3])
def test_weave_readout():
if torch.cuda.is_available():
device = torch.device('cuda:0')
else:
device = torch.device('cpu')
g, node_feats = test_graph1()
g, node_feats = g.to(device), node_feats.to(device)
bg, batch_node_feats = test_graph2()
bg, batch_node_feats = bg.to(device), batch_node_feats.to(device)
model = WeaveGather(node_in_feats=1).to(device)
assert model(g, node_feats).shape == torch.Size([1, 1])
assert model(bg, batch_node_feats).shape == torch.Size([2, 1])
model = WeaveGather(node_in_feats=1, gaussian_expand=False).to(device)
assert model(g, node_feats).shape == torch.Size([1, 1])
assert model(bg, batch_node_feats).shape == torch.Size([2, 1])
if __name__ == '__main__': if __name__ == '__main__':
test_weighted_sum_and_max() test_weighted_sum_and_max()
test_attentive_fp_readout() test_attentive_fp_readout()
test_mlp_readout() test_mlp_readout()
test_weave_readout()
...@@ -160,21 +160,28 @@ class TestAtomFeaturizer(BaseAtomFeaturizer): ...@@ -160,21 +160,28 @@ class TestAtomFeaturizer(BaseAtomFeaturizer):
def test_base_atom_featurizer(): def test_base_atom_featurizer():
test_featurizer = TestAtomFeaturizer() test_featurizer = TestAtomFeaturizer()
assert test_featurizer.feat_size('h1') == 11
assert test_featurizer.feat_size('h2') == 5
mol = test_mol1() mol = test_mol1()
feats = test_featurizer(mol) feats = test_featurizer(mol)
torch.allclose(feats['h1'], torch.tensor([[0., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0.], assert torch.allclose(feats['h1'],
torch.tensor([[0., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0.],
[0., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0.], [0., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0.],
[0., 0., 1., 0., 0., 0., 0., 0., 1., 0., 0.]])) [0., 0., 1., 0., 0., 0., 0., 0., 1., 0., 0.]]))
torch.allclose(feats['h2'], torch.tensor([[1., 0., 0., 0., 0.], assert torch.allclose(feats['h2'],
torch.tensor([[1., 0., 0., 0., 0.],
[1., 0., 0., 0., 0.], [1., 0., 0., 0., 0.],
[1., 0., 0., 0., 0.]])) [1., 0., 0., 0., 0.]]))
def test_canonical_atom_featurizer(): def test_canonical_atom_featurizer():
test_featurizer = CanonicalAtomFeaturizer() test_featurizer = CanonicalAtomFeaturizer()
assert test_featurizer.feat_size() == 74
assert test_featurizer.feat_size('h') == 74
mol = test_mol1() mol = test_mol1()
feats = test_featurizer(mol) feats = test_featurizer(mol)
assert list(feats.keys()) == ['h'] assert list(feats.keys()) == ['h']
torch.allclose(feats['h'], torch.tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., assert torch.allclose(feats['h'],
torch.tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.,
...@@ -196,6 +203,29 @@ def test_canonical_atom_featurizer(): ...@@ -196,6 +203,29 @@ def test_canonical_atom_featurizer():
0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0.,
0., 0.]])) 0., 0.]]))
def test_weave_atom_featurizer():
featurizer = WeaveAtomFeaturizer()
assert featurizer.feat_size() == 27
mol = test_mol1()
feats = featurizer(mol)
assert list(feats.keys()) == ['h']
assert torch.allclose(feats['h'],
torch.tensor([[0.0000, 1.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, -0.0418, 0.0000, 0.0000, 0.0000,
1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000],
[0.0000, 1.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0402, 0.0000, 0.0000, 0.0000,
1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 1.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, -0.3967, 0.0000, 0.0000, 0.0000,
1.0000, 1.0000, 1.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000]]), rtol=1e-3)
def test_bond_type_one_hot(): def test_bond_type_one_hot():
mol = test_mol1() mol = test_mol1()
assert bond_type_one_hot(mol.GetBondWithIdx(0)) == [1, 0, 0, 0] assert bond_type_one_hot(mol.GetBondWithIdx(0)) == [1, 0, 0, 0]
...@@ -241,6 +271,8 @@ class TestBondFeaturizer(BaseBondFeaturizer): ...@@ -241,6 +271,8 @@ class TestBondFeaturizer(BaseBondFeaturizer):
def test_base_bond_featurizer(): def test_base_bond_featurizer():
test_featurizer = TestBondFeaturizer() test_featurizer = TestBondFeaturizer()
assert test_featurizer.feat_size('h1') == 2
assert test_featurizer.feat_size('h2') == 6
mol = test_mol1() mol = test_mol1()
feats = test_featurizer(mol) feats = test_featurizer(mol)
assert torch.allclose(feats['h1'], torch.tensor([[0., 0.], [0., 0.], [0., 0.], [0., 0.]])) assert torch.allclose(feats['h1'], torch.tensor([[0., 0.], [0., 0.], [0., 0.], [0., 0.]]))
...@@ -251,6 +283,8 @@ def test_base_bond_featurizer(): ...@@ -251,6 +283,8 @@ def test_base_bond_featurizer():
def test_canonical_bond_featurizer(): def test_canonical_bond_featurizer():
test_featurizer = CanonicalBondFeaturizer() test_featurizer = CanonicalBondFeaturizer()
assert test_featurizer.feat_size() == 12
assert test_featurizer.feat_size('e') == 12
mol = test_mol1() mol = test_mol1()
feats = test_featurizer(mol) feats = test_featurizer(mol)
assert torch.allclose(feats['e'], torch.tensor( assert torch.allclose(feats['e'], torch.tensor(
...@@ -259,6 +293,22 @@ def test_canonical_bond_featurizer(): ...@@ -259,6 +293,22 @@ def test_canonical_bond_featurizer():
[1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.], [1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
[1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.]])) [1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.]]))
def test_weave_edge_featurizer():
test_featurizer = WeaveEdgeFeaturizer()
assert test_featurizer.feat_size() == 12
mol = test_mol1()
feats = test_featurizer(mol)
assert torch.allclose(feats['e'],
torch.tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[1., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
[1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[1., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[1., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
[1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[1., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]))
if __name__ == '__main__': if __name__ == '__main__':
test_one_hot_encoding() test_one_hot_encoding()
test_atom_type_one_hot() test_atom_type_one_hot()
...@@ -287,6 +337,7 @@ if __name__ == '__main__': ...@@ -287,6 +337,7 @@ if __name__ == '__main__':
test_concat_featurizer() test_concat_featurizer()
test_base_atom_featurizer() test_base_atom_featurizer()
test_canonical_atom_featurizer() test_canonical_atom_featurizer()
test_weave_atom_featurizer()
test_bond_type_one_hot() test_bond_type_one_hot()
test_bond_is_conjugated_one_hot() test_bond_is_conjugated_one_hot()
test_bond_is_conjugated() test_bond_is_conjugated()
...@@ -295,3 +346,4 @@ if __name__ == '__main__': ...@@ -295,3 +346,4 @@ if __name__ == '__main__':
test_bond_stereo_one_hot() test_bond_stereo_one_hot()
test_base_bond_featurizer() test_base_bond_featurizer()
test_canonical_bond_featurizer() test_canonical_bond_featurizer()
test_weave_edge_featurizer()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment