Unverified Commit 9df8cd32 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[Model Zoo] Fix JTNN (#843)

* Update

* Update

* Update

* Update

* Update
parent 4e0e6697
import rdkit
import rdkit.Chem as Chem import rdkit.Chem as Chem
from scipy.sparse import csr_matrix from scipy.sparse import csr_matrix
from scipy.sparse.csgraph import minimum_spanning_tree from scipy.sparse.csgraph import minimum_spanning_tree
from collections import defaultdict from collections import defaultdict
from rdkit.Chem.EnumerateStereoisomers import EnumerateStereoisomers, StereoEnumerationOptions from rdkit.Chem.EnumerateStereoisomers import EnumerateStereoisomers
MST_MAX_WEIGHT = 100 MST_MAX_WEIGHT = 100
MAX_NCAND = 2000 MAX_NCAND = 2000
...@@ -29,7 +28,8 @@ def decode_stereo(smiles2D): ...@@ -29,7 +28,8 @@ def decode_stereo(smiles2D):
dec_isomers = [Chem.MolFromSmiles(Chem.MolToSmiles(mol, isomericSmiles=True)) for mol in dec_isomers] dec_isomers = [Chem.MolFromSmiles(Chem.MolToSmiles(mol, isomericSmiles=True)) for mol in dec_isomers]
smiles3D = [Chem.MolToSmiles(mol, isomericSmiles=True) for mol in dec_isomers] smiles3D = [Chem.MolToSmiles(mol, isomericSmiles=True) for mol in dec_isomers]
chiralN = [atom.GetIdx() for atom in dec_isomers[0].GetAtoms() if int(atom.GetChiralTag()) > 0 and atom.GetSymbol() == "N"] chiralN = [atom.GetIdx() for atom in dec_isomers[0].GetAtoms()
if int(atom.GetChiralTag()) > 0 and atom.GetSymbol() == "N"]
if len(chiralN) > 0: if len(chiralN) > 0:
for mol in dec_isomers: for mol in dec_isomers:
for idx in chiralN: for idx in chiralN:
...@@ -117,7 +117,8 @@ def tree_decomp(mol): ...@@ -117,7 +117,8 @@ def tree_decomp(mol):
cnei = nei_list[atom] cnei = nei_list[atom]
bonds = [c for c in cnei if len(cliques[c]) == 2] bonds = [c for c in cnei if len(cliques[c]) == 2]
rings = [c for c in cnei if len(cliques[c]) > 4] rings = [c for c in cnei if len(cliques[c]) > 4]
if len(bonds) > 2 or (len(bonds) == 2 and len(cnei) > 2): #In general, if len(cnei) >= 3, a singleton should be added, but 1 bond + 2 ring is currently not dealt with. # In general, if len(cnei) >= 3, a singleton should be added, but 1 bond + 2 ring is currently not dealt with.
if len(bonds) > 2 or (len(bonds) == 2 and len(cnei) > 2):
cliques.append([atom]) cliques.append([atom])
c2 = len(cliques) - 1 c2 = len(cliques) - 1
for c1 in cnei: for c1 in cnei:
...@@ -242,11 +243,13 @@ def enum_attach_nx(ctr_mol, nei_node, amap, singletons): ...@@ -242,11 +243,13 @@ def enum_attach_nx(ctr_mol, nei_node, amap, singletons):
for b1 in ctr_bonds: for b1 in ctr_bonds:
for b2 in nei_mol.GetBonds(): for b2 in nei_mol.GetBonds():
if ring_bond_equal(b1, b2): if ring_bond_equal(b1, b2):
new_amap = amap + [(nei_idx, b1.GetBeginAtom().GetIdx(), b2.GetBeginAtom().GetIdx()), (nei_idx, b1.GetEndAtom().GetIdx(), b2.GetEndAtom().GetIdx())] new_amap = amap + [(nei_idx, b1.GetBeginAtom().GetIdx(), b2.GetBeginAtom().GetIdx()),
(nei_idx, b1.GetEndAtom().GetIdx(), b2.GetEndAtom().GetIdx())]
att_confs.append( new_amap ) att_confs.append( new_amap )
if ring_bond_equal(b1, b2, reverse=True): if ring_bond_equal(b1, b2, reverse=True):
new_amap = amap + [(nei_idx, b1.GetBeginAtom().GetIdx(), b2.GetEndAtom().GetIdx()), (nei_idx, b1.GetEndAtom().GetIdx(), b2.GetBeginAtom().GetIdx())] new_amap = amap + [(nei_idx, b1.GetBeginAtom().GetIdx(), b2.GetEndAtom().GetIdx()),
(nei_idx, b1.GetEndAtom().GetIdx(), b2.GetBeginAtom().GetIdx())]
att_confs.append( new_amap ) att_confs.append( new_amap )
return att_confs return att_confs
......
import torch import torch
from torch.utils.data import Dataset from torch.utils.data import Dataset
import numpy as np
import dgl import dgl
from dgl.data.utils import download, extract_archive, get_download_dir from dgl.data.utils import download, extract_archive, get_download_dir
......
import torch import torch
import torch.nn as nn import torch.nn as nn
from .nnutils import cuda from .nnutils import cuda
from .chemutils import get_mol
#from mpn import atom_features, bond_features, ATOM_FDIM, BOND_FDIM
import rdkit.Chem as Chem import rdkit.Chem as Chem
from dgl import DGLGraph, batch, unbatch, mean_nodes from dgl import DGLGraph, mean_nodes
import dgl.function as DGLF import dgl.function as DGLF
from .line_profiler_integration import profile
import os import os
import numpy as np
ELEM_LIST = ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na', 'Ca', 'Fe', 'Al', 'I', 'B', 'K', 'Se', 'Zn', 'H', 'Cu', 'Mn', 'unknown'] ELEM_LIST = ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na', 'Ca',
'Fe', 'Al', 'I', 'B', 'K', 'Se', 'Zn', 'H', 'Cu', 'Mn', 'unknown']
ATOM_FDIM = len(ELEM_LIST) + 6 + 5 + 1 ATOM_FDIM = len(ELEM_LIST) + 6 + 5 + 1
BOND_FDIM = 5 BOND_FDIM = 5
......
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from .mol_tree import Vocab
from .mol_tree_nx import DGLMolTree from .mol_tree_nx import DGLMolTree
from .chemutils import enum_assemble_nx, get_mol from .chemutils import enum_assemble_nx, get_mol
from .nnutils import GRUUpdate, cuda from .nnutils import GRUUpdate, cuda
import copy
import itertools
from dgl import batch, dfs_labeled_edges_generator from dgl import batch, dfs_labeled_edges_generator
import dgl.function as DGLF import dgl.function as DGLF
import networkx as nx
from .line_profiler_integration import profile
import numpy as np import numpy as np
MAX_NB = 8 MAX_NB = 8
...@@ -265,7 +260,6 @@ class DGLJTNNDecoder(nn.Module): ...@@ -265,7 +260,6 @@ class DGLJTNNDecoder(nn.Module):
for step in range(MAX_DECODE_LEN): for step in range(MAX_DECODE_LEN):
u, u_slots = stack[-1] u, u_slots = stack[-1]
udata = mol_tree.nodes[u].data udata = mol_tree.nodes[u].data
wid = udata['wid']
x = udata['x'] x = udata['x']
h = udata['h'] h = udata['h']
......
import torch import torch
import torch.nn as nn import torch.nn as nn
from collections import deque
from .mol_tree import Vocab
from .nnutils import GRUUpdate, cuda from .nnutils import GRUUpdate, cuda
import itertools from dgl import batch, bfs_edges_generator
import networkx as nx
from dgl import batch, unbatch, bfs_edges_generator
import dgl.function as DGLF import dgl.function as DGLF
from .line_profiler_integration import profile
import numpy as np import numpy as np
MAX_NB = 8 MAX_NB = 8
......
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from .mol_tree import Vocab
from .nnutils import cuda, move_dgl_to_cuda from .nnutils import cuda, move_dgl_to_cuda
from .chemutils import set_atommap, copy_edit_mol, enum_assemble_nx, \ from .chemutils import set_atommap, copy_edit_mol, enum_assemble_nx, \
attach_mols_nx, decode_stereo attach_mols_nx, decode_stereo
...@@ -11,13 +10,9 @@ from .mpn import DGLMPN ...@@ -11,13 +10,9 @@ from .mpn import DGLMPN
from .mpn import mol2dgl_single as mol2dgl_enc from .mpn import mol2dgl_single as mol2dgl_enc
from .jtmpn import DGLJTMPN from .jtmpn import DGLJTMPN
from .jtmpn import mol2dgl_single as mol2dgl_dec from .jtmpn import mol2dgl_single as mol2dgl_dec
from .line_profiler_integration import profile
import rdkit
import rdkit.Chem as Chem import rdkit.Chem as Chem
from rdkit import DataStructs import copy
from rdkit.Chem import AllChem
import copy, math
from dgl import batch, unbatch from dgl import batch, unbatch
...@@ -102,7 +97,6 @@ class DGLJTNNVAE(nn.Module): ...@@ -102,7 +97,6 @@ class DGLJTNNVAE(nn.Module):
assm_loss, assm_acc = self.assm(mol_batch, mol_tree_batch, mol_vec) assm_loss, assm_acc = self.assm(mol_batch, mol_tree_batch, mol_vec)
stereo_loss, stereo_acc = self.stereo(mol_batch, mol_vec) stereo_loss, stereo_acc = self.stereo(mol_batch, mol_vec)
all_vec = torch.cat([tree_vec, mol_vec], dim=1)
loss = word_loss + topo_loss + assm_loss + 2 * stereo_loss + beta * kl_loss loss = word_loss + topo_loss + assm_loss + 2 * stereo_loss + beta * kl_loss
return loss, kl_loss, word_acc, topo_acc, assm_acc, stereo_acc return loss, kl_loss, word_acc, topo_acc, assm_acc, stereo_acc
......
import rdkit
import rdkit.Chem as Chem import rdkit.Chem as Chem
import copy import copy
......
...@@ -3,7 +3,6 @@ import rdkit.Chem as Chem ...@@ -3,7 +3,6 @@ import rdkit.Chem as Chem
from .chemutils import get_clique_mol, tree_decomp, get_mol, get_smiles, \ from .chemutils import get_clique_mol, tree_decomp, get_mol, get_smiles, \
set_atommap, enum_assemble_nx, decode_stereo set_atommap, enum_assemble_nx, decode_stereo
import numpy as np import numpy as np
from .line_profiler_integration import profile
class DGLMolTree(DGLGraph): class DGLMolTree(DGLGraph):
def __init__(self, smiles): def __init__(self, smiles):
......
...@@ -2,16 +2,12 @@ import torch ...@@ -2,16 +2,12 @@ import torch
import torch.nn as nn import torch.nn as nn
import rdkit.Chem as Chem import rdkit.Chem as Chem
import torch.nn.functional as F import torch.nn.functional as F
from .nnutils import *
from .chemutils import get_mol from .chemutils import get_mol
from networkx import Graph, DiGraph, convert_node_labels_to_integers from dgl import DGLGraph, mean_nodes
from dgl import DGLGraph, batch, unbatch, mean_nodes
import dgl.function as DGLF import dgl.function as DGLF
from functools import partial
from .line_profiler_integration import profile
import numpy as np
ELEM_LIST = ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na', 'Ca', 'Fe', 'Al', 'I', 'B', 'K', 'Se', 'Zn', 'H', 'Cu', 'Mn', 'unknown'] ELEM_LIST = ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na', 'Ca',
'Fe', 'Al', 'I', 'B', 'K', 'Se', 'Zn', 'H', 'Cu', 'Mn', 'unknown']
ATOM_FDIM = len(ELEM_LIST) + 6 + 5 + 4 + 1 ATOM_FDIM = len(ELEM_LIST) + 6 + 5 + 4 + 1
BOND_FDIM = 5 + 6 BOND_FDIM = 5 + 6
......
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.autograd import Variable
import os import os
......
...@@ -86,7 +86,9 @@ are also two accompanying review papers that are well written [7], [8]. ...@@ -86,7 +86,9 @@ are also two accompanying review papers that are well written [7], [8].
### Models ### Models
- **Deep Generative Models of Graphs (DGMG)** [11]: A very general framework for graph distribution learning by - **Deep Generative Models of Graphs (DGMG)** [11]: A very general framework for graph distribution learning by
progressively adding atoms and bonds. progressively adding atoms and bonds.
- **Junction Tree Variational Autoencoder for Molecular Graph Generation (JTNN)** [13]: - **Junction Tree Variational Autoencoder for Molecular Graph Generation (JTNN)** [13]: JTNNs are able to incrementally
expand molecules while maintaining chemical valency at every step. They can be used for both molecule generation and
optimization.
### Example Usage of Pre-trained Models ### Example Usage of Pre-trained Models
......
...@@ -35,23 +35,16 @@ encoded nodes(atoms) and edges(bonds), and other information for model to use. ...@@ -35,23 +35,16 @@ encoded nodes(atoms) and edges(bonds), and other information for model to use.
To start training, use `python train.py`. By default, the script will use ZINC dataset To start training, use `python train.py`. By default, the script will use ZINC dataset
with preprocessed vocabulary, and save model checkpoint at the current working directory. with preprocessed vocabulary, and save model checkpoint at the current working directory.
``` ```
-s SAVE_PATH, --save_dir SAVE_PATH -s SAVE_PATH, Path to save checkpoint models, default to be current
Path to save checkpoint models, default to be current
working directory (default: ./) working directory (default: ./)
-m MODEL_PATH, --model MODEL_PATH -m MODEL_PATH, Path to load pre-trained model (default: None)
Path to load pre-trained model (default: None) -b BATCH_SIZE, Batch size (default: 40)
-b BATCH_SIZE, --batch BATCH_SIZE -w HIDDEN_SIZE, Size of representation vectors (default: 200)
Batch size (default: 40) -l LATENT_SIZE, Latent Size of node(atom) features and edge(atom)
-w HIDDEN_SIZE, --hidden HIDDEN_SIZE
Size of representation vectors (default: 200)
-l LATENT_SIZE, --latent LATENT_SIZE
Latent Size of node(atom) features and edge(atom)
features (default: 56) features (default: 56)
-d DEPTH, --depth DEPTH -d DEPTH, Depth of message passing hops (default: 3)
Depth of message passing hops (default: 3) -z BETA, Coefficient of KL Divergence term (default: 1.0)
-z BETA, --beta BETA Coefficient of KL Divergence term (default: 1.0) -q LR, Learning Rate (default: 0.001)
-q LR, --lr LR Learning Rate (default: 0.001)
-T, --test Add this flag to run test mode (default: False)
``` ```
Model will be saved periodically. Model will be saved periodically.
...@@ -70,21 +63,16 @@ If you want to use your own dataset, please create a file contains one SMILES a ...@@ -70,21 +63,16 @@ If you want to use your own dataset, please create a file contains one SMILES a
To start evaluation, use `python reconstruct_eval.py`, and following arguments To start evaluation, use `python reconstruct_eval.py`, and following arguments
``` ```
-t TRAIN, --train TRAIN -t TRAIN, Training file name (default: test)
Training file name (default: test) -m MODEL_PATH, Pre-trained model to be loaded for evalutaion. If not
-m MODEL_PATH, --model MODEL_PATH
Pre-trained model to be loaded for evalutaion. If not
specified, would use pre-trained model from model zoo specified, would use pre-trained model from model zoo
(default: None) (default: None)
-w HIDDEN_SIZE, --hidden HIDDEN_SIZE -w HIDDEN_SIZE, Hidden size of representation vector, should be
Hidden size of representation vector, should be
consistent with pre-trained model (default: 450) consistent with pre-trained model (default: 450)
-l LATENT_SIZE, --latent LATENT_SIZE -l LATENT_SIZE, Latent Size of node(atom) features and edge(atom)
Latent Size of node(atom) features and edge(atom)
features, should be consistent with pre-trained model features, should be consistent with pre-trained model
(default: 56) (default: 56)
-d DEPTH, --depth DEPTH -d DEPTH, Depth of message passing hops, should be consistent
Depth of message passing hops, should be consistent
with pre-trained model (default: 3) with pre-trained model (default: 3)
``` ```
......
import rdkit
import rdkit.Chem as Chem import rdkit.Chem as Chem
import torch import torch
from scipy.sparse import csr_matrix from scipy.sparse import csr_matrix
from scipy.sparse.csgraph import minimum_spanning_tree from scipy.sparse.csgraph import minimum_spanning_tree
from collections import defaultdict from collections import defaultdict
from rdkit.Chem.EnumerateStereoisomers import EnumerateStereoisomers, StereoEnumerationOptions from rdkit.Chem.EnumerateStereoisomers import EnumerateStereoisomers
from dgl import DGLGraph from dgl import DGLGraph
ELEM_LIST = ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na', ELEM_LIST = ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na',
...@@ -406,16 +405,12 @@ def mol2dgl_dec(cand_batch): ...@@ -406,16 +405,12 @@ def mol2dgl_dec(cand_batch):
tree_mess_target_edges = [] # these edges on candidate graphs tree_mess_target_edges = [] # these edges on candidate graphs
tree_mess_target_nodes = [] tree_mess_target_nodes = []
n_nodes = 0 n_nodes = 0
n_edges = 0
atom_x = [] atom_x = []
bond_x = [] bond_x = []
for mol, mol_tree, ctr_node_id in cand_batch: for mol, mol_tree, ctr_node_id in cand_batch:
n_atoms = mol.GetNumAtoms() n_atoms = mol.GetNumAtoms()
n_bonds = mol.GetNumBonds()
ctr_node = mol_tree.nodes_dict[ctr_node_id]
ctr_bid = ctr_node['idx']
g = DGLGraph() g = DGLGraph()
for i, atom in enumerate(mol.GetAtoms()): for i, atom in enumerate(mol.GetAtoms()):
......
...@@ -7,7 +7,8 @@ import os ...@@ -7,7 +7,8 @@ import os
from .mol_tree import Vocab, DGLMolTree from .mol_tree import Vocab, DGLMolTree
from .chemutils import mol2dgl_dec, mol2dgl_enc from .chemutils import mol2dgl_dec, mol2dgl_enc
ELEM_LIST = ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na', 'Ca', 'Fe', 'Al', 'I', 'B', 'K', 'Se', 'Zn', 'H', 'Cu', 'Mn', 'unknown'] ELEM_LIST = ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na', 'Ca',
'Fe', 'Al', 'I', 'B', 'K', 'Se', 'Zn', 'H', 'Cu', 'Mn', 'unknown']
ATOM_FDIM_DEC = len(ELEM_LIST) + 6 + 5 + 1 ATOM_FDIM_DEC = len(ELEM_LIST) + 6 + 5 + 1
BOND_FDIM_DEC = 5 BOND_FDIM_DEC = 5
......
import rdkit
import rdkit.Chem as Chem
import copy import copy
import numpy as np import numpy as np
from dgl import DGLGraph from dgl import DGLGraph
......
...@@ -5,9 +5,8 @@ import torch.optim.lr_scheduler as lr_scheduler ...@@ -5,9 +5,8 @@ import torch.optim.lr_scheduler as lr_scheduler
from dgl import model_zoo from dgl import model_zoo
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
import math, random, sys import sys
import argparse import argparse
from collections import deque
import rdkit import rdkit
from jtnn import * from jtnn import *
...@@ -42,8 +41,6 @@ parser.add_argument("-z", "--beta", dest="beta", default=1.0, ...@@ -42,8 +41,6 @@ parser.add_argument("-z", "--beta", dest="beta", default=1.0,
help="Coefficient of KL Divergence term") help="Coefficient of KL Divergence term")
parser.add_argument("-q", "--lr", dest="lr", default=1e-3, parser.add_argument("-q", "--lr", dest="lr", default=1e-3,
help="Learning Rate") help="Learning Rate")
parser.add_argument("-T", "--test", dest="test", action="store_true",
help="Add this flag to run test mode")
args = parser.parse_args() args = parser.parse_args()
dataset = JTNNDataset(data=args.train, vocab=args.vocab, training=True) dataset = JTNNDataset(data=args.train, vocab=args.vocab, training=True)
...@@ -131,34 +128,7 @@ def train(): ...@@ -131,34 +128,7 @@ def train():
print("learning rate: %.6f" % scheduler.get_lr()[0]) print("learning rate: %.6f" % scheduler.get_lr()[0])
torch.save(model.state_dict(), args.save_path + "/model.iter-" + str(epoch)) torch.save(model.state_dict(), args.save_path + "/model.iter-" + str(epoch))
def test():
dataset.training = False
dataloader = DataLoader(
dataset,
batch_size=1,
shuffle=False,
num_workers=0,
collate_fn=JTNNCollator(vocab, False),
drop_last=True,
worker_init_fn=worker_init_fn)
# Just an example of molecule decoding; in reality you may want to sample
# tree and molecule vectors.
for it, batch in enumerate(dataloader):
gt_smiles = batch['mol_trees'][0].smiles
print(gt_smiles)
model.move_to_cuda(batch)
_, tree_vec, mol_vec = model.encode(batch)
tree_vec, mol_vec, _, _ = model.sample(tree_vec, mol_vec)
smiles = model.decode(tree_vec, mol_vec)
print(smiles)
if __name__ == '__main__': if __name__ == '__main__':
if args.test:
test()
else:
train() train()
print('# passes:', model.n_passes) print('# passes:', model.n_passes)
......
# pylint: disable=C0111, C0103, E1101, W0611, W0612, W0703, C0200, R1710 # pylint: disable=C0111, C0103, E1101, W0611, W0612, W0703, C0200, R1710
from collections import defaultdict from collections import defaultdict
import rdkit
import rdkit.Chem as Chem import rdkit.Chem as Chem
from rdkit.Chem.EnumerateStereoisomers import (EnumerateStereoisomers, from rdkit.Chem.EnumerateStereoisomers import EnumerateStereoisomers
StereoEnumerationOptions)
from scipy.sparse import csr_matrix from scipy.sparse import csr_matrix
from scipy.sparse.csgraph import minimum_spanning_tree from scipy.sparse.csgraph import minimum_spanning_tree
......
...@@ -9,7 +9,6 @@ import torch.nn as nn ...@@ -9,7 +9,6 @@ import torch.nn as nn
import dgl.function as DGLF import dgl.function as DGLF
from dgl import DGLGraph, mean_nodes from dgl import DGLGraph, mean_nodes
from .chemutils import get_mol
from .nnutils import cuda from .nnutils import cuda
ELEM_LIST = ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na', ELEM_LIST = ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na',
...@@ -57,16 +56,12 @@ def mol2dgl_single(cand_batch): ...@@ -57,16 +56,12 @@ def mol2dgl_single(cand_batch):
tree_mess_target_edges = [] # these edges on candidate graphs tree_mess_target_edges = [] # these edges on candidate graphs
tree_mess_target_nodes = [] tree_mess_target_nodes = []
n_nodes = 0 n_nodes = 0
n_edges = 0
atom_x = [] atom_x = []
bond_x = [] bond_x = []
for mol, mol_tree, ctr_node_id in cand_batch: for mol, mol_tree, ctr_node_id in cand_batch:
n_atoms = mol.GetNumAtoms() n_atoms = mol.GetNumAtoms()
n_bonds = mol.GetNumBonds()
ctr_node = mol_tree.nodes_dict[ctr_node_id]
ctr_bid = ctr_node['idx']
g = DGLGraph() g = DGLGraph()
for i, atom in enumerate(mol.GetAtoms()): for i, atom in enumerate(mol.GetAtoms()):
......
# pylint: disable=C0111, C0103, E1101, W0611, W0612 # pylint: disable=C0111, C0103, E1101, W0611, W0612
import copy
import itertools
import networkx as nx
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -12,7 +8,6 @@ import dgl.function as DGLF ...@@ -12,7 +8,6 @@ import dgl.function as DGLF
from dgl import batch, dfs_labeled_edges_generator from dgl import batch, dfs_labeled_edges_generator
from .chemutils import enum_assemble_nx, get_mol from .chemutils import enum_assemble_nx, get_mol
from .mol_tree import Vocab
from .mol_tree_nx import DGLMolTree from .mol_tree_nx import DGLMolTree
from .nnutils import GRUUpdate, cuda from .nnutils import GRUUpdate, cuda
...@@ -274,7 +269,6 @@ class DGLJTNNDecoder(nn.Module): ...@@ -274,7 +269,6 @@ class DGLJTNNDecoder(nn.Module):
stack.append((0, self.vocab.get_slots(root_wid))) stack.append((0, self.vocab.get_slots(root_wid)))
all_nodes = {0: root_node_dict} all_nodes = {0: root_node_dict}
h = {}
first = True first = True
new_node_id = 0 new_node_id = 0
new_edge_id = 0 new_edge_id = 0
...@@ -282,7 +276,6 @@ class DGLJTNNDecoder(nn.Module): ...@@ -282,7 +276,6 @@ class DGLJTNNDecoder(nn.Module):
for step in range(MAX_DECODE_LEN): for step in range(MAX_DECODE_LEN):
u, u_slots = stack[-1] u, u_slots = stack[-1]
udata = mol_tree.nodes[u].data udata = mol_tree.nodes[u].data
wid = udata['wid']
x = udata['x'] x = udata['x']
h = udata['h'] h = udata['h']
......
# pylint: disable=C0111, C0103, E1101, W0611, W0612 # pylint: disable=C0111, C0103, E1101, W0611, W0612
import itertools
from collections import deque
import networkx as nx
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import dgl.function as DGLF import dgl.function as DGLF
from dgl import batch, bfs_edges_generator, unbatch from dgl import batch, bfs_edges_generator
from .mol_tree import Vocab
from .nnutils import GRUUpdate, cuda from .nnutils import GRUUpdate, cuda
MAX_NB = 8 MAX_NB = 8
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment