"tutorials/vscode:/vscode.git/clone" did not exist on "8bc01c6309008ef60640c3da507babbe0024dcfc"
Unverified Commit 9df8cd32 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[Model Zoo] Fix JTNN (#843)

* Update

* Update

* Update

* Update

* Update
parent 4e0e6697
import rdkit
import rdkit.Chem as Chem
from scipy.sparse import csr_matrix
from scipy.sparse.csgraph import minimum_spanning_tree
from collections import defaultdict
from rdkit.Chem.EnumerateStereoisomers import EnumerateStereoisomers, StereoEnumerationOptions
from rdkit.Chem.EnumerateStereoisomers import EnumerateStereoisomers
MST_MAX_WEIGHT = 100
MAX_NCAND = 2000
......@@ -29,7 +28,8 @@ def decode_stereo(smiles2D):
dec_isomers = [Chem.MolFromSmiles(Chem.MolToSmiles(mol, isomericSmiles=True)) for mol in dec_isomers]
smiles3D = [Chem.MolToSmiles(mol, isomericSmiles=True) for mol in dec_isomers]
chiralN = [atom.GetIdx() for atom in dec_isomers[0].GetAtoms() if int(atom.GetChiralTag()) > 0 and atom.GetSymbol() == "N"]
chiralN = [atom.GetIdx() for atom in dec_isomers[0].GetAtoms()
if int(atom.GetChiralTag()) > 0 and atom.GetSymbol() == "N"]
if len(chiralN) > 0:
for mol in dec_isomers:
for idx in chiralN:
......@@ -117,7 +117,8 @@ def tree_decomp(mol):
cnei = nei_list[atom]
bonds = [c for c in cnei if len(cliques[c]) == 2]
rings = [c for c in cnei if len(cliques[c]) > 4]
if len(bonds) > 2 or (len(bonds) == 2 and len(cnei) > 2): #In general, if len(cnei) >= 3, a singleton should be added, but 1 bond + 2 ring is currently not dealt with.
# In general, if len(cnei) >= 3, a singleton should be added, but 1 bond + 2 ring is currently not dealt with.
if len(bonds) > 2 or (len(bonds) == 2 and len(cnei) > 2):
cliques.append([atom])
c2 = len(cliques) - 1
for c1 in cnei:
......@@ -242,11 +243,13 @@ def enum_attach_nx(ctr_mol, nei_node, amap, singletons):
for b1 in ctr_bonds:
for b2 in nei_mol.GetBonds():
if ring_bond_equal(b1, b2):
new_amap = amap + [(nei_idx, b1.GetBeginAtom().GetIdx(), b2.GetBeginAtom().GetIdx()), (nei_idx, b1.GetEndAtom().GetIdx(), b2.GetEndAtom().GetIdx())]
new_amap = amap + [(nei_idx, b1.GetBeginAtom().GetIdx(), b2.GetBeginAtom().GetIdx()),
(nei_idx, b1.GetEndAtom().GetIdx(), b2.GetEndAtom().GetIdx())]
att_confs.append( new_amap )
if ring_bond_equal(b1, b2, reverse=True):
new_amap = amap + [(nei_idx, b1.GetBeginAtom().GetIdx(), b2.GetEndAtom().GetIdx()), (nei_idx, b1.GetEndAtom().GetIdx(), b2.GetBeginAtom().GetIdx())]
new_amap = amap + [(nei_idx, b1.GetBeginAtom().GetIdx(), b2.GetEndAtom().GetIdx()),
(nei_idx, b1.GetEndAtom().GetIdx(), b2.GetBeginAtom().GetIdx())]
att_confs.append( new_amap )
return att_confs
......
import torch
from torch.utils.data import Dataset
import numpy as np
import dgl
from dgl.data.utils import download, extract_archive, get_download_dir
......
import torch
import torch.nn as nn
from .nnutils import cuda
from .chemutils import get_mol
#from mpn import atom_features, bond_features, ATOM_FDIM, BOND_FDIM
import rdkit.Chem as Chem
from dgl import DGLGraph, batch, unbatch, mean_nodes
from dgl import DGLGraph, mean_nodes
import dgl.function as DGLF
from .line_profiler_integration import profile
import os
import numpy as np
ELEM_LIST = ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na', 'Ca', 'Fe', 'Al', 'I', 'B', 'K', 'Se', 'Zn', 'H', 'Cu', 'Mn', 'unknown']
ELEM_LIST = ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na', 'Ca',
'Fe', 'Al', 'I', 'B', 'K', 'Se', 'Zn', 'H', 'Cu', 'Mn', 'unknown']
ATOM_FDIM = len(ELEM_LIST) + 6 + 5 + 1
BOND_FDIM = 5
......
import torch
import torch.nn as nn
import torch.nn.functional as F
from .mol_tree import Vocab
from .mol_tree_nx import DGLMolTree
from .chemutils import enum_assemble_nx, get_mol
from .nnutils import GRUUpdate, cuda
import copy
import itertools
from dgl import batch, dfs_labeled_edges_generator
import dgl.function as DGLF
import networkx as nx
from .line_profiler_integration import profile
import numpy as np
MAX_NB = 8
......@@ -265,7 +260,6 @@ class DGLJTNNDecoder(nn.Module):
for step in range(MAX_DECODE_LEN):
u, u_slots = stack[-1]
udata = mol_tree.nodes[u].data
wid = udata['wid']
x = udata['x']
h = udata['h']
......
import torch
import torch.nn as nn
from collections import deque
from .mol_tree import Vocab
from .nnutils import GRUUpdate, cuda
import itertools
import networkx as nx
from dgl import batch, unbatch, bfs_edges_generator
from dgl import batch, bfs_edges_generator
import dgl.function as DGLF
from .line_profiler_integration import profile
import numpy as np
MAX_NB = 8
......
import torch
import torch.nn as nn
import torch.nn.functional as F
from .mol_tree import Vocab
from .nnutils import cuda, move_dgl_to_cuda
from .chemutils import set_atommap, copy_edit_mol, enum_assemble_nx, \
attach_mols_nx, decode_stereo
......@@ -11,13 +10,9 @@ from .mpn import DGLMPN
from .mpn import mol2dgl_single as mol2dgl_enc
from .jtmpn import DGLJTMPN
from .jtmpn import mol2dgl_single as mol2dgl_dec
from .line_profiler_integration import profile
import rdkit
import rdkit.Chem as Chem
from rdkit import DataStructs
from rdkit.Chem import AllChem
import copy, math
import copy
from dgl import batch, unbatch
......@@ -102,7 +97,6 @@ class DGLJTNNVAE(nn.Module):
assm_loss, assm_acc = self.assm(mol_batch, mol_tree_batch, mol_vec)
stereo_loss, stereo_acc = self.stereo(mol_batch, mol_vec)
all_vec = torch.cat([tree_vec, mol_vec], dim=1)
loss = word_loss + topo_loss + assm_loss + 2 * stereo_loss + beta * kl_loss
return loss, kl_loss, word_acc, topo_acc, assm_acc, stereo_acc
......
import rdkit
import rdkit.Chem as Chem
import copy
......
......@@ -3,7 +3,6 @@ import rdkit.Chem as Chem
from .chemutils import get_clique_mol, tree_decomp, get_mol, get_smiles, \
set_atommap, enum_assemble_nx, decode_stereo
import numpy as np
from .line_profiler_integration import profile
class DGLMolTree(DGLGraph):
def __init__(self, smiles):
......
......@@ -2,16 +2,12 @@ import torch
import torch.nn as nn
import rdkit.Chem as Chem
import torch.nn.functional as F
from .nnutils import *
from .chemutils import get_mol
from networkx import Graph, DiGraph, convert_node_labels_to_integers
from dgl import DGLGraph, batch, unbatch, mean_nodes
from dgl import DGLGraph, mean_nodes
import dgl.function as DGLF
from functools import partial
from .line_profiler_integration import profile
import numpy as np
ELEM_LIST = ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na', 'Ca', 'Fe', 'Al', 'I', 'B', 'K', 'Se', 'Zn', 'H', 'Cu', 'Mn', 'unknown']
ELEM_LIST = ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na', 'Ca',
'Fe', 'Al', 'I', 'B', 'K', 'Se', 'Zn', 'H', 'Cu', 'Mn', 'unknown']
ATOM_FDIM = len(ELEM_LIST) + 6 + 5 + 4 + 1
BOND_FDIM = 5 + 6
......
import torch
import torch.nn as nn
from torch.autograd import Variable
import os
......
......@@ -86,7 +86,9 @@ are also two accompanying review papers that are well written [7], [8].
### Models
- **Deep Generative Models of Graphs (DGMG)** [11]: A very general framework for graph distribution learning by
progressively adding atoms and bonds.
- **Junction Tree Variational Autoencoder for Molecular Graph Generation (JTNN)** [13]:
- **Junction Tree Variational Autoencoder for Molecular Graph Generation (JTNN)** [13]: JTNNs are able to incrementally
expand molecules while maintaining chemical valency at every step. They can be used for both molecule generation and
optimization.
### Example Usage of Pre-trained Models
......
......@@ -35,23 +35,16 @@ encoded nodes(atoms) and edges(bonds), and other information for model to use.
To start training, use `python train.py`. By default, the script will use ZINC dataset
with preprocessed vocabulary, and save model checkpoint at the current working directory.
```
-s SAVE_PATH, --save_dir SAVE_PATH
Path to save checkpoint models, default to be current
working directory (default: ./)
-m MODEL_PATH, --model MODEL_PATH
Path to load pre-trained model (default: None)
-b BATCH_SIZE, --batch BATCH_SIZE
Batch size (default: 40)
-w HIDDEN_SIZE, --hidden HIDDEN_SIZE
Size of representation vectors (default: 200)
-l LATENT_SIZE, --latent LATENT_SIZE
Latent Size of node(atom) features and edge(atom)
features (default: 56)
-d DEPTH, --depth DEPTH
Depth of message passing hops (default: 3)
-z BETA, --beta BETA Coefficient of KL Divergence term (default: 1.0)
-q LR, --lr LR Learning Rate (default: 0.001)
-T, --test Add this flag to run test mode (default: False)
-s SAVE_PATH, Path to save checkpoint models, default to be current
working directory (default: ./)
-m MODEL_PATH, Path to load pre-trained model (default: None)
-b BATCH_SIZE, Batch size (default: 40)
-w HIDDEN_SIZE, Size of representation vectors (default: 200)
-l LATENT_SIZE, Latent Size of node(atom) features and edge(atom)
features (default: 56)
-d DEPTH, Depth of message passing hops (default: 3)
-z BETA, Coefficient of KL Divergence term (default: 1.0)
-q LR, Learning Rate (default: 0.001)
```
Model will be saved periodically.
......@@ -70,22 +63,17 @@ If you want to use your own dataset, please create a file contains one SMILES a
To start evaluation, use `python reconstruct_eval.py`, and following arguments
```
-t TRAIN, --train TRAIN
Training file name (default: test)
-m MODEL_PATH, --model MODEL_PATH
Pre-trained model to be loaded for evalutaion. If not
specified, would use pre-trained model from model zoo
(default: None)
-w HIDDEN_SIZE, --hidden HIDDEN_SIZE
Hidden size of representation vector, should be
consistent with pre-trained model (default: 450)
-l LATENT_SIZE, --latent LATENT_SIZE
Latent Size of node(atom) features and edge(atom)
features, should be consistent with pre-trained model
(default: 56)
-d DEPTH, --depth DEPTH
Depth of message passing hops, should be consistent
with pre-trained model (default: 3)
-t TRAIN, Training file name (default: test)
-m MODEL_PATH, Pre-trained model to be loaded for evalutaion. If not
specified, would use pre-trained model from model zoo
(default: None)
-w HIDDEN_SIZE, Hidden size of representation vector, should be
consistent with pre-trained model (default: 450)
-l LATENT_SIZE, Latent Size of node(atom) features and edge(atom)
features, should be consistent with pre-trained model
(default: 56)
-d DEPTH, Depth of message passing hops, should be consistent
with pre-trained model (default: 3)
```
And it would print out the success rate of reconstructing the same molecules.
......
import rdkit
import rdkit.Chem as Chem
import torch
from scipy.sparse import csr_matrix
from scipy.sparse.csgraph import minimum_spanning_tree
from collections import defaultdict
from rdkit.Chem.EnumerateStereoisomers import EnumerateStereoisomers, StereoEnumerationOptions
from rdkit.Chem.EnumerateStereoisomers import EnumerateStereoisomers
from dgl import DGLGraph
ELEM_LIST = ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na',
......@@ -406,16 +405,12 @@ def mol2dgl_dec(cand_batch):
tree_mess_target_edges = [] # these edges on candidate graphs
tree_mess_target_nodes = []
n_nodes = 0
n_edges = 0
atom_x = []
bond_x = []
for mol, mol_tree, ctr_node_id in cand_batch:
n_atoms = mol.GetNumAtoms()
n_bonds = mol.GetNumBonds()
ctr_node = mol_tree.nodes_dict[ctr_node_id]
ctr_bid = ctr_node['idx']
g = DGLGraph()
for i, atom in enumerate(mol.GetAtoms()):
......
......@@ -7,7 +7,8 @@ import os
from .mol_tree import Vocab, DGLMolTree
from .chemutils import mol2dgl_dec, mol2dgl_enc
ELEM_LIST = ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na', 'Ca', 'Fe', 'Al', 'I', 'B', 'K', 'Se', 'Zn', 'H', 'Cu', 'Mn', 'unknown']
ELEM_LIST = ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na', 'Ca',
'Fe', 'Al', 'I', 'B', 'K', 'Se', 'Zn', 'H', 'Cu', 'Mn', 'unknown']
ATOM_FDIM_DEC = len(ELEM_LIST) + 6 + 5 + 1
BOND_FDIM_DEC = 5
......
import rdkit
import rdkit.Chem as Chem
import copy
import numpy as np
from dgl import DGLGraph
......
......@@ -5,9 +5,8 @@ import torch.optim.lr_scheduler as lr_scheduler
from dgl import model_zoo
from torch.utils.data import DataLoader
import math, random, sys
import sys
import argparse
from collections import deque
import rdkit
from jtnn import *
......@@ -42,8 +41,6 @@ parser.add_argument("-z", "--beta", dest="beta", default=1.0,
help="Coefficient of KL Divergence term")
parser.add_argument("-q", "--lr", dest="lr", default=1e-3,
help="Learning Rate")
parser.add_argument("-T", "--test", dest="test", action="store_true",
help="Add this flag to run test mode")
args = parser.parse_args()
dataset = JTNNDataset(data=args.train, vocab=args.vocab, training=True)
......@@ -131,35 +128,8 @@ def train():
print("learning rate: %.6f" % scheduler.get_lr()[0])
torch.save(model.state_dict(), args.save_path + "/model.iter-" + str(epoch))
def test():
dataset.training = False
dataloader = DataLoader(
dataset,
batch_size=1,
shuffle=False,
num_workers=0,
collate_fn=JTNNCollator(vocab, False),
drop_last=True,
worker_init_fn=worker_init_fn)
# Just an example of molecule decoding; in reality you may want to sample
# tree and molecule vectors.
for it, batch in enumerate(dataloader):
gt_smiles = batch['mol_trees'][0].smiles
print(gt_smiles)
model.move_to_cuda(batch)
_, tree_vec, mol_vec = model.encode(batch)
tree_vec, mol_vec, _, _ = model.sample(tree_vec, mol_vec)
smiles = model.decode(tree_vec, mol_vec)
print(smiles)
if __name__ == '__main__':
if args.test:
test()
else:
train()
train()
print('# passes:', model.n_passes)
print('Total # nodes processed:', model.n_nodes_total)
......
# pylint: disable=C0111, C0103, E1101, W0611, W0612, W0703, C0200, R1710
from collections import defaultdict
import rdkit
import rdkit.Chem as Chem
from rdkit.Chem.EnumerateStereoisomers import (EnumerateStereoisomers,
StereoEnumerationOptions)
from rdkit.Chem.EnumerateStereoisomers import EnumerateStereoisomers
from scipy.sparse import csr_matrix
from scipy.sparse.csgraph import minimum_spanning_tree
......
......@@ -9,7 +9,6 @@ import torch.nn as nn
import dgl.function as DGLF
from dgl import DGLGraph, mean_nodes
from .chemutils import get_mol
from .nnutils import cuda
ELEM_LIST = ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na',
......@@ -57,16 +56,12 @@ def mol2dgl_single(cand_batch):
tree_mess_target_edges = [] # these edges on candidate graphs
tree_mess_target_nodes = []
n_nodes = 0
n_edges = 0
atom_x = []
bond_x = []
for mol, mol_tree, ctr_node_id in cand_batch:
n_atoms = mol.GetNumAtoms()
n_bonds = mol.GetNumBonds()
ctr_node = mol_tree.nodes_dict[ctr_node_id]
ctr_bid = ctr_node['idx']
g = DGLGraph()
for i, atom in enumerate(mol.GetAtoms()):
......
# pylint: disable=C0111, C0103, E1101, W0611, W0612
import copy
import itertools
import networkx as nx
import numpy as np
import torch
import torch.nn as nn
......@@ -12,7 +8,6 @@ import dgl.function as DGLF
from dgl import batch, dfs_labeled_edges_generator
from .chemutils import enum_assemble_nx, get_mol
from .mol_tree import Vocab
from .mol_tree_nx import DGLMolTree
from .nnutils import GRUUpdate, cuda
......@@ -274,7 +269,6 @@ class DGLJTNNDecoder(nn.Module):
stack.append((0, self.vocab.get_slots(root_wid)))
all_nodes = {0: root_node_dict}
h = {}
first = True
new_node_id = 0
new_edge_id = 0
......@@ -282,7 +276,6 @@ class DGLJTNNDecoder(nn.Module):
for step in range(MAX_DECODE_LEN):
u, u_slots = stack[-1]
udata = mol_tree.nodes[u].data
wid = udata['wid']
x = udata['x']
h = udata['h']
......
# pylint: disable=C0111, C0103, E1101, W0611, W0612
import itertools
from collections import deque
import networkx as nx
import numpy as np
import torch
import torch.nn as nn
import dgl.function as DGLF
from dgl import batch, bfs_edges_generator, unbatch
from dgl import batch, bfs_edges_generator
from .mol_tree import Vocab
from .nnutils import GRUUpdate, cuda
MAX_NB = 8
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment