Unverified Commit ac932c66 authored by Gan Quan's avatar Gan Quan Committed by GitHub
Browse files

[Model] Junction Tree VAE update (#157)

* cherry picking optimization from jtnn

* adding official code.  TODO: fix DGLMolTree

* updating to current api.  vae test still failing

* reverting to list stacking

* reverting to list stacking

* cleaning x flags (stupid windows)

* cleaning x flags (stupid windows)

* adding stats

* optimization

* updating dgl stats

* update again

* more optimization

* looks like computation is faster

* removing profiling code

* cleaning obsolete code

* remove comparison warning

* readme update

* official implementation got a lot faster

* minor fixes

* unbatch by slicing frames

* working around unbatch

* reduce pack

* oops

* support frame read/write with slices

* reverting back to readout as unbatch-by-slicing slows down backward

* reverting to unbatch by splitting; slicing is unfriendly to backward

* replacing lru cache with static object factory

* cherry picking optimization from jtnn

* unbatch by slicing frames

* reduce pack

* oops

* support frame read/write with slices

* reverting to unbatch by splitting; slicing is unfriendly to backward

* replacing lru cache with static object factory

* replacing Scheme object with namedtuple

* forgot the find edges interface

* subclassing namedtuple

* updating to the latest api spec

* bugfix

* bfs with edges

* dfs toy test case

* clean up

* style fix

* bugfix

* update to latest api; include traversal

* replacing with readout

* simplify decoder

* oops

* cleanup

* reducing number of sets

* more speed up

* profile results

* random fixes

* fixing tvmarray handling incontiguous dlpack input

* fancier dataloader

* fix a potential context mismatch

* todo: support pickling or using scipy in multiprocessing load

* pickling support

* resorting to suggested way of pickling

* custom attribute pickling check

* working around a weird pytorch pickling bug

* including partial frame case

* enabling multiprocessing dataloader

* pickling everything now

* really works

* oops

* updated profiling results

* cleanup

* fix as requested

* cleaning random blank lines

* removing profiler outputs

* starting decoding

* testing, WIP

* tree decoding

* graph decoding, WIP

* graph decoding works

* oops

* fixing legacy apis

* trimming number of candidate structures

* sampling cleanups

* removing comparison test

* updated description
parent 4682b76e
...@@ -11,3 +11,16 @@ python3 vaetrain_dgl.py ...@@ -11,3 +11,16 @@ python3 vaetrain_dgl.py
``` ```
The script will automatically download the data, which is the same as the one in the The script will automatically download the data, which is the same as the one in the
original repository. original repository.
To disable CUDA, run with `NOCUDA` variable set:
```
NOCUDA=1 python3 vaetrain_dgl.py
```
To decode for new molecules, run
```
python3 vaetrain_dgl.py -T
```
Currently, decoding involves encoding a training example, sampling from the posterior
distribution, and decoding a molecule from that.
from .mol_tree import Vocab from .mol_tree import Vocab
from .jtnn_vae import DGLJTNNVAE from .jtnn_vae import DGLJTNNVAE
from .mpn import DGLMPN, mol2dgl from .mpn import DGLMPN
from .nnutils import create_var from .nnutils import create_var, cuda
from .datautils import JTNNDataset from .datautils import JTNNDataset, JTNNCollator
from .chemutils import decode_stereo from .chemutils import decode_stereo
from .line_profiler_integration import profile
...@@ -251,8 +251,7 @@ def enum_attach_nx(ctr_mol, nei_node, amap, singletons): ...@@ -251,8 +251,7 @@ def enum_attach_nx(ctr_mol, nei_node, amap, singletons):
return att_confs return att_confs
#Try rings first: Speed-Up #Try rings first: Speed-Up
def enum_assemble_nx(graph, node_idx, neighbors, prev_nodes=[], prev_amap=[]): def enum_assemble_nx(node, neighbors, prev_nodes=[], prev_amap=[]):
node = graph.nodes[node_idx]
all_attach_confs = [] all_attach_confs = []
singletons = [nei_node['nid'] for nei_node in neighbors + prev_nodes if nei_node['mol'].GetNumAtoms() == 1] singletons = [nei_node['nid'] for nei_node in neighbors + prev_nodes if nei_node['mol'].GetNumAtoms() == 1]
...@@ -301,21 +300,21 @@ def enum_assemble_nx(graph, node_idx, neighbors, prev_nodes=[], prev_amap=[]): ...@@ -301,21 +300,21 @@ def enum_assemble_nx(graph, node_idx, neighbors, prev_nodes=[], prev_amap=[]):
#Only used for debugging purpose #Only used for debugging purpose
def dfs_assemble_nx(graph, cur_mol, global_amap, fa_amap, cur_node_id, fa_node_id): def dfs_assemble_nx(graph, cur_mol, global_amap, fa_amap, cur_node_id, fa_node_id):
cur_node = graph.nodes[cur_node_id] cur_node = graph.nodes_dict[cur_node_id]
fa_node = graph.nodes[fa_node_id] if fa_node_id is not None else None fa_node = graph.nodes_dict[fa_node_id] if fa_node_id is not None else None
fa_nid = fa_node['nid'] if fa_node is not None else -1 fa_nid = fa_node['nid'] if fa_node is not None else -1
prev_nodes = [fa_node] if fa_node is not None else [] prev_nodes = [fa_node] if fa_node is not None else []
children_id = [nei for nei in graph[cur_node_id] if graph.nodes[nei]['nid'] != fa_nid] children_id = [nei for nei in graph[cur_node_id] if graph.nodes_dict[nei]['nid'] != fa_nid]
children = [graph.nodes[nei] for nei in children_id] children = [graph.nodes_dict[nei] for nei in children_id]
neighbors = [nei for nei in children if nei['mol'].GetNumAtoms() > 1] neighbors = [nei for nei in children if nei['mol'].GetNumAtoms() > 1]
neighbors = sorted(neighbors, key=lambda x:x['mol'].GetNumAtoms(), reverse=True) neighbors = sorted(neighbors, key=lambda x:x['mol'].GetNumAtoms(), reverse=True)
singletons = [nei for nei in children if nei['mol'].GetNumAtoms() == 1] singletons = [nei for nei in children if nei['mol'].GetNumAtoms() == 1]
neighbors = singletons + neighbors neighbors = singletons + neighbors
cur_amap = [(fa_nid,a2,a1) for nid,a1,a2 in fa_amap if nid == cur_node['nid']] cur_amap = [(fa_nid,a2,a1) for nid,a1,a2 in fa_amap if nid == cur_node['nid']]
cands = enum_assemble_nx(graph, cur_node_id, neighbors, prev_nodes, cur_amap) cands = enum_assemble_nx(graph.nodes_dict[cur_node_id], neighbors, prev_nodes, cur_amap)
if len(cands) == 0: if len(cands) == 0:
return return
......
import torch
from torch.utils.data import Dataset from torch.utils.data import Dataset
import numpy as np import numpy as np
import dgl import dgl
from dgl.data.utils import download, extract_archive, get_download_dir from dgl.data.utils import download, extract_archive, get_download_dir
from .mol_tree_nx import DGLMolTree
from .mol_tree import Vocab
from .mpn import mol2dgl_single as mol2dgl_enc
from .jtmpn import mol2dgl_single as mol2dgl_dec
_url = 'https://www.dropbox.com/s/4ypr0e0abcbsvoh/jtnn.zip?dl=1' _url = 'https://www.dropbox.com/s/4ypr0e0abcbsvoh/jtnn.zip?dl=1'
def _unpack_field(examples, field):
return [e[field] for e in examples]
def _set_node_id(mol_tree, vocab):
wid = []
for i, node in enumerate(mol_tree.nodes_dict):
mol_tree.nodes_dict[node]['idx'] = i
wid.append(vocab.get_index(mol_tree.nodes_dict[node]['smiles']))
return wid
class JTNNDataset(Dataset): class JTNNDataset(Dataset):
def __init__(self, data, vocab): def __init__(self, data, vocab, training=True):
self.dir = get_download_dir() self.dir = get_download_dir()
self.zip_file_path='{}/jtnn.zip'.format(self.dir) self.zip_file_path='{}/jtnn.zip'.format(self.dir)
download(_url, path=self.zip_file_path) download(_url, path=self.zip_file_path)
...@@ -20,14 +37,186 @@ class JTNNDataset(Dataset): ...@@ -20,14 +37,186 @@ class JTNNDataset(Dataset):
print('Loading finished.') print('Loading finished.')
print('\tNum samples:', len(self.data)) print('\tNum samples:', len(self.data))
print('\tVocab file:', self.vocab_file) print('\tVocab file:', self.vocab_file)
self.training = training
self.vocab = Vocab([x.strip("\r\n ") for x in open(self.vocab_file)])
def __len__(self): def __len__(self):
return len(self.data) return len(self.data)
def __getitem__(self, idx): def __getitem__(self, idx):
from .mol_tree_nx import DGLMolTree
smiles = self.data[idx] smiles = self.data[idx]
mol_tree = DGLMolTree(smiles) mol_tree = DGLMolTree(smiles)
mol_tree.recover() mol_tree.recover()
mol_tree.assemble() mol_tree.assemble()
return mol_tree
wid = _set_node_id(mol_tree, self.vocab)
# prebuild the molecule graph
mol_graph, atom_x_enc, bond_x_enc = mol2dgl_enc(mol_tree.smiles)
result = {
'mol_tree': mol_tree,
'mol_graph': mol_graph,
'atom_x_enc': atom_x_enc,
'bond_x_enc': bond_x_enc,
'wid': wid,
}
if not self.training:
return result
# prebuild the candidate graph list
cands = []
for node_id, node in mol_tree.nodes_dict.items():
# fill in ground truth
if node['label'] not in node['cands']:
node['cands'].append(node['label'])
node['cand_mols'].append(node['label_mol'])
if node['is_leaf'] or len(node['cands']) == 1:
continue
cands.extend([(cand, mol_tree, node_id)
for cand in node['cand_mols']])
if len(cands) > 0:
cand_graphs, atom_x_dec, bond_x_dec, tree_mess_src_e, \
tree_mess_tgt_e, tree_mess_tgt_n = mol2dgl_dec(cands)
else:
cand_graphs = []
atom_x_dec = torch.zeros(0, atom_x_enc.shape[1])
bond_x_dec = torch.zeros(0, bond_x_enc.shape[1])
tree_mess_src_e = torch.zeros(0, 2).long()
tree_mess_tgt_e = torch.zeros(0, 2).long()
tree_mess_tgt_n = torch.zeros(0, 2).long()
# prebuild the stereoisomers
cands = mol_tree.stereo_cands
if len(cands) > 1:
if mol_tree.smiles3D not in cands:
cands.append(mol_tree.smiles3D)
stereo_graphs = [mol2dgl_enc(c) for c in cands]
stereo_cand_graphs, stereo_atom_x_enc, stereo_bond_x_enc = \
zip(*stereo_graphs)
stereo_atom_x_enc = torch.cat(stereo_atom_x_enc)
stereo_bond_x_enc = torch.cat(stereo_bond_x_enc)
stereo_cand_label = [(cands.index(mol_tree.smiles3D), len(cands))]
else:
stereo_cand_graphs = []
stereo_atom_x_enc = torch.zeros(0, atom_x_enc.shape[1])
stereo_bond_x_enc = torch.zeros(0, bond_x_enc.shape[1])
stereo_cand_label = []
result.update({
'cand_graphs': cand_graphs,
'atom_x_dec': atom_x_dec,
'bond_x_dec': bond_x_dec,
'tree_mess_src_e': tree_mess_src_e,
'tree_mess_tgt_e': tree_mess_tgt_e,
'tree_mess_tgt_n': tree_mess_tgt_n,
'stereo_cand_graphs': stereo_cand_graphs,
'stereo_atom_x_enc': stereo_atom_x_enc,
'stereo_bond_x_enc': stereo_bond_x_enc,
'stereo_cand_label': stereo_cand_label,
})
return result
class JTNNCollator(object):
def __init__(self, vocab, training):
self.vocab = vocab
self.training = training
@staticmethod
def _batch_and_set(graphs, atom_x, bond_x, flatten):
if flatten:
graphs = [g for f in graphs for g in f]
graph_batch = dgl.batch(graphs)
graph_batch.ndata['x'] = atom_x
graph_batch.edata.update({
'x': bond_x,
'src_x': atom_x.new(bond_x.shape[0], atom_x.shape[1]).zero_(),
})
return graph_batch
def __call__(self, examples):
# get list of trees
mol_trees = _unpack_field(examples, 'mol_tree')
wid = _unpack_field(examples, 'wid')
for _wid, mol_tree in zip(wid, mol_trees):
mol_tree.ndata['wid'] = torch.LongTensor(_wid)
# TODO: either support pickling or get around ctypes pointers using scipy
# batch molecule graphs
mol_graphs = _unpack_field(examples, 'mol_graph')
atom_x = torch.cat(_unpack_field(examples, 'atom_x_enc'))
bond_x = torch.cat(_unpack_field(examples, 'bond_x_enc'))
mol_graph_batch = self._batch_and_set(mol_graphs, atom_x, bond_x, False)
result = {
'mol_trees': mol_trees,
'mol_graph_batch': mol_graph_batch,
}
if not self.training:
return result
# batch candidate graphs
cand_graphs = _unpack_field(examples, 'cand_graphs')
cand_batch_idx = []
atom_x = torch.cat(_unpack_field(examples, 'atom_x_dec'))
bond_x = torch.cat(_unpack_field(examples, 'bond_x_dec'))
tree_mess_src_e = _unpack_field(examples, 'tree_mess_src_e')
tree_mess_tgt_e = _unpack_field(examples, 'tree_mess_tgt_e')
tree_mess_tgt_n = _unpack_field(examples, 'tree_mess_tgt_n')
n_graph_nodes = 0
n_tree_nodes = 0
for i in range(len(cand_graphs)):
tree_mess_tgt_e[i] += n_graph_nodes
tree_mess_src_e[i] += n_tree_nodes
tree_mess_tgt_n[i] += n_graph_nodes
n_graph_nodes += sum(g.number_of_nodes() for g in cand_graphs[i])
n_tree_nodes += mol_trees[i].number_of_nodes()
cand_batch_idx.extend([i] * len(cand_graphs[i]))
tree_mess_tgt_e = torch.cat(tree_mess_tgt_e)
tree_mess_src_e = torch.cat(tree_mess_src_e)
tree_mess_tgt_n = torch.cat(tree_mess_tgt_n)
cand_graph_batch = self._batch_and_set(cand_graphs, atom_x, bond_x, True)
# batch stereoisomers
stereo_cand_graphs = _unpack_field(examples, 'stereo_cand_graphs')
atom_x = torch.cat(_unpack_field(examples, 'stereo_atom_x_enc'))
bond_x = torch.cat(_unpack_field(examples, 'stereo_bond_x_enc'))
stereo_cand_batch_idx = []
for i in range(len(stereo_cand_graphs)):
stereo_cand_batch_idx.extend([i] * len(stereo_cand_graphs[i]))
if len(stereo_cand_batch_idx) > 0:
stereo_cand_labels = [
(label, length)
for ex in _unpack_field(examples, 'stereo_cand_label')
for label, length in ex
]
stereo_cand_labels, stereo_cand_lengths = zip(*stereo_cand_labels)
stereo_cand_graph_batch = self._batch_and_set(
stereo_cand_graphs, atom_x, bond_x, True)
else:
stereo_cand_labels = []
stereo_cand_lengths = []
stereo_cand_graph_batch = None
stereo_cand_batch_idx = []
result.update({
'cand_graph_batch': cand_graph_batch,
'cand_batch_idx': cand_batch_idx,
'tree_mess_tgt_e': tree_mess_tgt_e,
'tree_mess_src_e': tree_mess_src_e,
'tree_mess_tgt_n': tree_mess_tgt_n,
'stereo_cand_graph_batch': stereo_cand_graph_batch,
'stereo_cand_batch_idx': stereo_cand_batch_idx,
'stereo_cand_labels': stereo_cand_labels,
'stereo_cand_lengths': stereo_cand_lengths,
})
return result
...@@ -4,10 +4,11 @@ from .nnutils import cuda ...@@ -4,10 +4,11 @@ from .nnutils import cuda
from .chemutils import get_mol from .chemutils import get_mol
#from mpn import atom_features, bond_features, ATOM_FDIM, BOND_FDIM #from mpn import atom_features, bond_features, ATOM_FDIM, BOND_FDIM
import rdkit.Chem as Chem import rdkit.Chem as Chem
from dgl import DGLGraph, line_graph, batch, unbatch from dgl import DGLGraph, batch, unbatch, mean_nodes
import dgl.function as DGLF import dgl.function as DGLF
from .line_profiler_integration import profile from .line_profiler_integration import profile
import os import os
import numpy as np
ELEM_LIST = ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na', 'Ca', 'Fe', 'Al', 'I', 'B', 'K', 'Se', 'Zn', 'H', 'Cu', 'Mn', 'unknown'] ELEM_LIST = ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na', 'Ca', 'Fe', 'Al', 'I', 'B', 'K', 'Se', 'Zn', 'H', 'Cu', 'Mn', 'unknown']
...@@ -27,81 +28,81 @@ def onek_encoding_unk(x, allowable_set): ...@@ -27,81 +28,81 @@ def onek_encoding_unk(x, allowable_set):
# the 2-D graph first, then enumerate all possible 3-D forms and find the # the 2-D graph first, then enumerate all possible 3-D forms and find the
# one with highest score. # one with highest score.
def atom_features(atom): def atom_features(atom):
return cuda(torch.Tensor(onek_encoding_unk(atom.GetSymbol(), ELEM_LIST) return (torch.Tensor(onek_encoding_unk(atom.GetSymbol(), ELEM_LIST)
+ onek_encoding_unk(atom.GetDegree(), [0,1,2,3,4,5]) + onek_encoding_unk(atom.GetDegree(), [0,1,2,3,4,5])
+ onek_encoding_unk(atom.GetFormalCharge(), [-1,-2,1,2,0]) + onek_encoding_unk(atom.GetFormalCharge(), [-1,-2,1,2,0])
+ [atom.GetIsAromatic()])) + [atom.GetIsAromatic()]))
def bond_features(bond): def bond_features(bond):
bt = bond.GetBondType() bt = bond.GetBondType()
return cuda(torch.Tensor([bt == Chem.rdchem.BondType.SINGLE, bt == Chem.rdchem.BondType.DOUBLE, bt == Chem.rdchem.BondType.TRIPLE, bt == Chem.rdchem.BondType.AROMATIC, bond.IsInRing()])) return (torch.Tensor([bt == Chem.rdchem.BondType.SINGLE, bt == Chem.rdchem.BondType.DOUBLE, bt == Chem.rdchem.BondType.TRIPLE, bt == Chem.rdchem.BondType.AROMATIC, bond.IsInRing()]))
def mol2dgl_single(cand_batch):
@profile
def mol2dgl(cand_batch, mol_tree_batch):
cand_graphs = [] cand_graphs = []
tree_mess_source_edges = [] # map these edges from trees to... tree_mess_source_edges = [] # map these edges from trees to...
tree_mess_target_edges = [] # these edges on candidate graphs tree_mess_target_edges = [] # these edges on candidate graphs
tree_mess_target_nodes = [] tree_mess_target_nodes = []
n_nodes = 0 n_nodes = 0
n_edges = 0
atom_x = []
bond_x = []
for mol, mol_tree, ctr_node_id in cand_batch: for mol, mol_tree, ctr_node_id in cand_batch:
atom_feature_list = [] n_atoms = mol.GetNumAtoms()
bond_feature_list = [] n_bonds = mol.GetNumBonds()
ctr_node = mol_tree.nodes[ctr_node_id]
ctr_node = mol_tree.nodes_dict[ctr_node_id]
ctr_bid = ctr_node['idx'] ctr_bid = ctr_node['idx']
g = DGLGraph() g = DGLGraph()
for atom in mol.GetAtoms(): for i, atom in enumerate(mol.GetAtoms()):
atom_feature_list.append(atom_features(atom)) assert i == atom.GetIdx()
g.add_node(atom.GetIdx()) atom_x.append(atom_features(atom))
g.add_nodes(n_atoms)
for bond in mol.GetBonds(): bond_src = []
bond_dst = []
for i, bond in enumerate(mol.GetBonds()):
a1 = bond.GetBeginAtom() a1 = bond.GetBeginAtom()
a2 = bond.GetEndAtom() a2 = bond.GetEndAtom()
begin_idx = a1.GetIdx() begin_idx = a1.GetIdx()
end_idx = a2.GetIdx() end_idx = a2.GetIdx()
features = bond_features(bond) features = bond_features(bond)
g.add_edge(begin_idx, end_idx) bond_src.append(begin_idx)
bond_feature_list.append(features) bond_dst.append(end_idx)
g.add_edge(end_idx, begin_idx) bond_x.append(features)
bond_feature_list.append(features) bond_src.append(end_idx)
bond_dst.append(begin_idx)
bond_x.append(features)
x_nid, y_nid = a1.GetAtomMapNum(), a2.GetAtomMapNum() x_nid, y_nid = a1.GetAtomMapNum(), a2.GetAtomMapNum()
# Tree node ID in the batch # Tree node ID in the batch
x_bid = mol_tree.nodes[x_nid - 1]['idx'] if x_nid > 0 else -1 x_bid = mol_tree.nodes_dict[x_nid - 1]['idx'] if x_nid > 0 else -1
y_bid = mol_tree.nodes[y_nid - 1]['idx'] if y_nid > 0 else -1 y_bid = mol_tree.nodes_dict[y_nid - 1]['idx'] if y_nid > 0 else -1
if x_bid >= 0 and y_bid >= 0 and x_bid != y_bid: if x_bid >= 0 and y_bid >= 0 and x_bid != y_bid:
if (x_bid, y_bid) in mol_tree_batch.edge_list: if mol_tree.has_edge_between(x_bid, y_bid):
tree_mess_target_edges.append( tree_mess_target_edges.append((begin_idx + n_nodes, end_idx + n_nodes))
(begin_idx + n_nodes, end_idx + n_nodes))
tree_mess_source_edges.append((x_bid, y_bid)) tree_mess_source_edges.append((x_bid, y_bid))
tree_mess_target_nodes.append(end_idx + n_nodes) tree_mess_target_nodes.append(end_idx + n_nodes)
if (y_bid, x_bid) in mol_tree_batch.edge_list: if mol_tree.has_edge_between(y_bid, x_bid):
tree_mess_target_edges.append( tree_mess_target_edges.append((end_idx + n_nodes, begin_idx + n_nodes))
(end_idx + n_nodes, begin_idx + n_nodes))
tree_mess_source_edges.append((y_bid, x_bid)) tree_mess_source_edges.append((y_bid, x_bid))
tree_mess_target_nodes.append(begin_idx + n_nodes) tree_mess_target_nodes.append(begin_idx + n_nodes)
n_nodes += len(g.nodes) n_nodes += n_atoms
g.add_edges(bond_src, bond_dst)
atom_x = torch.stack(atom_feature_list)
g.set_n_repr({'x': atom_x})
if len(bond_feature_list) > 0:
bond_x = torch.stack(bond_feature_list)
g.set_e_repr({
'x': bond_x,
'src_x': atom_x.new(len(bond_feature_list), ATOM_FDIM).zero_()
})
cand_graphs.append(g) cand_graphs.append(g)
return cand_graphs, tree_mess_source_edges, tree_mess_target_edges, \ return cand_graphs, torch.stack(atom_x), \
tree_mess_target_nodes torch.stack(bond_x) if len(bond_x) > 0 else torch.zeros(0), \
torch.LongTensor(tree_mess_source_edges), \
torch.LongTensor(tree_mess_target_edges), \
torch.LongTensor(tree_mess_target_nodes)
mpn_loopy_bp_msg = DGLF.copy_src(src='msg', out='msg') mpn_loopy_bp_msg = DGLF.copy_src(src='msg', out='msg')
mpn_loopy_bp_reduce = DGLF.sum(msgs='msg', out='accum_msg') mpn_loopy_bp_reduce = DGLF.sum(msg='msg', out='accum_msg')
class LoopyBPUpdate(nn.Module): class LoopyBPUpdate(nn.Module):
...@@ -112,8 +113,8 @@ class LoopyBPUpdate(nn.Module): ...@@ -112,8 +113,8 @@ class LoopyBPUpdate(nn.Module):
self.W_h = nn.Linear(hidden_size, hidden_size, bias=False) self.W_h = nn.Linear(hidden_size, hidden_size, bias=False)
def forward(self, node): def forward(self, node):
msg_input = node['msg_input'] msg_input = node.data['msg_input']
msg_delta = self.W_h(node['accum_msg'] + node['alpha']) msg_delta = self.W_h(node.data['accum_msg'] + node.data['alpha'])
msg = torch.relu(msg_input + msg_delta) msg = torch.relu(msg_input + msg_delta)
return {'msg': msg} return {'msg': msg}
...@@ -129,11 +130,11 @@ else: ...@@ -129,11 +130,11 @@ else:
if PAPER: if PAPER:
mpn_gather_reduce = [ mpn_gather_reduce = [
DGLF.sum(msgs='msg', out='m'), DGLF.sum(msg='msg', out='m'),
DGLF.sum(msgs='alpha', out='accum_alpha'), DGLF.sum(msg='alpha', out='accum_alpha'),
] ]
else: else:
mpn_gather_reduce = DGLF.sum(msgs='msg', out='m') mpn_gather_reduce = DGLF.sum(msg='msg', out='m')
class GatherUpdate(nn.Module): class GatherUpdate(nn.Module):
...@@ -146,11 +147,11 @@ class GatherUpdate(nn.Module): ...@@ -146,11 +147,11 @@ class GatherUpdate(nn.Module):
def forward(self, node): def forward(self, node):
if PAPER: if PAPER:
#m = node['m'] #m = node['m']
m = node['m'] + node['accum_alpha'] m = node.data['m'] + node.data['accum_alpha']
else: else:
m = node['m'] + node['alpha'] m = node.data['m'] + node.data['alpha']
return { return {
'h': torch.relu(self.W_o(torch.cat([node['x'], m], 1))), 'h': torch.relu(self.W_o(torch.cat([node.data['x'], m], 1))),
} }
...@@ -171,25 +172,21 @@ class DGLJTMPN(nn.Module): ...@@ -171,25 +172,21 @@ class DGLJTMPN(nn.Module):
self.n_edges_total = 0 self.n_edges_total = 0
self.n_passes = 0 self.n_passes = 0
@profile
def forward(self, cand_batch, mol_tree_batch): def forward(self, cand_batch, mol_tree_batch):
cand_graphs, tree_mess_src_edges, tree_mess_tgt_edges, tree_mess_tgt_nodes = \ cand_graphs, tree_mess_src_edges, tree_mess_tgt_edges, tree_mess_tgt_nodes = cand_batch
mol2dgl(cand_batch, mol_tree_batch)
n_samples = len(cand_graphs) n_samples = len(cand_graphs)
cand_graphs = batch(cand_graphs) cand_line_graph = cand_graphs.line_graph(backtracking=False, shared=True)
cand_line_graph = line_graph(cand_graphs, no_backtracking=True)
n_nodes = len(cand_graphs.nodes) n_nodes = cand_graphs.number_of_nodes()
n_edges = len(cand_graphs.edges) n_edges = cand_graphs.number_of_edges()
cand_graphs = self.run( cand_graphs = self.run(
cand_graphs, cand_line_graph, tree_mess_src_edges, tree_mess_tgt_edges, cand_graphs, cand_line_graph, tree_mess_src_edges, tree_mess_tgt_edges,
tree_mess_tgt_nodes, mol_tree_batch) tree_mess_tgt_nodes, mol_tree_batch)
cand_graphs = unbatch(cand_graphs) g_repr = mean_nodes(cand_graphs, 'h')
g_repr = torch.stack([g.get_n_repr()['h'].mean(0) for g in cand_graphs], 0)
self.n_samples_total += n_samples self.n_samples_total += n_samples
self.n_nodes_total += n_nodes self.n_nodes_total += n_nodes
...@@ -198,65 +195,61 @@ class DGLJTMPN(nn.Module): ...@@ -198,65 +195,61 @@ class DGLJTMPN(nn.Module):
return g_repr return g_repr
@profile
def run(self, cand_graphs, cand_line_graph, tree_mess_src_edges, tree_mess_tgt_edges, def run(self, cand_graphs, cand_line_graph, tree_mess_src_edges, tree_mess_tgt_edges,
tree_mess_tgt_nodes, mol_tree_batch): tree_mess_tgt_nodes, mol_tree_batch):
n_nodes = len(cand_graphs.nodes) n_nodes = cand_graphs.number_of_nodes()
cand_graphs.update_edge( cand_graphs.apply_edges(
#*zip(*cand_graphs.edge_list), func=lambda edges: {'src_x': edges.src['x']},
edge_func=lambda src, dst, edge: {'src_x': src['x']},
batchable=True,
) )
bond_features = cand_line_graph.get_n_repr()['x'] bond_features = cand_line_graph.ndata['x']
source_features = cand_line_graph.get_n_repr()['src_x'] source_features = cand_line_graph.ndata['src_x']
features = torch.cat([source_features, bond_features], 1) features = torch.cat([source_features, bond_features], 1)
msg_input = self.W_i(features) msg_input = self.W_i(features)
cand_line_graph.set_n_repr({ cand_line_graph.ndata.update({
'msg_input': msg_input, 'msg_input': msg_input,
'msg': torch.relu(msg_input), 'msg': torch.relu(msg_input),
'accum_msg': torch.zeros_like(msg_input), 'accum_msg': torch.zeros_like(msg_input),
}) })
zero_node_state = bond_features.new(n_nodes, self.hidden_size).zero_() zero_node_state = bond_features.new(n_nodes, self.hidden_size).zero_()
cand_graphs.set_n_repr({ cand_graphs.ndata.update({
'm': zero_node_state.clone(), 'm': zero_node_state.clone(),
'h': zero_node_state.clone(), 'h': zero_node_state.clone(),
}) })
if PAPER: cand_graphs.edata['alpha'] = \
cand_graphs.set_e_repr({ cuda(torch.zeros(cand_graphs.number_of_edges(), self.hidden_size))
'alpha': cuda(torch.zeros(len(cand_graphs.edge_list), self.hidden_size)) cand_graphs.ndata['alpha'] = zero_node_state
}) if tree_mess_src_edges.shape[0] > 0:
if PAPER:
alpha = mol_tree_batch.get_e_repr(*zip(*tree_mess_src_edges))['m'] src_u, src_v = tree_mess_src_edges.unbind(1)
cand_graphs.set_e_repr({'alpha': alpha}, *zip(*tree_mess_tgt_edges)) tgt_u, tgt_v = tree_mess_tgt_edges.unbind(1)
else: alpha = mol_tree_batch.edges[src_u, src_v].data['m']
alpha = mol_tree_batch.get_e_repr(*zip(*tree_mess_src_edges))['m'] cand_graphs.edges[tgt_u, tgt_v].data['alpha'] = alpha
node_idx = (torch.LongTensor(tree_mess_tgt_nodes) else:
.to(device=zero_node_state.device)[:, None] src_u, src_v = tree_mess_src_edges.unbind(1)
.expand_as(alpha)) alpha = mol_tree_batch.edges[src_u, src_v].data['m']
node_alpha = zero_node_state.clone().scatter_add(0, node_idx, alpha) node_idx = (tree_mess_tgt_nodes
cand_graphs.set_n_repr({'alpha': node_alpha}) .to(device=zero_node_state.device)[:, None]
cand_graphs.update_edge( .expand_as(alpha))
#*zip(*cand_graphs.edge_list), node_alpha = zero_node_state.clone().scatter_add(0, node_idx, alpha)
edge_func=lambda src, dst, edge: {'alpha': src['alpha']}, cand_graphs.ndata['alpha'] = node_alpha
batchable=True, cand_graphs.apply_edges(
) func=lambda edges: {'alpha': edges.src['alpha']},
)
for i in range(self.depth - 1): for i in range(self.depth - 1):
cand_line_graph.update_all( cand_line_graph.update_all(
mpn_loopy_bp_msg, mpn_loopy_bp_msg,
mpn_loopy_bp_reduce, mpn_loopy_bp_reduce,
self.loopy_bp_updater, self.loopy_bp_updater,
True
) )
cand_graphs.update_all( cand_graphs.update_all(
mpn_gather_msg, mpn_gather_msg,
mpn_gather_reduce, mpn_gather_reduce,
self.gather_updater, self.gather_updater,
True
) )
return cand_graphs return cand_graphs
...@@ -2,52 +2,87 @@ import torch ...@@ -2,52 +2,87 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from .mol_tree import Vocab from .mol_tree import Vocab
from .mol_tree_nx import DGLMolTree
from .chemutils import enum_assemble_nx, get_mol
from .nnutils import GRUUpdate, cuda from .nnutils import GRUUpdate, cuda
import copy import copy
import itertools import itertools
from dgl import batch, line_graph from dgl import batch, dfs_labeled_edges_generator
import dgl.function as DGLF import dgl.function as DGLF
import networkx as nx import networkx as nx
from .line_profiler_integration import profile from .line_profiler_integration import profile
import numpy as np
MAX_NB = 8 MAX_NB = 8
MAX_DECODE_LEN = 100 MAX_DECODE_LEN = 100
def dfs_order(forest, roots): def dfs_order(forest, roots):
''' edges = dfs_labeled_edges_generator(forest, roots, has_reverse_edge=True)
Returns edge source, edge destination, tree ID, and whether u is generating for e, l in zip(*edges):
a new children # I exploited the fact that the reverse edge ID equal to 1 xor forward
''' # edge ID for molecule trees. Normally, I should locate reverse edges
edge_list = [] # using find_edges().
yield e ^ l, l
for i, root in enumerate(roots):
edge_list.append([])
# The following gives the DFS order on edge on a tree.
for u, v, t in nx.dfs_labeled_edges(forest, root):
if u == v or t == 'nontree':
continue
elif t == 'forward':
edge_list[-1].append((u, v, i, 1))
elif t == 'reverse':
edge_list[-1].append((v, u, i, 0))
for edges in itertools.zip_longest(*edge_list):
edges = (e for e in edges if e is not None)
u, v, i, p = zip(*edges)
yield u, v, i, p
dec_tree_node_msg = DGLF.copy_edge(edge='m', out='m') dec_tree_node_msg = DGLF.copy_edge(edge='m', out='m')
dec_tree_node_reduce = DGLF.sum(msgs='m', out='h') dec_tree_node_reduce = DGLF.sum(msg='m', out='h')
def dec_tree_node_update(node): def dec_tree_node_update(nodes):
return {'new': node['new'].clone().zero_()} return {'new': nodes.data['new'].clone().zero_()}
dec_tree_edge_msg = [DGLF.copy_src(src='m', out='m'), DGLF.copy_src(src='rm', out='rm')] dec_tree_edge_msg = [DGLF.copy_src(src='m', out='m'), DGLF.copy_src(src='rm', out='rm')]
dec_tree_edge_reduce = [DGLF.sum(msgs='m', out='s'), DGLF.sum(msgs='rm', out='accum_rm')] dec_tree_edge_reduce = [DGLF.sum(msg='m', out='s'), DGLF.sum(msg='rm', out='accum_rm')]
def have_slots(fa_slots, ch_slots):
if len(fa_slots) > 2 and len(ch_slots) > 2:
return True
matches = []
for i,s1 in enumerate(fa_slots):
a1,c1,h1 = s1
for j,s2 in enumerate(ch_slots):
a2,c2,h2 = s2
if a1 == a2 and c1 == c2 and (a1 != "C" or h1 + h2 >= 4):
matches.append( (i,j) )
if len(matches) == 0: return False
fa_match,ch_match = list(zip(*matches))
if len(set(fa_match)) == 1 and 1 < len(fa_slots) <= 2: #never remove atom from ring
fa_slots.pop(fa_match[0])
if len(set(ch_match)) == 1 and 1 < len(ch_slots) <= 2: #never remove atom from ring
ch_slots.pop(ch_match[0])
return True
def can_assemble(mol_tree, u, v_node_dict):
u_node_dict = mol_tree.nodes_dict[u]
u_neighbors = mol_tree.successors(u)
u_neighbors_node_dict = [
mol_tree.nodes_dict[_u]
for _u in u_neighbors
if _u in mol_tree.nodes_dict
]
neis = u_neighbors_node_dict + [v_node_dict]
for i,nei in enumerate(neis):
nei['nid'] = i
neighbors = [nei for nei in neis if nei['mol'].GetNumAtoms() > 1]
neighbors = sorted(neighbors, key=lambda x:x['mol'].GetNumAtoms(), reverse=True)
singletons = [nei for nei in neis if nei['mol'].GetNumAtoms() == 1]
neighbors = singletons + neighbors
cands = enum_assemble_nx(u_node_dict, neighbors)
return len(cands) > 0
def create_node_dict(smiles, clique=[]):
return dict(
smiles=smiles,
mol=get_mol(smiles),
clique=clique,
)
class DGLJTNNDecoder(nn.Module): class DGLJTNNDecoder(nn.Module):
...@@ -70,32 +105,30 @@ class DGLJTNNDecoder(nn.Module): ...@@ -70,32 +105,30 @@ class DGLJTNNDecoder(nn.Module):
self.W_o = nn.Linear(hidden_size, self.vocab_size) self.W_o = nn.Linear(hidden_size, self.vocab_size)
self.U_s = nn.Linear(hidden_size, 1) self.U_s = nn.Linear(hidden_size, 1)
@profile
def forward(self, mol_trees, tree_vec): def forward(self, mol_trees, tree_vec):
''' '''
The training procedure which computes the prediction loss given the The training procedure which computes the prediction loss given the
ground truth tree ground truth tree
''' '''
mol_tree_batch = batch(mol_trees) mol_tree_batch = batch(mol_trees)
mol_tree_batch_lg = line_graph(mol_tree_batch, no_backtracking=True) mol_tree_batch_lg = mol_tree_batch.line_graph(backtracking=False, shared=True)
n_trees = len(mol_trees) n_trees = len(mol_trees)
return self.run(mol_tree_batch, mol_tree_batch_lg, n_trees, tree_vec) return self.run(mol_tree_batch, mol_tree_batch_lg, n_trees, tree_vec)
@profile
def run(self, mol_tree_batch, mol_tree_batch_lg, n_trees, tree_vec): def run(self, mol_tree_batch, mol_tree_batch_lg, n_trees, tree_vec):
root_ids = mol_tree_batch.node_offset[:-1] node_offset = np.cumsum([0] + mol_tree_batch.batch_num_nodes)
n_nodes = len(mol_tree_batch.nodes) root_ids = node_offset[:-1]
edge_list = mol_tree_batch.edge_list n_nodes = mol_tree_batch.number_of_nodes()
n_edges = len(edge_list) n_edges = mol_tree_batch.number_of_edges()
mol_tree_batch.set_n_repr({ mol_tree_batch.ndata.update({
'x': self.embedding(mol_tree_batch.get_n_repr()['wid']), 'x': self.embedding(mol_tree_batch.ndata['wid']),
'h': cuda(torch.zeros(n_nodes, self.hidden_size)), 'h': cuda(torch.zeros(n_nodes, self.hidden_size)),
'new': cuda(torch.ones(n_nodes).byte()), # whether it's newly generated node 'new': cuda(torch.ones(n_nodes).byte()), # whether it's newly generated node
}) })
mol_tree_batch.set_e_repr({ mol_tree_batch.edata.update({
's': cuda(torch.zeros(n_edges, self.hidden_size)), 's': cuda(torch.zeros(n_edges, self.hidden_size)),
'm': cuda(torch.zeros(n_edges, self.hidden_size)), 'm': cuda(torch.zeros(n_edges, self.hidden_size)),
'r': cuda(torch.zeros(n_edges, self.hidden_size)), 'r': cuda(torch.zeros(n_edges, self.hidden_size)),
...@@ -106,10 +139,8 @@ class DGLJTNNDecoder(nn.Module): ...@@ -106,10 +139,8 @@ class DGLJTNNDecoder(nn.Module):
'accum_rm': cuda(torch.zeros(n_edges, self.hidden_size)), 'accum_rm': cuda(torch.zeros(n_edges, self.hidden_size)),
}) })
mol_tree_batch.update_edge( mol_tree_batch.apply_edges(
#*zip(*edge_list), func=lambda edges: {'src_x': edges.src['x'], 'dst_x': edges.dst['x']},
edge_func=lambda src, dst, edge: {'src_x': src['x'], 'dst_x': dst['x']},
batchable=True,
) )
# input tensors for stop prediction (p) and label prediction (q) # input tensors for stop prediction (p) and label prediction (q)
...@@ -124,52 +155,57 @@ class DGLJTNNDecoder(nn.Module): ...@@ -124,52 +155,57 @@ class DGLJTNNDecoder(nn.Module):
dec_tree_node_msg, dec_tree_node_msg,
dec_tree_node_reduce, dec_tree_node_reduce,
dec_tree_node_update, dec_tree_node_update,
batchable=True,
) )
# Extract hidden states and store them for stop/label prediction # Extract hidden states and store them for stop/label prediction
h = mol_tree_batch.get_n_repr(root_ids)['h'] h = mol_tree_batch.nodes[root_ids].data['h']
x = mol_tree_batch.get_n_repr(root_ids)['x'] x = mol_tree_batch.nodes[root_ids].data['x']
p_inputs.append(torch.cat([x, h, tree_vec], 1)) p_inputs.append(torch.cat([x, h, tree_vec], 1))
t_set = list(range(len(root_ids))) # If the out degree is 0 we don't generate any edges at all
root_out_degrees = mol_tree_batch.out_degrees(root_ids)
q_inputs.append(torch.cat([h, tree_vec], 1)) q_inputs.append(torch.cat([h, tree_vec], 1))
q_targets.append(mol_tree_batch.get_n_repr(root_ids)['wid']) q_targets.append(mol_tree_batch.nodes[root_ids].data['wid'])
# Traverse the tree and predict on children # Traverse the tree and predict on children
for u, v, i, p in dfs_order(mol_tree_batch, root_ids): for eid, p in dfs_order(mol_tree_batch, root_ids):
assert set(t_set).issuperset(i) u, v = mol_tree_batch.find_edges(eid)
ip = dict(zip(i, p))
# TODO: context p_target_list = torch.zeros_like(root_out_degrees)
p_targets.append(cuda(torch.tensor([ip.get(_i, 0) for _i in t_set]))) p_target_list[root_out_degrees > 0] = 1 - p
t_set = list(i) p_target_list = p_target_list[root_out_degrees >= 0]
eid = mol_tree_batch.get_edge_id(u, v) p_targets.append(torch.tensor(p_target_list))
root_out_degrees -= (root_out_degrees == 0).long()
root_out_degrees -= torch.tensor(np.isin(root_ids, v).astype('int64'))
mol_tree_batch_lg.pull( mol_tree_batch_lg.pull(
eid, eid,
dec_tree_edge_msg, dec_tree_edge_msg,
dec_tree_edge_reduce, dec_tree_edge_reduce,
self.dec_tree_edge_update, self.dec_tree_edge_update,
batchable=True,
) )
is_new = mol_tree_batch.get_n_repr(v)['new'] is_new = mol_tree_batch.nodes[v].data['new']
mol_tree_batch.pull( mol_tree_batch.pull(
v, v,
dec_tree_node_msg, dec_tree_node_msg,
dec_tree_node_reduce, dec_tree_node_reduce,
dec_tree_node_update, dec_tree_node_update,
batchable=True,
) )
# Extract # Extract
h = mol_tree_batch.get_n_repr(v)['h'] n_repr = mol_tree_batch.nodes[v].data
x = mol_tree_batch.get_n_repr(v)['x'] h = n_repr['h']
p_inputs.append(torch.cat([x, h, tree_vec[t_set]], 1)) x = n_repr['x']
tree_vec_set = tree_vec[root_out_degrees >= 0]
wid = n_repr['wid']
p_inputs.append(torch.cat([x, h, tree_vec_set], 1))
# Only newly generated nodes are needed for label prediction # Only newly generated nodes are needed for label prediction
# NOTE: The following works since the uncomputed messages are zeros. # NOTE: The following works since the uncomputed messages are zeros.
q_inputs.append(torch.cat([h[is_new], tree_vec[t_set][is_new]], 1)) q_inputs.append(torch.cat([h, tree_vec_set], 1)[is_new])
q_targets.append(mol_tree_batch.get_n_repr(v)['wid'][is_new]) q_targets.append(wid[is_new])
p_targets.append(cuda(torch.tensor([0 for _ in t_set]))) p_targets.append(torch.zeros((root_out_degrees == 0).sum()).long())
# Batch compute the stop/label prediction losses # Batch compute the stop/label prediction losses
p_inputs = torch.cat(p_inputs, 0) p_inputs = torch.cat(p_inputs, 0)
p_targets = torch.cat(p_targets, 0) p_targets = cuda(torch.cat(p_targets, 0))
q_inputs = torch.cat(q_inputs, 0) q_inputs = torch.cat(q_inputs, 0)
q_targets = torch.cat(q_targets, 0) q_targets = torch.cat(q_targets, 0)
...@@ -183,4 +219,161 @@ class DGLJTNNDecoder(nn.Module): ...@@ -183,4 +219,161 @@ class DGLJTNNDecoder(nn.Module):
p_acc = ((p > 0).long() == p_targets).sum().float() / p_targets.shape[0] p_acc = ((p > 0).long() == p_targets).sum().float() / p_targets.shape[0]
q_acc = (q.max(1)[1] == q_targets).float().sum() / q_targets.shape[0] q_acc = (q.max(1)[1] == q_targets).float().sum() / q_targets.shape[0]
self.q_inputs = q_inputs
self.q_targets = q_targets
self.q = q
self.p_inputs = p_inputs
self.p_targets = p_targets
self.p = p
return q_loss, p_loss, q_acc, p_acc return q_loss, p_loss, q_acc, p_acc
def decode(self, mol_vec):
assert mol_vec.shape[0] == 1
mol_tree = DGLMolTree(None)
init_hidden = cuda(torch.zeros(1, self.hidden_size))
root_hidden = torch.cat([init_hidden, mol_vec], 1)
root_hidden = F.relu(self.W(root_hidden))
root_score = self.W_o(root_hidden)
_, root_wid = torch.max(root_score, 1)
root_wid = root_wid.view(1)
mol_tree.add_nodes(1) # root
mol_tree.nodes[0].data['wid'] = root_wid
mol_tree.nodes[0].data['x'] = self.embedding(root_wid)
mol_tree.nodes[0].data['h'] = init_hidden
mol_tree.nodes[0].data['fail'] = cuda(torch.tensor([0]))
mol_tree.nodes_dict[0] = root_node_dict = create_node_dict(
self.vocab.get_smiles(root_wid))
stack, trace = [], []
stack.append((0, self.vocab.get_slots(root_wid)))
all_nodes = {0: root_node_dict}
h = {}
first = True
new_node_id = 0
new_edge_id = 0
for step in range(MAX_DECODE_LEN):
u, u_slots = stack[-1]
udata = mol_tree.nodes[u].data
wid = udata['wid']
x = udata['x']
h = udata['h']
# Predict stop
p_input = torch.cat([x, h, mol_vec], 1)
p_score = torch.sigmoid(self.U_s(torch.relu(self.U(p_input))))
backtrack = (p_score.item() < 0.5)
if not backtrack:
# Predict next clique. Note that the prediction may fail due
# to lack of assemblable components
mol_tree.add_nodes(1)
new_node_id += 1
v = new_node_id
mol_tree.add_edges(u, v)
uv = new_edge_id
new_edge_id += 1
if first:
mol_tree.edata.update({
's': cuda(torch.zeros(1, self.hidden_size)),
'm': cuda(torch.zeros(1, self.hidden_size)),
'r': cuda(torch.zeros(1, self.hidden_size)),
'z': cuda(torch.zeros(1, self.hidden_size)),
'src_x': cuda(torch.zeros(1, self.hidden_size)),
'dst_x': cuda(torch.zeros(1, self.hidden_size)),
'rm': cuda(torch.zeros(1, self.hidden_size)),
'accum_rm': cuda(torch.zeros(1, self.hidden_size)),
})
first = False
mol_tree.edges[uv].data['src_x'] = mol_tree.nodes[u].data['x']
# keeping dst_x 0 is fine as h on new edge doesn't depend on that.
# DGL doesn't dynamically maintain a line graph.
mol_tree_lg = mol_tree.line_graph(backtracking=False, shared=True)
mol_tree_lg.pull(
uv,
dec_tree_edge_msg,
dec_tree_edge_reduce,
self.dec_tree_edge_update.update_zm,
)
mol_tree.pull(
v,
dec_tree_node_msg,
dec_tree_node_reduce,
)
vdata = mol_tree.nodes[v].data
h_v = vdata['h']
q_input = torch.cat([h_v, mol_vec], 1)
q_score = torch.softmax(self.W_o(torch.relu(self.W(q_input))), -1)
_, sort_wid = torch.sort(q_score, 1, descending=True)
sort_wid = sort_wid.squeeze()
next_wid = None
for wid in sort_wid.tolist()[:5]:
slots = self.vocab.get_slots(wid)
cand_node_dict = create_node_dict(self.vocab.get_smiles(wid))
if (have_slots(u_slots, slots) and can_assemble(mol_tree, u, cand_node_dict)):
next_wid = wid
next_slots = slots
next_node_dict = cand_node_dict
break
if next_wid is None:
# Failed adding an actual children; v is a spurious node
# and we mark it.
vdata['fail'] = cuda(torch.tensor([1]))
backtrack = True
else:
next_wid = cuda(torch.tensor([next_wid]))
vdata['wid'] = next_wid
vdata['x'] = self.embedding(next_wid)
mol_tree.nodes_dict[v] = next_node_dict
all_nodes[v] = next_node_dict
stack.append((v, next_slots))
mol_tree.add_edge(v, u)
vu = new_edge_id
new_edge_id += 1
mol_tree.edges[uv].data['dst_x'] = mol_tree.nodes[v].data['x']
mol_tree.edges[vu].data['src_x'] = mol_tree.nodes[v].data['x']
mol_tree.edges[vu].data['dst_x'] = mol_tree.nodes[u].data['x']
# DGL doesn't dynamically maintain a line graph.
mol_tree_lg = mol_tree.line_graph(backtracking=False, shared=True)
mol_tree_lg.apply_nodes(
self.dec_tree_edge_update.update_r,
uv
)
if backtrack:
if len(stack) == 1:
break # At root, terminate
pu, _ = stack[-2]
u_pu = mol_tree.edge_id(u, pu)
mol_tree_lg.pull(
u_pu,
dec_tree_edge_msg,
dec_tree_edge_reduce,
self.dec_tree_edge_update,
)
mol_tree.pull(
pu,
dec_tree_node_msg,
dec_tree_node_reduce,
)
stack.pop()
effective_nodes = mol_tree.filter_nodes(lambda nodes: nodes.data['fail'] != 1)
effective_nodes, _ = torch.sort(effective_nodes)
return mol_tree, all_nodes, effective_nodes
...@@ -5,43 +5,24 @@ from .mol_tree import Vocab ...@@ -5,43 +5,24 @@ from .mol_tree import Vocab
from .nnutils import GRUUpdate, cuda from .nnutils import GRUUpdate, cuda
import itertools import itertools
import networkx as nx import networkx as nx
from dgl import batch, unbatch, line_graph from dgl import batch, unbatch, bfs_edges_generator
import dgl.function as DGLF import dgl.function as DGLF
from .line_profiler_integration import profile from .line_profiler_integration import profile
import numpy as np
MAX_NB = 8 MAX_NB = 8
def level_order(forest, roots): def level_order(forest, roots):
''' edges = bfs_edges_generator(forest, roots)
Given the forest and the list of root nodes, _, leaves = forest.find_edges(edges[-1])
returns iterator of list of edges ordered by depth, first in bottom-up edges_back = bfs_edges_generator(forest, roots, reversed=True)
and then top-down yield from reversed(edges_back)
''' yield from edges
edge_list = []
node_depth = {}
edge_list.append([])
for root in roots:
node_depth[root] = 0
for u, v in nx.bfs_edges(forest, root):
node_depth[v] = node_depth[u] + 1
if len(edge_list) == node_depth[u]:
edge_list.append([])
edge_list[node_depth[u]].append((u, v))
for edges in reversed(edge_list):
u, v = zip(*edges)
yield v, u
for edges in edge_list:
u, v = zip(*edges)
yield u, v
enc_tree_msg = [DGLF.copy_src(src='m', out='m'), DGLF.copy_src(src='rm', out='rm')] enc_tree_msg = [DGLF.copy_src(src='m', out='m'), DGLF.copy_src(src='rm', out='rm')]
enc_tree_reduce = [DGLF.sum(msgs='m', out='s'), DGLF.sum(msgs='rm', out='accum_rm')] enc_tree_reduce = [DGLF.sum(msg='m', out='s'), DGLF.sum(msg='rm', out='accum_rm')]
enc_tree_gather_msg = DGLF.copy_edge(edge='m', out='m') enc_tree_gather_msg = DGLF.copy_edge(edge='m', out='m')
enc_tree_gather_reduce = DGLF.sum(msgs='m', out='m') enc_tree_gather_reduce = DGLF.sum(msg='m', out='m')
class EncoderGatherUpdate(nn.Module): class EncoderGatherUpdate(nn.Module):
def __init__(self, hidden_size): def __init__(self, hidden_size):
...@@ -50,9 +31,9 @@ class EncoderGatherUpdate(nn.Module): ...@@ -50,9 +31,9 @@ class EncoderGatherUpdate(nn.Module):
self.W = nn.Linear(2 * hidden_size, hidden_size) self.W = nn.Linear(2 * hidden_size, hidden_size)
def forward(self, node): def forward(self, nodes):
x = node['x'] x = nodes.data['x']
m = node['m'] m = nodes.data['m']
return { return {
'h': torch.relu(self.W(torch.cat([x, m], 1))), 'h': torch.relu(self.W(torch.cat([x, m], 1))),
} }
...@@ -73,34 +54,32 @@ class DGLJTNNEncoder(nn.Module): ...@@ -73,34 +54,32 @@ class DGLJTNNEncoder(nn.Module):
self.enc_tree_update = GRUUpdate(hidden_size) self.enc_tree_update = GRUUpdate(hidden_size)
self.enc_tree_gather_update = EncoderGatherUpdate(hidden_size) self.enc_tree_gather_update = EncoderGatherUpdate(hidden_size)
@profile
def forward(self, mol_trees): def forward(self, mol_trees):
mol_tree_batch = batch(mol_trees) mol_tree_batch = batch(mol_trees)
# Build line graph to prepare for belief propagation # Build line graph to prepare for belief propagation
mol_tree_batch_lg = line_graph(mol_tree_batch, no_backtracking=True) mol_tree_batch_lg = mol_tree_batch.line_graph(backtracking=False, shared=True)
return self.run(mol_tree_batch, mol_tree_batch_lg) return self.run(mol_tree_batch, mol_tree_batch_lg)
@profile
def run(self, mol_tree_batch, mol_tree_batch_lg): def run(self, mol_tree_batch, mol_tree_batch_lg):
# Since tree roots are designated to 0. In the batched graph we can # Since tree roots are designated to 0. In the batched graph we can
# simply find the corresponding node ID by looking at node_offset # simply find the corresponding node ID by looking at node_offset
root_ids = mol_tree_batch.node_offset[:-1] node_offset = np.cumsum([0] + mol_tree_batch.batch_num_nodes)
n_nodes = len(mol_tree_batch.nodes) root_ids = node_offset[:-1]
edge_list = mol_tree_batch.edge_list n_nodes = mol_tree_batch.number_of_nodes()
n_edges = len(edge_list) n_edges = mol_tree_batch.number_of_edges()
# Assign structure embeddings to tree nodes # Assign structure embeddings to tree nodes
mol_tree_batch.set_n_repr({ mol_tree_batch.ndata.update({
'x': self.embedding(mol_tree_batch.get_n_repr()['wid']), 'x': self.embedding(mol_tree_batch.ndata['wid']),
'h': cuda(torch.zeros(n_nodes, self.hidden_size)), 'h': cuda(torch.zeros(n_nodes, self.hidden_size)),
}) })
# Initialize the intermediate variables according to Eq (4)-(8). # Initialize the intermediate variables according to Eq (4)-(8).
# Also initialize the src_x and dst_x fields. # Also initialize the src_x and dst_x fields.
# TODO: context? # TODO: context?
mol_tree_batch.set_e_repr({ mol_tree_batch.edata.update({
's': cuda(torch.zeros(n_edges, self.hidden_size)), 's': cuda(torch.zeros(n_edges, self.hidden_size)),
'm': cuda(torch.zeros(n_edges, self.hidden_size)), 'm': cuda(torch.zeros(n_edges, self.hidden_size)),
'r': cuda(torch.zeros(n_edges, self.hidden_size)), 'r': cuda(torch.zeros(n_edges, self.hidden_size)),
...@@ -112,10 +91,8 @@ class DGLJTNNEncoder(nn.Module): ...@@ -112,10 +91,8 @@ class DGLJTNNEncoder(nn.Module):
}) })
# Send the source/destination node features to edges # Send the source/destination node features to edges
mol_tree_batch.update_edge( mol_tree_batch.apply_edges(
#*zip(*edge_list), func=lambda edges: {'src_x': edges.src['x'], 'dst_x': edges.dst['x']},
edge_func=lambda src, dst, edge: {'src_x': src['x'], 'dst_x': dst['x']},
batchable=True,
) )
# Message passing # Message passing
...@@ -123,14 +100,13 @@ class DGLJTNNEncoder(nn.Module): ...@@ -123,14 +100,13 @@ class DGLJTNNEncoder(nn.Module):
# messages, and the uncomputed messages are zero vectors. Essentially, # messages, and the uncomputed messages are zero vectors. Essentially,
# we can always compute s_ij as the sum of incoming m_ij, no matter # we can always compute s_ij as the sum of incoming m_ij, no matter
# if m_ij is actually computed or not. # if m_ij is actually computed or not.
for u, v in level_order(mol_tree_batch, root_ids): for eid in level_order(mol_tree_batch, root_ids):
eid = mol_tree_batch.get_edge_id(u, v) #eid = mol_tree_batch.edge_ids(u, v)
mol_tree_batch_lg.pull( mol_tree_batch_lg.pull(
eid, eid,
enc_tree_msg, enc_tree_msg,
enc_tree_reduce, enc_tree_reduce,
self.enc_tree_update, self.enc_tree_update,
batchable=True,
) )
# Readout # Readout
...@@ -138,9 +114,8 @@ class DGLJTNNEncoder(nn.Module): ...@@ -138,9 +114,8 @@ class DGLJTNNEncoder(nn.Module):
enc_tree_gather_msg, enc_tree_gather_msg,
enc_tree_gather_reduce, enc_tree_gather_reduce,
self.enc_tree_gather_update, self.enc_tree_gather_update,
batchable=True,
) )
root_vecs = mol_tree_batch.get_n_repr(root_ids)['h'] root_vecs = mol_tree_batch.nodes[root_ids].data['h']
return mol_tree_batch, root_vecs return mol_tree_batch, root_vecs
...@@ -2,11 +2,15 @@ import torch ...@@ -2,11 +2,15 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from .mol_tree import Vocab from .mol_tree import Vocab
from .nnutils import create_var, cuda from .nnutils import create_var, cuda, move_dgl_to_cuda
from .chemutils import set_atommap, copy_edit_mol, enum_assemble_nx, \
attach_mols_nx, decode_stereo
from .jtnn_enc import DGLJTNNEncoder from .jtnn_enc import DGLJTNNEncoder
from .jtnn_dec import DGLJTNNDecoder from .jtnn_dec import DGLJTNNDecoder
from .mpn import DGLMPN, mol2dgl from .mpn import DGLMPN
from .mpn import mol2dgl_single as mol2dgl_enc
from .jtmpn import DGLJTMPN from .jtmpn import DGLJTMPN
from .jtmpn import mol2dgl_single as mol2dgl_dec
from .line_profiler_integration import profile from .line_profiler_integration import profile
import rdkit import rdkit
...@@ -15,15 +19,7 @@ from rdkit import DataStructs ...@@ -15,15 +19,7 @@ from rdkit import DataStructs
from rdkit.Chem import AllChem from rdkit.Chem import AllChem
import copy, math import copy, math
def dgl_set_batch_nodeID(mol_batch, vocab): from dgl import batch, unbatch
tot = 0
for mol_tree in mol_batch:
wid = []
for node in mol_tree.nodes:
mol_tree.nodes[node]['idx'] = tot
tot += 1
wid.append(vocab.get_index(mol_tree.nodes[node]['smiles']))
mol_tree.set_n_repr({'wid': cuda(torch.LongTensor(wid))})
class DGLJTNNVAE(nn.Module): class DGLJTNNVAE(nn.Module):
...@@ -51,86 +47,75 @@ class DGLJTNNVAE(nn.Module): ...@@ -51,86 +47,75 @@ class DGLJTNNVAE(nn.Module):
self.n_edges_total = 0 self.n_edges_total = 0
self.n_tree_nodes_total = 0 self.n_tree_nodes_total = 0
@profile @staticmethod
def encode(self, mol_batch): def move_to_cuda(mol_batch):
dgl_set_batch_nodeID(mol_batch, self.vocab) for t in mol_batch['mol_trees']:
move_dgl_to_cuda(t)
move_dgl_to_cuda(mol_batch['mol_graph_batch'])
if 'cand_graph_batch' in mol_batch:
move_dgl_to_cuda(mol_batch['cand_graph_batch'])
if mol_batch.get('stereo_cand_graph_batch') is not None:
move_dgl_to_cuda(mol_batch['stereo_cand_graph_batch'])
smiles_batch = [mol_tree.smiles for mol_tree in mol_batch] def encode(self, mol_batch):
mol_graphs = mol2dgl(smiles_batch) mol_graphs = mol_batch['mol_graph_batch']
mol_vec = self.mpn(mol_graphs) mol_vec = self.mpn(mol_graphs)
# mol_batch is a junction tree
mol_tree_batch, tree_vec = self.jtnn(mol_batch)
self.n_nodes_total += sum(len(g.nodes) for g in mol_graphs) mol_tree_batch, tree_vec = self.jtnn(mol_batch['mol_trees'])
self.n_edges_total += sum(len(g.edges) for g in mol_graphs)
self.n_tree_nodes_total += sum(len(t.nodes) for t in mol_batch) self.n_nodes_total += mol_graphs.number_of_nodes()
self.n_edges_total += mol_graphs.number_of_edges()
self.n_tree_nodes_total += sum(t.number_of_nodes() for t in mol_batch['mol_trees'])
self.n_passes += 1 self.n_passes += 1
return mol_tree_batch, tree_vec, mol_vec return mol_tree_batch, tree_vec, mol_vec
@profile def sample(self, tree_vec, mol_vec, e1=None, e2=None):
def forward(self, mol_batch, beta=0, e1=None, e2=None):
batch_size = len(mol_batch)
mol_tree_batch, tree_vec, mol_vec = self.encode(mol_batch)
tree_mean = self.T_mean(tree_vec) tree_mean = self.T_mean(tree_vec)
tree_log_var = -torch.abs(self.T_var(tree_vec)) tree_log_var = -torch.abs(self.T_var(tree_vec))
mol_mean = self.G_mean(mol_vec) mol_mean = self.G_mean(mol_vec)
mol_log_var = -torch.abs(self.G_var(mol_vec)) mol_log_var = -torch.abs(self.G_var(mol_vec))
self.tree_mean = tree_mean epsilon = cuda(torch.randn(*tree_mean.shape)) if e1 is None else e1
self.tree_log_var = tree_log_var tree_vec = tree_mean + torch.exp(tree_log_var / 2) * epsilon
self.mol_mean = mol_mean epsilon = cuda(torch.randn(*mol_mean.shape)) if e2 is None else e2
self.mol_log_var = mol_log_var mol_vec = mol_mean + torch.exp(mol_log_var / 2) * epsilon
z_mean = torch.cat([tree_mean, mol_mean], dim=1) z_mean = torch.cat([tree_mean, mol_mean], 1)
z_log_var = torch.cat([tree_log_var, mol_log_var], dim=1) z_log_var = torch.cat([tree_log_var, mol_log_var], 1)
kl_loss = -0.5 * torch.sum(1.0 + z_log_var - z_mean * z_mean - torch.exp(z_log_var)) / batch_size
self.z_mean = z_mean return tree_vec, mol_vec, z_mean, z_log_var
self.z_log_var = z_log_var
epsilon = cuda(torch.randn(batch_size, self.latent_size // 2)) if e1 is None else e1 def forward(self, mol_batch, beta=0, e1=None, e2=None):
tree_vec = tree_mean + torch.exp(tree_log_var / 2) * epsilon self.move_to_cuda(mol_batch)
epsilon = cuda(torch.randn(batch_size, self.latent_size // 2)) if e2 is None else e2
mol_vec = mol_mean + torch.exp(mol_log_var / 2) * epsilon mol_trees = mol_batch['mol_trees']
batch_size = len(mol_trees)
mol_tree_batch, tree_vec, mol_vec = self.encode(mol_batch)
self.tree_vec = tree_vec tree_vec, mol_vec, z_mean, z_log_var = self.sample(tree_vec, mol_vec, e1, e2)
self.mol_vec = mol_vec kl_loss = -0.5 * torch.sum(1.0 + z_log_var - z_mean * z_mean - torch.exp(z_log_var)) / batch_size
word_loss, topo_loss, word_acc, topo_acc = self.decoder(mol_batch, tree_vec) word_loss, topo_loss, word_acc, topo_acc = self.decoder(mol_trees, tree_vec)
assm_loss, assm_acc = self.assm(mol_batch, mol_tree_batch, mol_vec) assm_loss, assm_acc = self.assm(mol_batch, mol_tree_batch, mol_vec)
stereo_loss, stereo_acc = self.stereo(mol_batch, mol_vec) stereo_loss, stereo_acc = self.stereo(mol_batch, mol_vec)
self.word_loss_v = word_loss
self.topo_loss_v = topo_loss
self.assm_loss_v = assm_loss
self.stereo_loss_v = stereo_loss
all_vec = torch.cat([tree_vec, mol_vec], dim=1) all_vec = torch.cat([tree_vec, mol_vec], dim=1)
loss = word_loss + topo_loss + assm_loss + 2 * stereo_loss + beta * kl_loss loss = word_loss + topo_loss + assm_loss + 2 * stereo_loss + beta * kl_loss
self.all_vec = all_vec
return loss, kl_loss, word_acc, topo_acc, assm_acc, stereo_acc return loss, kl_loss, word_acc, topo_acc, assm_acc, stereo_acc
@profile
def assm(self, mol_batch, mol_tree_batch, mol_vec): def assm(self, mol_batch, mol_tree_batch, mol_vec):
cands = [] cands = [mol_batch['cand_graph_batch'],
batch_idx = [] mol_batch['tree_mess_src_e'],
mol_batch['tree_mess_tgt_e'],
for i, mol_tree in enumerate(mol_batch): mol_batch['tree_mess_tgt_n']]
for node_id, node in mol_tree.nodes.items():
if node['is_leaf'] or len(node['cands']) == 1:
continue
cands.extend([(cand, mol_tree, node_id) for cand in node['cand_mols']])
batch_idx.extend([i] * len(node['cands']))
cand_vec = self.jtmpn(cands, mol_tree_batch) cand_vec = self.jtmpn(cands, mol_tree_batch)
cand_vec = self.G_mean(cand_vec) cand_vec = self.G_mean(cand_vec)
batch_idx = cuda(torch.LongTensor(batch_idx)) batch_idx = cuda(torch.LongTensor(mol_batch['cand_batch_idx']))
mol_vec = mol_vec[batch_idx] mol_vec = mol_vec[batch_idx]
mol_vec = mol_vec.view(-1, 1, self.latent_size // 2) mol_vec = mol_vec.view(-1, 1, self.latent_size // 2)
...@@ -139,13 +124,13 @@ class DGLJTNNVAE(nn.Module): ...@@ -139,13 +124,13 @@ class DGLJTNNVAE(nn.Module):
cnt, tot, acc = 0, 0, 0 cnt, tot, acc = 0, 0, 0
all_loss = [] all_loss = []
for i, mol_tree in enumerate(mol_batch): for i, mol_tree in enumerate(mol_batch['mol_trees']):
comp_nodes = [node_id for node_id, node in mol_tree.nodes.items() comp_nodes = [node_id for node_id, node in mol_tree.nodes_dict.items()
if len(node['cands']) > 1 and not node['is_leaf']] if len(node['cands']) > 1 and not node['is_leaf']]
cnt += len(comp_nodes) cnt += len(comp_nodes)
# segmented accuracy and cross entropy # segmented accuracy and cross entropy
for node_id in comp_nodes: for node_id in comp_nodes:
node = mol_tree.nodes[node_id] node = mol_tree.nodes_dict[node_id]
label = node['cands'].index(node['label']) label = node['cands'].index(node['label'])
ncand = len(node['cands']) ncand = len(node['cands'])
cur_score = scores[tot:tot+ncand] cur_score = scores[tot:tot+ncand]
...@@ -158,36 +143,28 @@ class DGLJTNNVAE(nn.Module): ...@@ -158,36 +143,28 @@ class DGLJTNNVAE(nn.Module):
all_loss.append( all_loss.append(
F.cross_entropy(cur_score.view(1, -1), label, size_average=False)) F.cross_entropy(cur_score.view(1, -1), label, size_average=False))
all_loss = sum(all_loss) / len(mol_batch) all_loss = sum(all_loss) / len(mol_batch['mol_trees'])
return all_loss, acc / cnt return all_loss, acc / cnt
@profile
def stereo(self, mol_batch, mol_vec): def stereo(self, mol_batch, mol_vec):
stereo_cands, batch_idx = [], [] stereo_cands = mol_batch['stereo_cand_graph_batch']
labels = [] batch_idx = mol_batch['stereo_cand_batch_idx']
for i, mol_tree in enumerate(mol_batch): labels = mol_batch['stereo_cand_labels']
cands = mol_tree.stereo_cands lengths = mol_batch['stereo_cand_lengths']
if len(cands) == 1:
continue
if mol_tree.smiles3D not in cands:
cands.append(mol_tree.smiles3D)
stereo_cands.extend(cands)
batch_idx.extend([i] * len(cands))
labels.append((cands.index(mol_tree.smiles3D), len(cands)))
if len(labels) == 0: if len(labels) == 0:
# Only one stereoisomer exists; do nothing # Only one stereoisomer exists; do nothing
return cuda(torch.tensor(0.)), 1. return cuda(torch.tensor(0.)), 1.
batch_idx = cuda(torch.LongTensor(batch_idx)) batch_idx = cuda(torch.LongTensor(batch_idx))
stereo_cands = self.mpn(mol2dgl(stereo_cands)) stereo_cands = self.mpn(stereo_cands)
stereo_cands = self.G_mean(stereo_cands) stereo_cands = self.G_mean(stereo_cands)
stereo_labels = mol_vec[batch_idx] stereo_labels = mol_vec[batch_idx]
scores = F.cosine_similarity(stereo_cands, stereo_labels) scores = F.cosine_similarity(stereo_cands, stereo_labels)
st, acc = 0, 0 st, acc = 0, 0
all_loss = [] all_loss = []
for label, le in labels: for label, le in zip(labels, lengths):
cur_scores = scores[st:st+le] cur_scores = scores[st:st+le]
if cur_scores.data[label].item() >= cur_scores.max().item(): if cur_scores.data[label].item() >= cur_scores.max().item():
acc += 1 acc += 1
...@@ -198,3 +175,134 @@ class DGLJTNNVAE(nn.Module): ...@@ -198,3 +175,134 @@ class DGLJTNNVAE(nn.Module):
all_loss = sum(all_loss) / len(labels) all_loss = sum(all_loss) / len(labels)
return all_loss, acc / len(labels) return all_loss, acc / len(labels)
def decode(self, tree_vec, mol_vec):
mol_tree, nodes_dict, effective_nodes = self.decoder.decode(tree_vec)
effective_nodes_list = effective_nodes.tolist()
nodes_dict = [nodes_dict[v] for v in effective_nodes_list]
for i, (node_id, node) in enumerate(zip(effective_nodes_list, nodes_dict)):
node['idx'] = i
node['nid'] = i + 1
node['is_leaf'] = True
if mol_tree.in_degree(node_id) > 1:
node['is_leaf'] = False
set_atommap(node['mol'], node['nid'])
mol_tree_sg = mol_tree.subgraph(effective_nodes)
mol_tree_sg.copy_from_parent()
mol_tree_msg, _ = self.jtnn([mol_tree_sg])
mol_tree_msg = unbatch(mol_tree_msg)[0]
mol_tree_msg.nodes_dict = nodes_dict
cur_mol = copy_edit_mol(nodes_dict[0]['mol'])
global_amap = [{}] + [{} for node in nodes_dict]
global_amap[1] = {atom.GetIdx(): atom.GetIdx() for atom in cur_mol.GetAtoms()}
cur_mol = self.dfs_assemble(mol_tree_msg, mol_vec, cur_mol, global_amap, [], 0, None)
if cur_mol is None:
return None
cur_mol = cur_mol.GetMol()
set_atommap(cur_mol)
cur_mol = Chem.MolFromSmiles(Chem.MolToSmiles(cur_mol))
if cur_mol is None:
return None
smiles2D = Chem.MolToSmiles(cur_mol)
stereo_cands = decode_stereo(smiles2D)
if len(stereo_cands) == 1:
return stereo_cands[0]
stereo_graphs = [mol2dgl_enc(c) for c in stereo_cands]
stereo_cand_graphs, atom_x, bond_x = \
zip(*stereo_graphs)
stereo_cand_graphs = batch(stereo_cand_graphs)
atom_x = cuda(torch.cat(atom_x))
bond_x = cuda(torch.cat(bond_x))
stereo_cand_graphs.ndata['x'] = atom_x
stereo_cand_graphs.edata['x'] = bond_x
stereo_cand_graphs.edata['src_x'] = atom_x.new(
bond_x.shape[0], atom_x.shape[1]).zero_()
stereo_vecs = self.mpn(stereo_cand_graphs)
stereo_vecs = self.G_mean(stereo_vecs)
scores = F.cosine_similarity(stereo_vecs, mol_vec)
_, max_id = scores.max(0)
return stereo_cands[max_id.item()]
def dfs_assemble(self, mol_tree_msg, mol_vec, cur_mol,
global_amap, fa_amap, cur_node_id, fa_node_id):
nodes_dict = mol_tree_msg.nodes_dict
fa_node = nodes_dict[fa_node_id] if fa_node_id is not None else None
cur_node = nodes_dict[cur_node_id]
fa_nid = fa_node['nid'] if fa_node is not None else -1
prev_nodes = [fa_node] if fa_node is not None else []
children_node_id = [v for v in mol_tree_msg.successors(cur_node_id).tolist()
if nodes_dict[v]['nid'] != fa_nid]
children = [nodes_dict[v] for v in children_node_id]
neighbors = [nei for nei in children if nei['mol'].GetNumAtoms() > 1]
neighbors = sorted(neighbors, key=lambda x: x['mol'].GetNumAtoms(), reverse=True)
singletons = [nei for nei in children if nei['mol'].GetNumAtoms() == 1]
neighbors = singletons + neighbors
cur_amap = [(fa_nid, a2, a1) for nid, a1, a2 in fa_amap if nid == cur_node['nid']]
cands = enum_assemble_nx(cur_node, neighbors, prev_nodes, cur_amap)
if len(cands) == 0:
return None
cand_smiles, cand_mols, cand_amap = list(zip(*cands))
cands = [(candmol, mol_tree_msg, cur_node_id) for candmol in cand_mols]
cand_graphs, atom_x, bond_x, tree_mess_src_edges, \
tree_mess_tgt_edges, tree_mess_tgt_nodes = mol2dgl_dec(
cands)
cand_graphs = batch(cand_graphs)
atom_x = cuda(atom_x)
bond_x = cuda(bond_x)
cand_graphs.ndata['x'] = atom_x
cand_graphs.edata['x'] = bond_x
cand_graphs.edata['src_x'] = atom_x.new(bond_x.shape[0], atom_x.shape[1]).zero_()
cand_vecs = self.jtmpn(
(cand_graphs, tree_mess_src_edges, tree_mess_tgt_edges, tree_mess_tgt_nodes),
mol_tree_msg,
)
cand_vecs = self.G_mean(cand_vecs)
mol_vec = mol_vec.squeeze()
scores = cand_vecs @ mol_vec
_, cand_idx = torch.sort(scores, descending=True)
backup_mol = Chem.RWMol(cur_mol)
for i in range(len(cand_idx)):
cur_mol = Chem.RWMol(backup_mol)
pred_amap = cand_amap[cand_idx[i].item()]
new_global_amap = copy.deepcopy(global_amap)
for nei_id, ctr_atom, nei_atom in pred_amap:
if nei_id == fa_nid:
continue
new_global_amap[nei_id][nei_atom] = new_global_amap[cur_node['nid']][ctr_atom]
cur_mol = attach_mols_nx(cur_mol, children, [], new_global_amap)
new_mol = cur_mol.GetMol()
new_mol = Chem.MolFromSmiles(Chem.MolToSmiles(new_mol))
if new_mol is None:
continue
result = True
for nei_node_id, nei_node in zip(children_node_id, children):
if nei_node['is_leaf']:
continue
cur_mol = self.dfs_assemble(
mol_tree_msg, mol_vec, cur_mol, new_global_amap, pred_amap,
nei_node_id, cur_node_id)
if cur_mol is None:
result = False
break
if result:
return cur_mol
return None
...@@ -2,10 +2,17 @@ from dgl import DGLGraph ...@@ -2,10 +2,17 @@ from dgl import DGLGraph
import rdkit.Chem as Chem import rdkit.Chem as Chem
from .chemutils import get_clique_mol, tree_decomp, get_mol, get_smiles, \ from .chemutils import get_clique_mol, tree_decomp, get_mol, get_smiles, \
set_atommap, enum_assemble_nx, decode_stereo set_atommap, enum_assemble_nx, decode_stereo
import numpy as np
from .line_profiler_integration import profile
class DGLMolTree(DGLGraph): class DGLMolTree(DGLGraph):
def __init__(self, smiles): def __init__(self, smiles):
DGLGraph.__init__(self) DGLGraph.__init__(self)
self.nodes_dict = {}
if smiles is None:
return
self.smiles = smiles self.smiles = smiles
self.mol = get_mol(smiles) self.mol = get_mol(smiles)
...@@ -21,39 +28,43 @@ class DGLMolTree(DGLGraph): ...@@ -21,39 +28,43 @@ class DGLMolTree(DGLGraph):
for i, c in enumerate(cliques): for i, c in enumerate(cliques):
cmol = get_clique_mol(self.mol, c) cmol = get_clique_mol(self.mol, c)
csmiles = get_smiles(cmol) csmiles = get_smiles(cmol)
self.add_node( self.nodes_dict[i] = dict(
i,
smiles=csmiles, smiles=csmiles,
mol=get_mol(csmiles), mol=get_mol(csmiles),
clique=c, clique=c,
) )
if min(c) == 0: if min(c) == 0:
root = i root = i
self.add_nodes(len(cliques))
# The clique with atom ID 0 becomes root # The clique with atom ID 0 becomes root
if root > 0: if root > 0:
for attr in self.nodes[0]: for attr in self.nodes_dict[0]:
self.nodes[0][attr], self.nodes[root][attr] = \ self.nodes_dict[0][attr], self.nodes_dict[root][attr] = \
self.nodes[root][attr], self.nodes[0][attr] self.nodes_dict[root][attr], self.nodes_dict[0][attr]
for _x, _y in edges: src = np.zeros((len(edges) * 2,), dtype='int')
dst = np.zeros((len(edges) * 2,), dtype='int')
for i, (_x, _y) in enumerate(edges):
x = 0 if _x == root else root if _x == 0 else _x x = 0 if _x == root else root if _x == 0 else _x
y = 0 if _y == root else root if _y == 0 else _y y = 0 if _y == root else root if _y == 0 else _y
self.add_edge(x, y) src[2 * i] = x
self.add_edge(y, x) dst[2 * i] = y
src[2 * i + 1] = y
for i in self.nodes: dst[2 * i + 1] = x
self.nodes[i]['nid'] = i + 1 self.add_edges(src, dst)
if len(self[i]) > 1: # Leaf node mol is not marked
set_atommap(self.nodes[i]['mol'], self.nodes[i]['nid']) for i in self.nodes_dict:
self.nodes[i]['is_leaf'] = (len(self[i]) == 1) self.nodes_dict[i]['nid'] = i + 1
if self.out_degree(i) > 1: # Leaf node mol is not marked
set_atommap(self.nodes_dict[i]['mol'], self.nodes_dict[i]['nid'])
self.nodes_dict[i]['is_leaf'] = (self.out_degree(i) == 1)
# avoiding DiGraph.size()
def treesize(self): def treesize(self):
return len(self.nodes) return self.number_of_nodes()
def _recover_node(self, i, original_mol): def _recover_node(self, i, original_mol):
node = self.nodes[i] node = self.nodes_dict[i]
clique = [] clique = []
clique.extend(node['clique']) clique.extend(node['clique'])
...@@ -61,8 +72,8 @@ class DGLMolTree(DGLGraph): ...@@ -61,8 +72,8 @@ class DGLMolTree(DGLGraph):
for cidx in node['clique']: for cidx in node['clique']:
original_mol.GetAtomWithIdx(cidx).SetAtomMapNum(node['nid']) original_mol.GetAtomWithIdx(cidx).SetAtomMapNum(node['nid'])
for j in self[i]: for j in self.successors(i).numpy():
nei_node = self.nodes[j] nei_node = self.nodes_dict[j]
clique.extend(nei_node['clique']) clique.extend(nei_node['clique'])
if nei_node['is_leaf']: # Leaf node, no need to mark if nei_node['is_leaf']: # Leaf node, no need to mark
continue continue
...@@ -83,25 +94,27 @@ class DGLMolTree(DGLGraph): ...@@ -83,25 +94,27 @@ class DGLMolTree(DGLGraph):
return node['label'] return node['label']
def _assemble_node(self, i): def _assemble_node(self, i):
neighbors = [self.nodes[j] for j in self[i] if self.nodes[j]['mol'].GetNumAtoms() > 1] neighbors = [self.nodes_dict[j] for j in self.successors(i).numpy()
if self.nodes_dict[j]['mol'].GetNumAtoms() > 1]
neighbors = sorted(neighbors, key=lambda x: x['mol'].GetNumAtoms(), reverse=True) neighbors = sorted(neighbors, key=lambda x: x['mol'].GetNumAtoms(), reverse=True)
singletons = [self.nodes[j] for j in self[i] if self.nodes[j]['mol'].GetNumAtoms() == 1] singletons = [self.nodes_dict[j] for j in self.successors(i).numpy()
if self.nodes_dict[j]['mol'].GetNumAtoms() == 1]
neighbors = singletons + neighbors neighbors = singletons + neighbors
cands = enum_assemble_nx(self, i, neighbors) cands = enum_assemble_nx(self.nodes_dict[i], neighbors)
if len(cands) > 0: if len(cands) > 0:
self.nodes[i]['cands'], self.nodes[i]['cand_mols'], _ = list(zip(*cands)) self.nodes_dict[i]['cands'], self.nodes_dict[i]['cand_mols'], _ = list(zip(*cands))
self.nodes[i]['cands'] = list(self.nodes[i]['cands']) self.nodes_dict[i]['cands'] = list(self.nodes_dict[i]['cands'])
self.nodes[i]['cand_mols'] = list(self.nodes[i]['cand_mols']) self.nodes_dict[i]['cand_mols'] = list(self.nodes_dict[i]['cand_mols'])
else: else:
self.nodes[i]['cands'] = [] self.nodes_dict[i]['cands'] = []
self.nodes[i]['cand_mols'] = [] self.nodes_dict[i]['cand_mols'] = []
def recover(self): def recover(self):
for i in self.nodes: for i in self.nodes_dict:
self._recover_node(i, self.mol) self._recover_node(i, self.mol)
def assemble(self): def assemble(self):
for i in self.nodes: for i in self.nodes_dict:
self._assemble_node(i) self._assemble_node(i)
...@@ -4,11 +4,12 @@ import rdkit.Chem as Chem ...@@ -4,11 +4,12 @@ import rdkit.Chem as Chem
import torch.nn.functional as F import torch.nn.functional as F
from .nnutils import * from .nnutils import *
from .chemutils import get_mol from .chemutils import get_mol
from networkx import Graph, DiGraph, line_graph, convert_node_labels_to_integers from networkx import Graph, DiGraph, convert_node_labels_to_integers
from dgl import DGLGraph, line_graph, batch, unbatch from dgl import DGLGraph, batch, unbatch, mean_nodes
import dgl.function as DGLF import dgl.function as DGLF
from functools import partial from functools import partial
from .line_profiler_integration import profile from .line_profiler_integration import profile
import numpy as np
ELEM_LIST = ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na', 'Ca', 'Fe', 'Al', 'I', 'B', 'K', 'Se', 'Zn', 'H', 'Cu', 'Mn', 'unknown'] ELEM_LIST = ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na', 'Ca', 'Fe', 'Al', 'I', 'B', 'K', 'Se', 'Zn', 'H', 'Cu', 'Mn', 'unknown']
...@@ -22,7 +23,7 @@ def onek_encoding_unk(x, allowable_set): ...@@ -22,7 +23,7 @@ def onek_encoding_unk(x, allowable_set):
return [x == s for s in allowable_set] return [x == s for s in allowable_set]
def atom_features(atom): def atom_features(atom):
return cuda(torch.Tensor(onek_encoding_unk(atom.GetSymbol(), ELEM_LIST) return (torch.Tensor(onek_encoding_unk(atom.GetSymbol(), ELEM_LIST)
+ onek_encoding_unk(atom.GetDegree(), [0,1,2,3,4,5]) + onek_encoding_unk(atom.GetDegree(), [0,1,2,3,4,5])
+ onek_encoding_unk(atom.GetFormalCharge(), [-1,-2,1,2,0]) + onek_encoding_unk(atom.GetFormalCharge(), [-1,-2,1,2,0])
+ onek_encoding_unk(int(atom.GetChiralTag()), [0,1,2,3]) + onek_encoding_unk(int(atom.GetChiralTag()), [0,1,2,3])
...@@ -33,46 +34,45 @@ def bond_features(bond): ...@@ -33,46 +34,45 @@ def bond_features(bond):
stereo = int(bond.GetStereo()) stereo = int(bond.GetStereo())
fbond = [bt == Chem.rdchem.BondType.SINGLE, bt == Chem.rdchem.BondType.DOUBLE, bt == Chem.rdchem.BondType.TRIPLE, bt == Chem.rdchem.BondType.AROMATIC, bond.IsInRing()] fbond = [bt == Chem.rdchem.BondType.SINGLE, bt == Chem.rdchem.BondType.DOUBLE, bt == Chem.rdchem.BondType.TRIPLE, bt == Chem.rdchem.BondType.AROMATIC, bond.IsInRing()]
fstereo = onek_encoding_unk(stereo, [0,1,2,3,4,5]) fstereo = onek_encoding_unk(stereo, [0,1,2,3,4,5])
return cuda(torch.Tensor(fbond + fstereo)) return (torch.Tensor(fbond + fstereo))
@profile def mol2dgl_single(smiles):
def mol2dgl(smiles_batch): n_edges = 0
n_nodes = 0
graph_list = [] atom_x = []
for smiles in smiles_batch: bond_x = []
atom_feature_list = []
bond_feature_list = [] mol = get_mol(smiles)
bond_source_feature_list = [] n_atoms = mol.GetNumAtoms()
graph = DGLGraph() n_bonds = mol.GetNumBonds()
mol = get_mol(smiles) graph = DGLGraph()
for atom in mol.GetAtoms(): for i, atom in enumerate(mol.GetAtoms()):
graph.add_node(atom.GetIdx()) assert i == atom.GetIdx()
atom_feature_list.append(atom_features(atom)) atom_x.append(atom_features(atom))
for bond in mol.GetBonds(): graph.add_nodes(n_atoms)
begin_idx = bond.GetBeginAtom().GetIdx()
end_idx = bond.GetEndAtom().GetIdx() bond_src = []
features = bond_features(bond) bond_dst = []
graph.add_edge(begin_idx, end_idx) for i, bond in enumerate(mol.GetBonds()):
bond_feature_list.append(features) begin_idx = bond.GetBeginAtom().GetIdx()
# set up the reverse direction end_idx = bond.GetEndAtom().GetIdx()
graph.add_edge(end_idx, begin_idx) features = bond_features(bond)
bond_feature_list.append(features) bond_src.append(begin_idx)
bond_dst.append(end_idx)
atom_x = torch.stack(atom_feature_list) bond_x.append(features)
graph.set_n_repr({'x': atom_x}) # set up the reverse direction
if len(bond_feature_list) > 0: bond_src.append(end_idx)
bond_x = torch.stack(bond_feature_list) bond_dst.append(begin_idx)
graph.set_e_repr({ bond_x.append(features)
'x': bond_x, graph.add_edges(bond_src, bond_dst)
'src_x': atom_x.new(len(bond_feature_list), ATOM_FDIM).zero_()
}) n_edges += n_bonds
graph_list.append(graph) return graph, torch.stack(atom_x), \
torch.stack(bond_x) if len(bond_x) > 0 else torch.zeros(0)
return graph_list
mpn_loopy_bp_msg = DGLF.copy_src(src='msg', out='msg') mpn_loopy_bp_msg = DGLF.copy_src(src='msg', out='msg')
mpn_loopy_bp_reduce = DGLF.sum(msgs='msg', out='accum_msg') mpn_loopy_bp_reduce = DGLF.sum(msg='msg', out='accum_msg')
class LoopyBPUpdate(nn.Module): class LoopyBPUpdate(nn.Module):
...@@ -82,15 +82,15 @@ class LoopyBPUpdate(nn.Module): ...@@ -82,15 +82,15 @@ class LoopyBPUpdate(nn.Module):
self.W_h = nn.Linear(hidden_size, hidden_size, bias=False) self.W_h = nn.Linear(hidden_size, hidden_size, bias=False)
def forward(self, node): def forward(self, nodes):
msg_input = node['msg_input'] msg_input = nodes.data['msg_input']
msg_delta = self.W_h(node['accum_msg']) msg_delta = self.W_h(nodes.data['accum_msg'])
msg = F.relu(msg_input + msg_delta) msg = F.relu(msg_input + msg_delta)
return {'msg': msg} return {'msg': msg}
mpn_gather_msg = DGLF.copy_edge(edge='msg', out='msg') mpn_gather_msg = DGLF.copy_edge(edge='msg', out='msg')
mpn_gather_reduce = DGLF.sum(msgs='msg', out='m') mpn_gather_reduce = DGLF.sum(msg='msg', out='m')
class GatherUpdate(nn.Module): class GatherUpdate(nn.Module):
...@@ -100,10 +100,10 @@ class GatherUpdate(nn.Module): ...@@ -100,10 +100,10 @@ class GatherUpdate(nn.Module):
self.W_o = nn.Linear(ATOM_FDIM + hidden_size, hidden_size) self.W_o = nn.Linear(ATOM_FDIM + hidden_size, hidden_size)
def forward(self, node): def forward(self, nodes):
m = node['m'] m = nodes.data['m']
return { return {
'h': F.relu(self.W_o(torch.cat([node['x'], m], 1))), 'h': F.relu(self.W_o(torch.cat([nodes.data['x'], m], 1))),
} }
...@@ -124,19 +124,18 @@ class DGLMPN(nn.Module): ...@@ -124,19 +124,18 @@ class DGLMPN(nn.Module):
self.n_edges_total = 0 self.n_edges_total = 0
self.n_passes = 0 self.n_passes = 0
@profile def forward(self, mol_graph):
def forward(self, mol_graph_list): n_samples = mol_graph.batch_size
n_samples = len(mol_graph_list)
mol_graph = batch(mol_graph_list) mol_line_graph = mol_graph.line_graph(backtracking=False, shared=True)
mol_line_graph = line_graph(mol_graph, no_backtracking=True)
n_nodes = len(mol_graph.nodes) n_nodes = mol_graph.number_of_nodes()
n_edges = len(mol_graph.edges) n_edges = mol_graph.number_of_edges()
mol_graph = self.run(mol_graph, mol_line_graph) mol_graph = self.run(mol_graph, mol_line_graph)
mol_graph_list = unbatch(mol_graph)
g_repr = torch.stack([g.get_n_repr()['h'].mean(0) for g in mol_graph_list], 0) # TODO: replace with unbatch or readout
g_repr = mean_nodes(mol_graph, 'h')
self.n_samples_total += n_samples self.n_samples_total += n_samples
self.n_nodes_total += n_nodes self.n_nodes_total += n_nodes
...@@ -145,27 +144,25 @@ class DGLMPN(nn.Module): ...@@ -145,27 +144,25 @@ class DGLMPN(nn.Module):
return g_repr return g_repr
@profile
def run(self, mol_graph, mol_line_graph): def run(self, mol_graph, mol_line_graph):
n_nodes = len(mol_graph.nodes) n_nodes = mol_graph.number_of_nodes()
mol_graph.update_edge( mol_graph.apply_edges(
#*zip(*mol_graph.edge_list), func=lambda edges: {'src_x': edges.src['x']},
edge_func=lambda src, dst, edge: {'src_x': src['x']},
batchable=True,
) )
bond_features = mol_line_graph.get_n_repr()['x'] e_repr = mol_line_graph.ndata
source_features = mol_line_graph.get_n_repr()['src_x'] bond_features = e_repr['x']
source_features = e_repr['src_x']
features = torch.cat([source_features, bond_features], 1) features = torch.cat([source_features, bond_features], 1)
msg_input = self.W_i(features) msg_input = self.W_i(features)
mol_line_graph.set_n_repr({ mol_line_graph.ndata.update({
'msg_input': msg_input, 'msg_input': msg_input,
'msg': F.relu(msg_input), 'msg': F.relu(msg_input),
'accum_msg': torch.zeros_like(msg_input), 'accum_msg': torch.zeros_like(msg_input),
}) })
mol_graph.set_n_repr({ mol_graph.ndata.update({
'm': bond_features.new(n_nodes, self.hidden_size).zero_(), 'm': bond_features.new(n_nodes, self.hidden_size).zero_(),
'h': bond_features.new(n_nodes, self.hidden_size).zero_(), 'h': bond_features.new(n_nodes, self.hidden_size).zero_(),
}) })
...@@ -175,14 +172,12 @@ class DGLMPN(nn.Module): ...@@ -175,14 +172,12 @@ class DGLMPN(nn.Module):
mpn_loopy_bp_msg, mpn_loopy_bp_msg,
mpn_loopy_bp_reduce, mpn_loopy_bp_reduce,
self.loopy_bp_updater, self.loopy_bp_updater,
True
) )
mol_graph.update_all( mol_graph.update_all(
mpn_gather_msg, mpn_gather_msg,
mpn_gather_reduce, mpn_gather_reduce,
self.gather_updater, self.gather_updater,
True
) )
return mol_graph return mol_graph
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.autograd import Variable from torch.autograd import Variable
import os
def create_var(tensor, requires_grad=None): def create_var(tensor, requires_grad=None):
if requires_grad is None: if requires_grad is None:
...@@ -9,7 +10,7 @@ def create_var(tensor, requires_grad=None): ...@@ -9,7 +10,7 @@ def create_var(tensor, requires_grad=None):
return Variable(tensor, requires_grad=requires_grad) return Variable(tensor, requires_grad=requires_grad)
def cuda(tensor): def cuda(tensor):
if torch.cuda.is_available(): if torch.cuda.is_available() and not os.getenv('NOCUDA', None):
return tensor.cuda() return tensor.cuda()
else: else:
return tensor return tensor
...@@ -25,17 +26,29 @@ class GRUUpdate(nn.Module): ...@@ -25,17 +26,29 @@ class GRUUpdate(nn.Module):
self.U_r = nn.Linear(hidden_size, hidden_size) self.U_r = nn.Linear(hidden_size, hidden_size)
self.W_h = nn.Linear(2 * hidden_size, hidden_size) self.W_h = nn.Linear(2 * hidden_size, hidden_size)
def forward(self, node): def update_zm(self, node):
src_x = node['src_x'] src_x = node.data['src_x']
dst_x = node['dst_x'] s = node.data['s']
s = node['s'] rm = node.data['accum_rm']
rm = node['accum_rm']
z = torch.sigmoid(self.W_z(torch.cat([src_x, s], 1))) z = torch.sigmoid(self.W_z(torch.cat([src_x, s], 1)))
m = torch.tanh(self.W_h(torch.cat([src_x, rm], 1))) m = torch.tanh(self.W_h(torch.cat([src_x, rm], 1)))
m = (1 - z) * s + z * m m = (1 - z) * s + z * m
return {'m': m, 'z': z}
def update_r(self, node, zm=None):
dst_x = node.data['dst_x']
m = node.data['m'] if zm is None else zm['m']
r_1 = self.W_r(dst_x) r_1 = self.W_r(dst_x)
r_2 = self.U_r(m) r_2 = self.U_r(m)
r = torch.sigmoid(r_1 + r_2) r = torch.sigmoid(r_1 + r_2)
return {'r': r, 'rm': r * m}
def forward(self, node):
dic = self.update_zm(node)
dic.update(self.update_r(node, zm=dic))
return dic
return {'m': m, 'r': r, 'z': z, 'rm': r * m}
def move_dgl_to_cuda(g):
g.ndata.update({k: cuda(g.ndata[k]) for k in g.ndata})
g.edata.update({k: cuda(g.edata[k]) for k in g.edata})
...@@ -11,8 +11,12 @@ import rdkit ...@@ -11,8 +11,12 @@ import rdkit
from jtnn import * from jtnn import *
lg = rdkit.RDLogger.logger() torch.multiprocessing.set_sharing_strategy('file_system')
lg.setLevel(rdkit.RDLogger.CRITICAL)
def worker_init_fn(id_):
lg = rdkit.RDLogger.logger()
lg.setLevel(rdkit.RDLogger.CRITICAL)
worker_init_fn(None)
parser = OptionParser() parser = OptionParser()
parser.add_option("-t", "--train", dest="train", default='train', help='Training file name') parser.add_option("-t", "--train", dest="train", default='train', help='Training file name')
...@@ -25,10 +29,11 @@ parser.add_option("-l", "--latent", dest="latent_size", default=56) ...@@ -25,10 +29,11 @@ parser.add_option("-l", "--latent", dest="latent_size", default=56)
parser.add_option("-d", "--depth", dest="depth", default=3) parser.add_option("-d", "--depth", dest="depth", default=3)
parser.add_option("-z", "--beta", dest="beta", default=1.0) parser.add_option("-z", "--beta", dest="beta", default=1.0)
parser.add_option("-q", "--lr", dest="lr", default=1e-3) parser.add_option("-q", "--lr", dest="lr", default=1e-3)
parser.add_option("-T", "--test", dest="test", action="store_true")
opts,args = parser.parse_args() opts,args = parser.parse_args()
dataset = JTNNDataset(data=opts.train, vocab=opts.vocab) dataset = JTNNDataset(data=opts.train, vocab=opts.vocab, training=True)
vocab = Vocab([x.strip("\r\n ") for x in open(dataset.vocab_file)]) vocab = dataset.vocab
batch_size = int(opts.batch_size) batch_size = int(opts.batch_size)
hidden_size = int(opts.hidden_size) hidden_size = int(opts.hidden_size)
...@@ -48,39 +53,37 @@ else: ...@@ -48,39 +53,37 @@ else:
else: else:
nn.init.xavier_normal(param) nn.init.xavier_normal(param)
if torch.cuda.is_available(): model = cuda(model)
model = model.cuda()
print("Model #Params: %dK" % (sum([x.nelement() for x in model.parameters()]) / 1000,)) print("Model #Params: %dK" % (sum([x.nelement() for x in model.parameters()]) / 1000,))
optimizer = optim.Adam(model.parameters(), lr=lr) optimizer = optim.Adam(model.parameters(), lr=lr)
scheduler = lr_scheduler.ExponentialLR(optimizer, 0.9) scheduler = lr_scheduler.ExponentialLR(optimizer, 0.9)
scheduler.step() scheduler.step()
MAX_EPOCH = 1 MAX_EPOCH = 100
PRINT_ITER = 20 PRINT_ITER = 20
@profile
def train(): def train():
dataset.training = True
dataloader = DataLoader( dataloader = DataLoader(
dataset, dataset,
batch_size=batch_size, batch_size=batch_size,
shuffle=True, shuffle=True,
num_workers=0, num_workers=4,
collate_fn=lambda x:x, collate_fn=JTNNCollator(vocab, True),
drop_last=True) drop_last=True,
worker_init_fn=worker_init_fn)
for epoch in range(MAX_EPOCH): for epoch in range(MAX_EPOCH):
word_acc,topo_acc,assm_acc,steo_acc = 0,0,0,0 word_acc,topo_acc,assm_acc,steo_acc = 0,0,0,0
for it, batch in enumerate(dataloader): for it, batch in enumerate(dataloader):
for mol_tree in batch:
for node_id, node in mol_tree.nodes.items():
if node['label'] not in node['cands']:
node['cands'].append(node['label'])
node['cand_mols'].append(node['label_mol'])
model.zero_grad() model.zero_grad()
loss, kl_div, wacc, tacc, sacc, dacc = model(batch, beta) try:
loss, kl_div, wacc, tacc, sacc, dacc = model(batch, beta)
except:
print([t.smiles for t in batch['mol_trees']])
raise
loss.backward() loss.backward()
optimizer.step() optimizer.step()
...@@ -95,8 +98,8 @@ def train(): ...@@ -95,8 +98,8 @@ def train():
assm_acc = assm_acc / PRINT_ITER * 100 assm_acc = assm_acc / PRINT_ITER * 100
steo_acc = steo_acc / PRINT_ITER * 100 steo_acc = steo_acc / PRINT_ITER * 100
print("KL: %.1f, Word: %.2f, Topo: %.2f, Assm: %.2f, Steo: %.2f" % ( print("KL: %.1f, Word: %.2f, Topo: %.2f, Assm: %.2f, Steo: %.2f, Loss: %.6f" % (
kl_div, word_acc, topo_acc, assm_acc, steo_acc)) kl_div, word_acc, topo_acc, assm_acc, steo_acc, loss.item()))
word_acc,topo_acc,assm_acc,steo_acc = 0,0,0,0 word_acc,topo_acc,assm_acc,steo_acc = 0,0,0,0
sys.stdout.flush() sys.stdout.flush()
...@@ -110,8 +113,33 @@ def train(): ...@@ -110,8 +113,33 @@ def train():
print("learning rate: %.6f" % scheduler.get_lr()[0]) print("learning rate: %.6f" % scheduler.get_lr()[0])
torch.save(model.state_dict(), opts.save_path + "/model.iter-" + str(epoch)) torch.save(model.state_dict(), opts.save_path + "/model.iter-" + str(epoch))
def test():
dataset.training = False
dataloader = DataLoader(
dataset,
batch_size=1,
shuffle=False,
num_workers=0,
collate_fn=JTNNCollator(vocab, False),
drop_last=True,
worker_init_fn=worker_init_fn)
# Just an example of molecule decoding; in reality you may want to sample
# tree and molecule vectors.
for it, batch in enumerate(dataloader):
gt_smiles = batch['mol_trees'][0].smiles
print(gt_smiles)
model.move_to_cuda(batch)
_, tree_vec, mol_vec = model.encode(batch)
tree_vec, mol_vec, _, _ = model.sample(tree_vec, mol_vec)
smiles = model.decode(tree_vec, mol_vec)
print(smiles)
if __name__ == '__main__': if __name__ == '__main__':
train() if opts.test:
test()
else:
train()
print('# passes:', model.n_passes) print('# passes:', model.n_passes)
print('Total # nodes processed:', model.n_nodes_total) print('Total # nodes processed:', model.n_nodes_total)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment