Unverified Commit ac932c66 authored by Gan Quan's avatar Gan Quan Committed by GitHub
Browse files

[Model] Junction Tree VAE update (#157)

* cherry picking optimization from jtnn

* adding official code.  TODO: fix DGLMolTree

* updating to current api.  vae test still failing

* reverting to list stacking

* reverting to list stacking

* cleaning x flags (stupid windows)

* cleaning x flags (stupid windows)

* adding stats

* optimization

* updating dgl stats

* update again

* more optimization

* looks like computation is faster

* removing profiling code

* cleaning obsolete code

* remove comparison warning

* readme update

* official implementation got a lot faster

* minor fixes

* unbatch by slicing frames

* working around unbatch

* reduce pack

* oops

* support frame read/write with slices

* reverting back to readout as unbatch-by-slicing slows down backward

* reverting to unbatch by splitting; slicing is unfriendly to backward

* replacing lru cache with static object factory

* cherry picking optimization from jtnn

* unbatch by slicing frames

* reduce pack

* oops

* support frame read/write with slices

* reverting to unbatch by splitting; slicing is unfriendly to backward

* replacing lru cache with static object factory

* replacing Scheme object with namedtuple

* forgot the find edges interface

* subclassing namedtuple

* updating to the latest api spec

* bugfix

* bfs with edges

* dfs toy test case

* clean up

* style fix

* bugfix

* update to latest api; include traversal

* replacing with readout

* simplify decoder

* oops

* cleanup

* reducing number of sets

* more speed up

* profile results

* random fixes

* fixing tvmarray handling incontiguous dlpack input

* fancier dataloader

* fix a potential context mismatch

* todo: support pickling or using scipy in multiprocessing load

* pickling support

* resorting to suggested way of pickling

* custom attribute pickling check

* working around a weird pytorch pickling bug

* including partial frame case

* enabling multiprocessing dataloader

* pickling everything now

* really works

* oops

* updated profiling results

* cleanup

* fix as requested

* cleaning random blank lines

* removing profiler outputs

* starting decoding

* testing, WIP

* tree decoding

* graph decoding, WIP

* graph decoding works

* oops

* fixing legacy apis

* trimming number of candidate structures

* sampling cleanups

* removing comparison test

* updated description
parent 4682b76e
......@@ -11,3 +11,16 @@ python3 vaetrain_dgl.py
```
The script will automatically download the data, which is the same as the one in the
original repository.
To disable CUDA, run with `NOCUDA` variable set:
```
NOCUDA=1 python3 vaetrain_dgl.py
```
To decode for new molecules, run
```
python3 vaetrain_dgl.py -T
```
Currently, decoding involves encoding a training example, sampling from the posterior
distribution, and decoding a molecule from that.
from .mol_tree import Vocab
from .jtnn_vae import DGLJTNNVAE
from .mpn import DGLMPN, mol2dgl
from .nnutils import create_var
from .datautils import JTNNDataset
from .mpn import DGLMPN
from .nnutils import create_var, cuda
from .datautils import JTNNDataset, JTNNCollator
from .chemutils import decode_stereo
from .line_profiler_integration import profile
......@@ -251,8 +251,7 @@ def enum_attach_nx(ctr_mol, nei_node, amap, singletons):
return att_confs
#Try rings first: Speed-Up
def enum_assemble_nx(graph, node_idx, neighbors, prev_nodes=[], prev_amap=[]):
node = graph.nodes[node_idx]
def enum_assemble_nx(node, neighbors, prev_nodes=[], prev_amap=[]):
all_attach_confs = []
singletons = [nei_node['nid'] for nei_node in neighbors + prev_nodes if nei_node['mol'].GetNumAtoms() == 1]
......@@ -301,21 +300,21 @@ def enum_assemble_nx(graph, node_idx, neighbors, prev_nodes=[], prev_amap=[]):
#Only used for debugging purpose
def dfs_assemble_nx(graph, cur_mol, global_amap, fa_amap, cur_node_id, fa_node_id):
cur_node = graph.nodes[cur_node_id]
fa_node = graph.nodes[fa_node_id] if fa_node_id is not None else None
cur_node = graph.nodes_dict[cur_node_id]
fa_node = graph.nodes_dict[fa_node_id] if fa_node_id is not None else None
fa_nid = fa_node['nid'] if fa_node is not None else -1
prev_nodes = [fa_node] if fa_node is not None else []
children_id = [nei for nei in graph[cur_node_id] if graph.nodes[nei]['nid'] != fa_nid]
children = [graph.nodes[nei] for nei in children_id]
children_id = [nei for nei in graph[cur_node_id] if graph.nodes_dict[nei]['nid'] != fa_nid]
children = [graph.nodes_dict[nei] for nei in children_id]
neighbors = [nei for nei in children if nei['mol'].GetNumAtoms() > 1]
neighbors = sorted(neighbors, key=lambda x:x['mol'].GetNumAtoms(), reverse=True)
singletons = [nei for nei in children if nei['mol'].GetNumAtoms() == 1]
neighbors = singletons + neighbors
cur_amap = [(fa_nid,a2,a1) for nid,a1,a2 in fa_amap if nid == cur_node['nid']]
cands = enum_assemble_nx(graph, cur_node_id, neighbors, prev_nodes, cur_amap)
cands = enum_assemble_nx(graph.nodes_dict[cur_node_id], neighbors, prev_nodes, cur_amap)
if len(cands) == 0:
return
......
import torch
from torch.utils.data import Dataset
import numpy as np
import dgl
from dgl.data.utils import download, extract_archive, get_download_dir
from .mol_tree_nx import DGLMolTree
from .mol_tree import Vocab
from .mpn import mol2dgl_single as mol2dgl_enc
from .jtmpn import mol2dgl_single as mol2dgl_dec
_url = 'https://www.dropbox.com/s/4ypr0e0abcbsvoh/jtnn.zip?dl=1'
def _unpack_field(examples, field):
return [e[field] for e in examples]
def _set_node_id(mol_tree, vocab):
wid = []
for i, node in enumerate(mol_tree.nodes_dict):
mol_tree.nodes_dict[node]['idx'] = i
wid.append(vocab.get_index(mol_tree.nodes_dict[node]['smiles']))
return wid
class JTNNDataset(Dataset):
def __init__(self, data, vocab):
def __init__(self, data, vocab, training=True):
self.dir = get_download_dir()
self.zip_file_path='{}/jtnn.zip'.format(self.dir)
download(_url, path=self.zip_file_path)
......@@ -20,14 +37,186 @@ class JTNNDataset(Dataset):
print('Loading finished.')
print('\tNum samples:', len(self.data))
print('\tVocab file:', self.vocab_file)
self.training = training
self.vocab = Vocab([x.strip("\r\n ") for x in open(self.vocab_file)])
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
from .mol_tree_nx import DGLMolTree
smiles = self.data[idx]
mol_tree = DGLMolTree(smiles)
mol_tree.recover()
mol_tree.assemble()
return mol_tree
wid = _set_node_id(mol_tree, self.vocab)
# prebuild the molecule graph
mol_graph, atom_x_enc, bond_x_enc = mol2dgl_enc(mol_tree.smiles)
result = {
'mol_tree': mol_tree,
'mol_graph': mol_graph,
'atom_x_enc': atom_x_enc,
'bond_x_enc': bond_x_enc,
'wid': wid,
}
if not self.training:
return result
# prebuild the candidate graph list
cands = []
for node_id, node in mol_tree.nodes_dict.items():
# fill in ground truth
if node['label'] not in node['cands']:
node['cands'].append(node['label'])
node['cand_mols'].append(node['label_mol'])
if node['is_leaf'] or len(node['cands']) == 1:
continue
cands.extend([(cand, mol_tree, node_id)
for cand in node['cand_mols']])
if len(cands) > 0:
cand_graphs, atom_x_dec, bond_x_dec, tree_mess_src_e, \
tree_mess_tgt_e, tree_mess_tgt_n = mol2dgl_dec(cands)
else:
cand_graphs = []
atom_x_dec = torch.zeros(0, atom_x_enc.shape[1])
bond_x_dec = torch.zeros(0, bond_x_enc.shape[1])
tree_mess_src_e = torch.zeros(0, 2).long()
tree_mess_tgt_e = torch.zeros(0, 2).long()
tree_mess_tgt_n = torch.zeros(0, 2).long()
# prebuild the stereoisomers
cands = mol_tree.stereo_cands
if len(cands) > 1:
if mol_tree.smiles3D not in cands:
cands.append(mol_tree.smiles3D)
stereo_graphs = [mol2dgl_enc(c) for c in cands]
stereo_cand_graphs, stereo_atom_x_enc, stereo_bond_x_enc = \
zip(*stereo_graphs)
stereo_atom_x_enc = torch.cat(stereo_atom_x_enc)
stereo_bond_x_enc = torch.cat(stereo_bond_x_enc)
stereo_cand_label = [(cands.index(mol_tree.smiles3D), len(cands))]
else:
stereo_cand_graphs = []
stereo_atom_x_enc = torch.zeros(0, atom_x_enc.shape[1])
stereo_bond_x_enc = torch.zeros(0, bond_x_enc.shape[1])
stereo_cand_label = []
result.update({
'cand_graphs': cand_graphs,
'atom_x_dec': atom_x_dec,
'bond_x_dec': bond_x_dec,
'tree_mess_src_e': tree_mess_src_e,
'tree_mess_tgt_e': tree_mess_tgt_e,
'tree_mess_tgt_n': tree_mess_tgt_n,
'stereo_cand_graphs': stereo_cand_graphs,
'stereo_atom_x_enc': stereo_atom_x_enc,
'stereo_bond_x_enc': stereo_bond_x_enc,
'stereo_cand_label': stereo_cand_label,
})
return result
class JTNNCollator(object):
def __init__(self, vocab, training):
self.vocab = vocab
self.training = training
@staticmethod
def _batch_and_set(graphs, atom_x, bond_x, flatten):
if flatten:
graphs = [g for f in graphs for g in f]
graph_batch = dgl.batch(graphs)
graph_batch.ndata['x'] = atom_x
graph_batch.edata.update({
'x': bond_x,
'src_x': atom_x.new(bond_x.shape[0], atom_x.shape[1]).zero_(),
})
return graph_batch
def __call__(self, examples):
# get list of trees
mol_trees = _unpack_field(examples, 'mol_tree')
wid = _unpack_field(examples, 'wid')
for _wid, mol_tree in zip(wid, mol_trees):
mol_tree.ndata['wid'] = torch.LongTensor(_wid)
# TODO: either support pickling or get around ctypes pointers using scipy
# batch molecule graphs
mol_graphs = _unpack_field(examples, 'mol_graph')
atom_x = torch.cat(_unpack_field(examples, 'atom_x_enc'))
bond_x = torch.cat(_unpack_field(examples, 'bond_x_enc'))
mol_graph_batch = self._batch_and_set(mol_graphs, atom_x, bond_x, False)
result = {
'mol_trees': mol_trees,
'mol_graph_batch': mol_graph_batch,
}
if not self.training:
return result
# batch candidate graphs
cand_graphs = _unpack_field(examples, 'cand_graphs')
cand_batch_idx = []
atom_x = torch.cat(_unpack_field(examples, 'atom_x_dec'))
bond_x = torch.cat(_unpack_field(examples, 'bond_x_dec'))
tree_mess_src_e = _unpack_field(examples, 'tree_mess_src_e')
tree_mess_tgt_e = _unpack_field(examples, 'tree_mess_tgt_e')
tree_mess_tgt_n = _unpack_field(examples, 'tree_mess_tgt_n')
n_graph_nodes = 0
n_tree_nodes = 0
for i in range(len(cand_graphs)):
tree_mess_tgt_e[i] += n_graph_nodes
tree_mess_src_e[i] += n_tree_nodes
tree_mess_tgt_n[i] += n_graph_nodes
n_graph_nodes += sum(g.number_of_nodes() for g in cand_graphs[i])
n_tree_nodes += mol_trees[i].number_of_nodes()
cand_batch_idx.extend([i] * len(cand_graphs[i]))
tree_mess_tgt_e = torch.cat(tree_mess_tgt_e)
tree_mess_src_e = torch.cat(tree_mess_src_e)
tree_mess_tgt_n = torch.cat(tree_mess_tgt_n)
cand_graph_batch = self._batch_and_set(cand_graphs, atom_x, bond_x, True)
# batch stereoisomers
stereo_cand_graphs = _unpack_field(examples, 'stereo_cand_graphs')
atom_x = torch.cat(_unpack_field(examples, 'stereo_atom_x_enc'))
bond_x = torch.cat(_unpack_field(examples, 'stereo_bond_x_enc'))
stereo_cand_batch_idx = []
for i in range(len(stereo_cand_graphs)):
stereo_cand_batch_idx.extend([i] * len(stereo_cand_graphs[i]))
if len(stereo_cand_batch_idx) > 0:
stereo_cand_labels = [
(label, length)
for ex in _unpack_field(examples, 'stereo_cand_label')
for label, length in ex
]
stereo_cand_labels, stereo_cand_lengths = zip(*stereo_cand_labels)
stereo_cand_graph_batch = self._batch_and_set(
stereo_cand_graphs, atom_x, bond_x, True)
else:
stereo_cand_labels = []
stereo_cand_lengths = []
stereo_cand_graph_batch = None
stereo_cand_batch_idx = []
result.update({
'cand_graph_batch': cand_graph_batch,
'cand_batch_idx': cand_batch_idx,
'tree_mess_tgt_e': tree_mess_tgt_e,
'tree_mess_src_e': tree_mess_src_e,
'tree_mess_tgt_n': tree_mess_tgt_n,
'stereo_cand_graph_batch': stereo_cand_graph_batch,
'stereo_cand_batch_idx': stereo_cand_batch_idx,
'stereo_cand_labels': stereo_cand_labels,
'stereo_cand_lengths': stereo_cand_lengths,
})
return result
......@@ -4,10 +4,11 @@ from .nnutils import cuda
from .chemutils import get_mol
#from mpn import atom_features, bond_features, ATOM_FDIM, BOND_FDIM
import rdkit.Chem as Chem
from dgl import DGLGraph, line_graph, batch, unbatch
from dgl import DGLGraph, batch, unbatch, mean_nodes
import dgl.function as DGLF
from .line_profiler_integration import profile
import os
import numpy as np
ELEM_LIST = ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na', 'Ca', 'Fe', 'Al', 'I', 'B', 'K', 'Se', 'Zn', 'H', 'Cu', 'Mn', 'unknown']
......@@ -27,81 +28,81 @@ def onek_encoding_unk(x, allowable_set):
# the 2-D graph first, then enumerate all possible 3-D forms and find the
# one with highest score.
def atom_features(atom):
return cuda(torch.Tensor(onek_encoding_unk(atom.GetSymbol(), ELEM_LIST)
return (torch.Tensor(onek_encoding_unk(atom.GetSymbol(), ELEM_LIST)
+ onek_encoding_unk(atom.GetDegree(), [0,1,2,3,4,5])
+ onek_encoding_unk(atom.GetFormalCharge(), [-1,-2,1,2,0])
+ [atom.GetIsAromatic()]))
def bond_features(bond):
bt = bond.GetBondType()
return cuda(torch.Tensor([bt == Chem.rdchem.BondType.SINGLE, bt == Chem.rdchem.BondType.DOUBLE, bt == Chem.rdchem.BondType.TRIPLE, bt == Chem.rdchem.BondType.AROMATIC, bond.IsInRing()]))
return (torch.Tensor([bt == Chem.rdchem.BondType.SINGLE, bt == Chem.rdchem.BondType.DOUBLE, bt == Chem.rdchem.BondType.TRIPLE, bt == Chem.rdchem.BondType.AROMATIC, bond.IsInRing()]))
@profile
def mol2dgl(cand_batch, mol_tree_batch):
def mol2dgl_single(cand_batch):
cand_graphs = []
tree_mess_source_edges = [] # map these edges from trees to...
tree_mess_target_edges = [] # these edges on candidate graphs
tree_mess_target_nodes = []
n_nodes = 0
n_edges = 0
atom_x = []
bond_x = []
for mol, mol_tree, ctr_node_id in cand_batch:
atom_feature_list = []
bond_feature_list = []
ctr_node = mol_tree.nodes[ctr_node_id]
n_atoms = mol.GetNumAtoms()
n_bonds = mol.GetNumBonds()
ctr_node = mol_tree.nodes_dict[ctr_node_id]
ctr_bid = ctr_node['idx']
g = DGLGraph()
for atom in mol.GetAtoms():
atom_feature_list.append(atom_features(atom))
g.add_node(atom.GetIdx())
for i, atom in enumerate(mol.GetAtoms()):
assert i == atom.GetIdx()
atom_x.append(atom_features(atom))
g.add_nodes(n_atoms)
for bond in mol.GetBonds():
bond_src = []
bond_dst = []
for i, bond in enumerate(mol.GetBonds()):
a1 = bond.GetBeginAtom()
a2 = bond.GetEndAtom()
begin_idx = a1.GetIdx()
end_idx = a2.GetIdx()
features = bond_features(bond)
g.add_edge(begin_idx, end_idx)
bond_feature_list.append(features)
g.add_edge(end_idx, begin_idx)
bond_feature_list.append(features)
bond_src.append(begin_idx)
bond_dst.append(end_idx)
bond_x.append(features)
bond_src.append(end_idx)
bond_dst.append(begin_idx)
bond_x.append(features)
x_nid, y_nid = a1.GetAtomMapNum(), a2.GetAtomMapNum()
# Tree node ID in the batch
x_bid = mol_tree.nodes[x_nid - 1]['idx'] if x_nid > 0 else -1
y_bid = mol_tree.nodes[y_nid - 1]['idx'] if y_nid > 0 else -1
x_bid = mol_tree.nodes_dict[x_nid - 1]['idx'] if x_nid > 0 else -1
y_bid = mol_tree.nodes_dict[y_nid - 1]['idx'] if y_nid > 0 else -1
if x_bid >= 0 and y_bid >= 0 and x_bid != y_bid:
if (x_bid, y_bid) in mol_tree_batch.edge_list:
tree_mess_target_edges.append(
(begin_idx + n_nodes, end_idx + n_nodes))
if mol_tree.has_edge_between(x_bid, y_bid):
tree_mess_target_edges.append((begin_idx + n_nodes, end_idx + n_nodes))
tree_mess_source_edges.append((x_bid, y_bid))
tree_mess_target_nodes.append(end_idx + n_nodes)
if (y_bid, x_bid) in mol_tree_batch.edge_list:
tree_mess_target_edges.append(
(end_idx + n_nodes, begin_idx + n_nodes))
if mol_tree.has_edge_between(y_bid, x_bid):
tree_mess_target_edges.append((end_idx + n_nodes, begin_idx + n_nodes))
tree_mess_source_edges.append((y_bid, x_bid))
tree_mess_target_nodes.append(begin_idx + n_nodes)
n_nodes += len(g.nodes)
atom_x = torch.stack(atom_feature_list)
g.set_n_repr({'x': atom_x})
if len(bond_feature_list) > 0:
bond_x = torch.stack(bond_feature_list)
g.set_e_repr({
'x': bond_x,
'src_x': atom_x.new(len(bond_feature_list), ATOM_FDIM).zero_()
})
n_nodes += n_atoms
g.add_edges(bond_src, bond_dst)
cand_graphs.append(g)
return cand_graphs, tree_mess_source_edges, tree_mess_target_edges, \
tree_mess_target_nodes
return cand_graphs, torch.stack(atom_x), \
torch.stack(bond_x) if len(bond_x) > 0 else torch.zeros(0), \
torch.LongTensor(tree_mess_source_edges), \
torch.LongTensor(tree_mess_target_edges), \
torch.LongTensor(tree_mess_target_nodes)
mpn_loopy_bp_msg = DGLF.copy_src(src='msg', out='msg')
mpn_loopy_bp_reduce = DGLF.sum(msgs='msg', out='accum_msg')
mpn_loopy_bp_reduce = DGLF.sum(msg='msg', out='accum_msg')
class LoopyBPUpdate(nn.Module):
......@@ -112,8 +113,8 @@ class LoopyBPUpdate(nn.Module):
self.W_h = nn.Linear(hidden_size, hidden_size, bias=False)
def forward(self, node):
msg_input = node['msg_input']
msg_delta = self.W_h(node['accum_msg'] + node['alpha'])
msg_input = node.data['msg_input']
msg_delta = self.W_h(node.data['accum_msg'] + node.data['alpha'])
msg = torch.relu(msg_input + msg_delta)
return {'msg': msg}
......@@ -129,11 +130,11 @@ else:
if PAPER:
mpn_gather_reduce = [
DGLF.sum(msgs='msg', out='m'),
DGLF.sum(msgs='alpha', out='accum_alpha'),
DGLF.sum(msg='msg', out='m'),
DGLF.sum(msg='alpha', out='accum_alpha'),
]
else:
mpn_gather_reduce = DGLF.sum(msgs='msg', out='m')
mpn_gather_reduce = DGLF.sum(msg='msg', out='m')
class GatherUpdate(nn.Module):
......@@ -146,11 +147,11 @@ class GatherUpdate(nn.Module):
def forward(self, node):
if PAPER:
#m = node['m']
m = node['m'] + node['accum_alpha']
m = node.data['m'] + node.data['accum_alpha']
else:
m = node['m'] + node['alpha']
m = node.data['m'] + node.data['alpha']
return {
'h': torch.relu(self.W_o(torch.cat([node['x'], m], 1))),
'h': torch.relu(self.W_o(torch.cat([node.data['x'], m], 1))),
}
......@@ -171,25 +172,21 @@ class DGLJTMPN(nn.Module):
self.n_edges_total = 0
self.n_passes = 0
@profile
def forward(self, cand_batch, mol_tree_batch):
cand_graphs, tree_mess_src_edges, tree_mess_tgt_edges, tree_mess_tgt_nodes = \
mol2dgl(cand_batch, mol_tree_batch)
cand_graphs, tree_mess_src_edges, tree_mess_tgt_edges, tree_mess_tgt_nodes = cand_batch
n_samples = len(cand_graphs)
cand_graphs = batch(cand_graphs)
cand_line_graph = line_graph(cand_graphs, no_backtracking=True)
cand_line_graph = cand_graphs.line_graph(backtracking=False, shared=True)
n_nodes = len(cand_graphs.nodes)
n_edges = len(cand_graphs.edges)
n_nodes = cand_graphs.number_of_nodes()
n_edges = cand_graphs.number_of_edges()
cand_graphs = self.run(
cand_graphs, cand_line_graph, tree_mess_src_edges, tree_mess_tgt_edges,
tree_mess_tgt_nodes, mol_tree_batch)
cand_graphs = unbatch(cand_graphs)
g_repr = torch.stack([g.get_n_repr()['h'].mean(0) for g in cand_graphs], 0)
g_repr = mean_nodes(cand_graphs, 'h')
self.n_samples_total += n_samples
self.n_nodes_total += n_nodes
......@@ -198,50 +195,48 @@ class DGLJTMPN(nn.Module):
return g_repr
@profile
def run(self, cand_graphs, cand_line_graph, tree_mess_src_edges, tree_mess_tgt_edges,
tree_mess_tgt_nodes, mol_tree_batch):
n_nodes = len(cand_graphs.nodes)
n_nodes = cand_graphs.number_of_nodes()
cand_graphs.update_edge(
#*zip(*cand_graphs.edge_list),
edge_func=lambda src, dst, edge: {'src_x': src['x']},
batchable=True,
cand_graphs.apply_edges(
func=lambda edges: {'src_x': edges.src['x']},
)
bond_features = cand_line_graph.get_n_repr()['x']
source_features = cand_line_graph.get_n_repr()['src_x']
bond_features = cand_line_graph.ndata['x']
source_features = cand_line_graph.ndata['src_x']
features = torch.cat([source_features, bond_features], 1)
msg_input = self.W_i(features)
cand_line_graph.set_n_repr({
cand_line_graph.ndata.update({
'msg_input': msg_input,
'msg': torch.relu(msg_input),
'accum_msg': torch.zeros_like(msg_input),
})
zero_node_state = bond_features.new(n_nodes, self.hidden_size).zero_()
cand_graphs.set_n_repr({
cand_graphs.ndata.update({
'm': zero_node_state.clone(),
'h': zero_node_state.clone(),
})
cand_graphs.edata['alpha'] = \
cuda(torch.zeros(cand_graphs.number_of_edges(), self.hidden_size))
cand_graphs.ndata['alpha'] = zero_node_state
if tree_mess_src_edges.shape[0] > 0:
if PAPER:
cand_graphs.set_e_repr({
'alpha': cuda(torch.zeros(len(cand_graphs.edge_list), self.hidden_size))
})
alpha = mol_tree_batch.get_e_repr(*zip(*tree_mess_src_edges))['m']
cand_graphs.set_e_repr({'alpha': alpha}, *zip(*tree_mess_tgt_edges))
src_u, src_v = tree_mess_src_edges.unbind(1)
tgt_u, tgt_v = tree_mess_tgt_edges.unbind(1)
alpha = mol_tree_batch.edges[src_u, src_v].data['m']
cand_graphs.edges[tgt_u, tgt_v].data['alpha'] = alpha
else:
alpha = mol_tree_batch.get_e_repr(*zip(*tree_mess_src_edges))['m']
node_idx = (torch.LongTensor(tree_mess_tgt_nodes)
src_u, src_v = tree_mess_src_edges.unbind(1)
alpha = mol_tree_batch.edges[src_u, src_v].data['m']
node_idx = (tree_mess_tgt_nodes
.to(device=zero_node_state.device)[:, None]
.expand_as(alpha))
node_alpha = zero_node_state.clone().scatter_add(0, node_idx, alpha)
cand_graphs.set_n_repr({'alpha': node_alpha})
cand_graphs.update_edge(
#*zip(*cand_graphs.edge_list),
edge_func=lambda src, dst, edge: {'alpha': src['alpha']},
batchable=True,
cand_graphs.ndata['alpha'] = node_alpha
cand_graphs.apply_edges(
func=lambda edges: {'alpha': edges.src['alpha']},
)
for i in range(self.depth - 1):
......@@ -249,14 +244,12 @@ class DGLJTMPN(nn.Module):
mpn_loopy_bp_msg,
mpn_loopy_bp_reduce,
self.loopy_bp_updater,
True
)
cand_graphs.update_all(
mpn_gather_msg,
mpn_gather_reduce,
self.gather_updater,
True
)
return cand_graphs
......@@ -2,52 +2,87 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from .mol_tree import Vocab
from .mol_tree_nx import DGLMolTree
from .chemutils import enum_assemble_nx, get_mol
from .nnutils import GRUUpdate, cuda
import copy
import itertools
from dgl import batch, line_graph
from dgl import batch, dfs_labeled_edges_generator
import dgl.function as DGLF
import networkx as nx
from .line_profiler_integration import profile
import numpy as np
MAX_NB = 8
MAX_DECODE_LEN = 100
def dfs_order(forest, roots):
'''
Returns edge source, edge destination, tree ID, and whether u is generating
a new children
'''
edge_list = []
edges = dfs_labeled_edges_generator(forest, roots, has_reverse_edge=True)
for e, l in zip(*edges):
# I exploited the fact that the reverse edge ID equal to 1 xor forward
# edge ID for molecule trees. Normally, I should locate reverse edges
# using find_edges().
yield e ^ l, l
for i, root in enumerate(roots):
edge_list.append([])
# The following gives the DFS order on edge on a tree.
for u, v, t in nx.dfs_labeled_edges(forest, root):
if u == v or t == 'nontree':
continue
elif t == 'forward':
edge_list[-1].append((u, v, i, 1))
elif t == 'reverse':
edge_list[-1].append((v, u, i, 0))
dec_tree_node_msg = DGLF.copy_edge(edge='m', out='m')
dec_tree_node_reduce = DGLF.sum(msg='m', out='h')
for edges in itertools.zip_longest(*edge_list):
edges = (e for e in edges if e is not None)
u, v, i, p = zip(*edges)
yield u, v, i, p
def dec_tree_node_update(nodes):
return {'new': nodes.data['new'].clone().zero_()}
dec_tree_node_msg = DGLF.copy_edge(edge='m', out='m')
dec_tree_node_reduce = DGLF.sum(msgs='m', out='h')
dec_tree_edge_msg = [DGLF.copy_src(src='m', out='m'), DGLF.copy_src(src='rm', out='rm')]
dec_tree_edge_reduce = [DGLF.sum(msg='m', out='s'), DGLF.sum(msg='rm', out='accum_rm')]
def dec_tree_node_update(node):
return {'new': node['new'].clone().zero_()}
def have_slots(fa_slots, ch_slots):
if len(fa_slots) > 2 and len(ch_slots) > 2:
return True
matches = []
for i,s1 in enumerate(fa_slots):
a1,c1,h1 = s1
for j,s2 in enumerate(ch_slots):
a2,c2,h2 = s2
if a1 == a2 and c1 == c2 and (a1 != "C" or h1 + h2 >= 4):
matches.append( (i,j) )
dec_tree_edge_msg = [DGLF.copy_src(src='m', out='m'), DGLF.copy_src(src='rm', out='rm')]
dec_tree_edge_reduce = [DGLF.sum(msgs='m', out='s'), DGLF.sum(msgs='rm', out='accum_rm')]
if len(matches) == 0: return False
fa_match,ch_match = list(zip(*matches))
if len(set(fa_match)) == 1 and 1 < len(fa_slots) <= 2: #never remove atom from ring
fa_slots.pop(fa_match[0])
if len(set(ch_match)) == 1 and 1 < len(ch_slots) <= 2: #never remove atom from ring
ch_slots.pop(ch_match[0])
return True
def can_assemble(mol_tree, u, v_node_dict):
u_node_dict = mol_tree.nodes_dict[u]
u_neighbors = mol_tree.successors(u)
u_neighbors_node_dict = [
mol_tree.nodes_dict[_u]
for _u in u_neighbors
if _u in mol_tree.nodes_dict
]
neis = u_neighbors_node_dict + [v_node_dict]
for i,nei in enumerate(neis):
nei['nid'] = i
neighbors = [nei for nei in neis if nei['mol'].GetNumAtoms() > 1]
neighbors = sorted(neighbors, key=lambda x:x['mol'].GetNumAtoms(), reverse=True)
singletons = [nei for nei in neis if nei['mol'].GetNumAtoms() == 1]
neighbors = singletons + neighbors
cands = enum_assemble_nx(u_node_dict, neighbors)
return len(cands) > 0
def create_node_dict(smiles, clique=[]):
return dict(
smiles=smiles,
mol=get_mol(smiles),
clique=clique,
)
class DGLJTNNDecoder(nn.Module):
......@@ -70,32 +105,30 @@ class DGLJTNNDecoder(nn.Module):
self.W_o = nn.Linear(hidden_size, self.vocab_size)
self.U_s = nn.Linear(hidden_size, 1)
@profile
def forward(self, mol_trees, tree_vec):
'''
The training procedure which computes the prediction loss given the
ground truth tree
'''
mol_tree_batch = batch(mol_trees)
mol_tree_batch_lg = line_graph(mol_tree_batch, no_backtracking=True)
mol_tree_batch_lg = mol_tree_batch.line_graph(backtracking=False, shared=True)
n_trees = len(mol_trees)
return self.run(mol_tree_batch, mol_tree_batch_lg, n_trees, tree_vec)
@profile
def run(self, mol_tree_batch, mol_tree_batch_lg, n_trees, tree_vec):
root_ids = mol_tree_batch.node_offset[:-1]
n_nodes = len(mol_tree_batch.nodes)
edge_list = mol_tree_batch.edge_list
n_edges = len(edge_list)
node_offset = np.cumsum([0] + mol_tree_batch.batch_num_nodes)
root_ids = node_offset[:-1]
n_nodes = mol_tree_batch.number_of_nodes()
n_edges = mol_tree_batch.number_of_edges()
mol_tree_batch.set_n_repr({
'x': self.embedding(mol_tree_batch.get_n_repr()['wid']),
mol_tree_batch.ndata.update({
'x': self.embedding(mol_tree_batch.ndata['wid']),
'h': cuda(torch.zeros(n_nodes, self.hidden_size)),
'new': cuda(torch.ones(n_nodes).byte()), # whether it's newly generated node
})
mol_tree_batch.set_e_repr({
mol_tree_batch.edata.update({
's': cuda(torch.zeros(n_edges, self.hidden_size)),
'm': cuda(torch.zeros(n_edges, self.hidden_size)),
'r': cuda(torch.zeros(n_edges, self.hidden_size)),
......@@ -106,10 +139,8 @@ class DGLJTNNDecoder(nn.Module):
'accum_rm': cuda(torch.zeros(n_edges, self.hidden_size)),
})
mol_tree_batch.update_edge(
#*zip(*edge_list),
edge_func=lambda src, dst, edge: {'src_x': src['x'], 'dst_x': dst['x']},
batchable=True,
mol_tree_batch.apply_edges(
func=lambda edges: {'src_x': edges.src['x'], 'dst_x': edges.dst['x']},
)
# input tensors for stop prediction (p) and label prediction (q)
......@@ -124,52 +155,57 @@ class DGLJTNNDecoder(nn.Module):
dec_tree_node_msg,
dec_tree_node_reduce,
dec_tree_node_update,
batchable=True,
)
# Extract hidden states and store them for stop/label prediction
h = mol_tree_batch.get_n_repr(root_ids)['h']
x = mol_tree_batch.get_n_repr(root_ids)['x']
h = mol_tree_batch.nodes[root_ids].data['h']
x = mol_tree_batch.nodes[root_ids].data['x']
p_inputs.append(torch.cat([x, h, tree_vec], 1))
t_set = list(range(len(root_ids)))
# If the out degree is 0 we don't generate any edges at all
root_out_degrees = mol_tree_batch.out_degrees(root_ids)
q_inputs.append(torch.cat([h, tree_vec], 1))
q_targets.append(mol_tree_batch.get_n_repr(root_ids)['wid'])
q_targets.append(mol_tree_batch.nodes[root_ids].data['wid'])
# Traverse the tree and predict on children
for u, v, i, p in dfs_order(mol_tree_batch, root_ids):
assert set(t_set).issuperset(i)
ip = dict(zip(i, p))
# TODO: context
p_targets.append(cuda(torch.tensor([ip.get(_i, 0) for _i in t_set])))
t_set = list(i)
eid = mol_tree_batch.get_edge_id(u, v)
for eid, p in dfs_order(mol_tree_batch, root_ids):
u, v = mol_tree_batch.find_edges(eid)
p_target_list = torch.zeros_like(root_out_degrees)
p_target_list[root_out_degrees > 0] = 1 - p
p_target_list = p_target_list[root_out_degrees >= 0]
p_targets.append(torch.tensor(p_target_list))
root_out_degrees -= (root_out_degrees == 0).long()
root_out_degrees -= torch.tensor(np.isin(root_ids, v).astype('int64'))
mol_tree_batch_lg.pull(
eid,
dec_tree_edge_msg,
dec_tree_edge_reduce,
self.dec_tree_edge_update,
batchable=True,
)
is_new = mol_tree_batch.get_n_repr(v)['new']
is_new = mol_tree_batch.nodes[v].data['new']
mol_tree_batch.pull(
v,
dec_tree_node_msg,
dec_tree_node_reduce,
dec_tree_node_update,
batchable=True,
)
# Extract
h = mol_tree_batch.get_n_repr(v)['h']
x = mol_tree_batch.get_n_repr(v)['x']
p_inputs.append(torch.cat([x, h, tree_vec[t_set]], 1))
n_repr = mol_tree_batch.nodes[v].data
h = n_repr['h']
x = n_repr['x']
tree_vec_set = tree_vec[root_out_degrees >= 0]
wid = n_repr['wid']
p_inputs.append(torch.cat([x, h, tree_vec_set], 1))
# Only newly generated nodes are needed for label prediction
# NOTE: The following works since the uncomputed messages are zeros.
q_inputs.append(torch.cat([h[is_new], tree_vec[t_set][is_new]], 1))
q_targets.append(mol_tree_batch.get_n_repr(v)['wid'][is_new])
p_targets.append(cuda(torch.tensor([0 for _ in t_set])))
q_inputs.append(torch.cat([h, tree_vec_set], 1)[is_new])
q_targets.append(wid[is_new])
p_targets.append(torch.zeros((root_out_degrees == 0).sum()).long())
# Batch compute the stop/label prediction losses
p_inputs = torch.cat(p_inputs, 0)
p_targets = torch.cat(p_targets, 0)
p_targets = cuda(torch.cat(p_targets, 0))
q_inputs = torch.cat(q_inputs, 0)
q_targets = torch.cat(q_targets, 0)
......@@ -183,4 +219,161 @@ class DGLJTNNDecoder(nn.Module):
p_acc = ((p > 0).long() == p_targets).sum().float() / p_targets.shape[0]
q_acc = (q.max(1)[1] == q_targets).float().sum() / q_targets.shape[0]
self.q_inputs = q_inputs
self.q_targets = q_targets
self.q = q
self.p_inputs = p_inputs
self.p_targets = p_targets
self.p = p
return q_loss, p_loss, q_acc, p_acc
def decode(self, mol_vec):
assert mol_vec.shape[0] == 1
mol_tree = DGLMolTree(None)
init_hidden = cuda(torch.zeros(1, self.hidden_size))
root_hidden = torch.cat([init_hidden, mol_vec], 1)
root_hidden = F.relu(self.W(root_hidden))
root_score = self.W_o(root_hidden)
_, root_wid = torch.max(root_score, 1)
root_wid = root_wid.view(1)
mol_tree.add_nodes(1) # root
mol_tree.nodes[0].data['wid'] = root_wid
mol_tree.nodes[0].data['x'] = self.embedding(root_wid)
mol_tree.nodes[0].data['h'] = init_hidden
mol_tree.nodes[0].data['fail'] = cuda(torch.tensor([0]))
mol_tree.nodes_dict[0] = root_node_dict = create_node_dict(
self.vocab.get_smiles(root_wid))
stack, trace = [], []
stack.append((0, self.vocab.get_slots(root_wid)))
all_nodes = {0: root_node_dict}
h = {}
first = True
new_node_id = 0
new_edge_id = 0
for step in range(MAX_DECODE_LEN):
u, u_slots = stack[-1]
udata = mol_tree.nodes[u].data
wid = udata['wid']
x = udata['x']
h = udata['h']
# Predict stop
p_input = torch.cat([x, h, mol_vec], 1)
p_score = torch.sigmoid(self.U_s(torch.relu(self.U(p_input))))
backtrack = (p_score.item() < 0.5)
if not backtrack:
# Predict next clique. Note that the prediction may fail due
# to lack of assemblable components
mol_tree.add_nodes(1)
new_node_id += 1
v = new_node_id
mol_tree.add_edges(u, v)
uv = new_edge_id
new_edge_id += 1
if first:
mol_tree.edata.update({
's': cuda(torch.zeros(1, self.hidden_size)),
'm': cuda(torch.zeros(1, self.hidden_size)),
'r': cuda(torch.zeros(1, self.hidden_size)),
'z': cuda(torch.zeros(1, self.hidden_size)),
'src_x': cuda(torch.zeros(1, self.hidden_size)),
'dst_x': cuda(torch.zeros(1, self.hidden_size)),
'rm': cuda(torch.zeros(1, self.hidden_size)),
'accum_rm': cuda(torch.zeros(1, self.hidden_size)),
})
first = False
mol_tree.edges[uv].data['src_x'] = mol_tree.nodes[u].data['x']
# keeping dst_x 0 is fine as h on new edge doesn't depend on that.
# DGL doesn't dynamically maintain a line graph.
mol_tree_lg = mol_tree.line_graph(backtracking=False, shared=True)
mol_tree_lg.pull(
uv,
dec_tree_edge_msg,
dec_tree_edge_reduce,
self.dec_tree_edge_update.update_zm,
)
mol_tree.pull(
v,
dec_tree_node_msg,
dec_tree_node_reduce,
)
vdata = mol_tree.nodes[v].data
h_v = vdata['h']
q_input = torch.cat([h_v, mol_vec], 1)
q_score = torch.softmax(self.W_o(torch.relu(self.W(q_input))), -1)
_, sort_wid = torch.sort(q_score, 1, descending=True)
sort_wid = sort_wid.squeeze()
next_wid = None
for wid in sort_wid.tolist()[:5]:
slots = self.vocab.get_slots(wid)
cand_node_dict = create_node_dict(self.vocab.get_smiles(wid))
if (have_slots(u_slots, slots) and can_assemble(mol_tree, u, cand_node_dict)):
next_wid = wid
next_slots = slots
next_node_dict = cand_node_dict
break
if next_wid is None:
# Failed adding an actual children; v is a spurious node
# and we mark it.
vdata['fail'] = cuda(torch.tensor([1]))
backtrack = True
else:
next_wid = cuda(torch.tensor([next_wid]))
vdata['wid'] = next_wid
vdata['x'] = self.embedding(next_wid)
mol_tree.nodes_dict[v] = next_node_dict
all_nodes[v] = next_node_dict
stack.append((v, next_slots))
mol_tree.add_edge(v, u)
vu = new_edge_id
new_edge_id += 1
mol_tree.edges[uv].data['dst_x'] = mol_tree.nodes[v].data['x']
mol_tree.edges[vu].data['src_x'] = mol_tree.nodes[v].data['x']
mol_tree.edges[vu].data['dst_x'] = mol_tree.nodes[u].data['x']
# DGL doesn't dynamically maintain a line graph.
mol_tree_lg = mol_tree.line_graph(backtracking=False, shared=True)
mol_tree_lg.apply_nodes(
self.dec_tree_edge_update.update_r,
uv
)
if backtrack:
if len(stack) == 1:
break # At root, terminate
pu, _ = stack[-2]
u_pu = mol_tree.edge_id(u, pu)
mol_tree_lg.pull(
u_pu,
dec_tree_edge_msg,
dec_tree_edge_reduce,
self.dec_tree_edge_update,
)
mol_tree.pull(
pu,
dec_tree_node_msg,
dec_tree_node_reduce,
)
stack.pop()
effective_nodes = mol_tree.filter_nodes(lambda nodes: nodes.data['fail'] != 1)
effective_nodes, _ = torch.sort(effective_nodes)
return mol_tree, all_nodes, effective_nodes
......@@ -5,43 +5,24 @@ from .mol_tree import Vocab
from .nnutils import GRUUpdate, cuda
import itertools
import networkx as nx
from dgl import batch, unbatch, line_graph
from dgl import batch, unbatch, bfs_edges_generator
import dgl.function as DGLF
from .line_profiler_integration import profile
import numpy as np
MAX_NB = 8
def level_order(forest, roots):
'''
Given the forest and the list of root nodes,
returns iterator of list of edges ordered by depth, first in bottom-up
and then top-down
'''
edge_list = []
node_depth = {}
edge_list.append([])
for root in roots:
node_depth[root] = 0
for u, v in nx.bfs_edges(forest, root):
node_depth[v] = node_depth[u] + 1
if len(edge_list) == node_depth[u]:
edge_list.append([])
edge_list[node_depth[u]].append((u, v))
for edges in reversed(edge_list):
u, v = zip(*edges)
yield v, u
for edges in edge_list:
u, v = zip(*edges)
yield u, v
edges = bfs_edges_generator(forest, roots)
_, leaves = forest.find_edges(edges[-1])
edges_back = bfs_edges_generator(forest, roots, reversed=True)
yield from reversed(edges_back)
yield from edges
enc_tree_msg = [DGLF.copy_src(src='m', out='m'), DGLF.copy_src(src='rm', out='rm')]
enc_tree_reduce = [DGLF.sum(msgs='m', out='s'), DGLF.sum(msgs='rm', out='accum_rm')]
enc_tree_reduce = [DGLF.sum(msg='m', out='s'), DGLF.sum(msg='rm', out='accum_rm')]
enc_tree_gather_msg = DGLF.copy_edge(edge='m', out='m')
enc_tree_gather_reduce = DGLF.sum(msgs='m', out='m')
enc_tree_gather_reduce = DGLF.sum(msg='m', out='m')
class EncoderGatherUpdate(nn.Module):
def __init__(self, hidden_size):
......@@ -50,9 +31,9 @@ class EncoderGatherUpdate(nn.Module):
self.W = nn.Linear(2 * hidden_size, hidden_size)
def forward(self, node):
x = node['x']
m = node['m']
def forward(self, nodes):
x = nodes.data['x']
m = nodes.data['m']
return {
'h': torch.relu(self.W(torch.cat([x, m], 1))),
}
......@@ -73,34 +54,32 @@ class DGLJTNNEncoder(nn.Module):
self.enc_tree_update = GRUUpdate(hidden_size)
self.enc_tree_gather_update = EncoderGatherUpdate(hidden_size)
@profile
def forward(self, mol_trees):
mol_tree_batch = batch(mol_trees)
# Build line graph to prepare for belief propagation
mol_tree_batch_lg = line_graph(mol_tree_batch, no_backtracking=True)
mol_tree_batch_lg = mol_tree_batch.line_graph(backtracking=False, shared=True)
return self.run(mol_tree_batch, mol_tree_batch_lg)
@profile
def run(self, mol_tree_batch, mol_tree_batch_lg):
# Since tree roots are designated to 0. In the batched graph we can
# simply find the corresponding node ID by looking at node_offset
root_ids = mol_tree_batch.node_offset[:-1]
n_nodes = len(mol_tree_batch.nodes)
edge_list = mol_tree_batch.edge_list
n_edges = len(edge_list)
node_offset = np.cumsum([0] + mol_tree_batch.batch_num_nodes)
root_ids = node_offset[:-1]
n_nodes = mol_tree_batch.number_of_nodes()
n_edges = mol_tree_batch.number_of_edges()
# Assign structure embeddings to tree nodes
mol_tree_batch.set_n_repr({
'x': self.embedding(mol_tree_batch.get_n_repr()['wid']),
mol_tree_batch.ndata.update({
'x': self.embedding(mol_tree_batch.ndata['wid']),
'h': cuda(torch.zeros(n_nodes, self.hidden_size)),
})
# Initialize the intermediate variables according to Eq (4)-(8).
# Also initialize the src_x and dst_x fields.
# TODO: context?
mol_tree_batch.set_e_repr({
mol_tree_batch.edata.update({
's': cuda(torch.zeros(n_edges, self.hidden_size)),
'm': cuda(torch.zeros(n_edges, self.hidden_size)),
'r': cuda(torch.zeros(n_edges, self.hidden_size)),
......@@ -112,10 +91,8 @@ class DGLJTNNEncoder(nn.Module):
})
# Send the source/destination node features to edges
mol_tree_batch.update_edge(
#*zip(*edge_list),
edge_func=lambda src, dst, edge: {'src_x': src['x'], 'dst_x': dst['x']},
batchable=True,
mol_tree_batch.apply_edges(
func=lambda edges: {'src_x': edges.src['x'], 'dst_x': edges.dst['x']},
)
# Message passing
......@@ -123,14 +100,13 @@ class DGLJTNNEncoder(nn.Module):
# messages, and the uncomputed messages are zero vectors. Essentially,
# we can always compute s_ij as the sum of incoming m_ij, no matter
# if m_ij is actually computed or not.
for u, v in level_order(mol_tree_batch, root_ids):
eid = mol_tree_batch.get_edge_id(u, v)
for eid in level_order(mol_tree_batch, root_ids):
#eid = mol_tree_batch.edge_ids(u, v)
mol_tree_batch_lg.pull(
eid,
enc_tree_msg,
enc_tree_reduce,
self.enc_tree_update,
batchable=True,
)
# Readout
......@@ -138,9 +114,8 @@ class DGLJTNNEncoder(nn.Module):
enc_tree_gather_msg,
enc_tree_gather_reduce,
self.enc_tree_gather_update,
batchable=True,
)
root_vecs = mol_tree_batch.get_n_repr(root_ids)['h']
root_vecs = mol_tree_batch.nodes[root_ids].data['h']
return mol_tree_batch, root_vecs
......@@ -2,11 +2,15 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from .mol_tree import Vocab
from .nnutils import create_var, cuda
from .nnutils import create_var, cuda, move_dgl_to_cuda
from .chemutils import set_atommap, copy_edit_mol, enum_assemble_nx, \
attach_mols_nx, decode_stereo
from .jtnn_enc import DGLJTNNEncoder
from .jtnn_dec import DGLJTNNDecoder
from .mpn import DGLMPN, mol2dgl
from .mpn import DGLMPN
from .mpn import mol2dgl_single as mol2dgl_enc
from .jtmpn import DGLJTMPN
from .jtmpn import mol2dgl_single as mol2dgl_dec
from .line_profiler_integration import profile
import rdkit
......@@ -15,15 +19,7 @@ from rdkit import DataStructs
from rdkit.Chem import AllChem
import copy, math
def dgl_set_batch_nodeID(mol_batch, vocab):
tot = 0
for mol_tree in mol_batch:
wid = []
for node in mol_tree.nodes:
mol_tree.nodes[node]['idx'] = tot
tot += 1
wid.append(vocab.get_index(mol_tree.nodes[node]['smiles']))
mol_tree.set_n_repr({'wid': cuda(torch.LongTensor(wid))})
from dgl import batch, unbatch
class DGLJTNNVAE(nn.Module):
......@@ -51,86 +47,75 @@ class DGLJTNNVAE(nn.Module):
self.n_edges_total = 0
self.n_tree_nodes_total = 0
@profile
def encode(self, mol_batch):
dgl_set_batch_nodeID(mol_batch, self.vocab)
@staticmethod
def move_to_cuda(mol_batch):
for t in mol_batch['mol_trees']:
move_dgl_to_cuda(t)
move_dgl_to_cuda(mol_batch['mol_graph_batch'])
if 'cand_graph_batch' in mol_batch:
move_dgl_to_cuda(mol_batch['cand_graph_batch'])
if mol_batch.get('stereo_cand_graph_batch') is not None:
move_dgl_to_cuda(mol_batch['stereo_cand_graph_batch'])
smiles_batch = [mol_tree.smiles for mol_tree in mol_batch]
mol_graphs = mol2dgl(smiles_batch)
def encode(self, mol_batch):
mol_graphs = mol_batch['mol_graph_batch']
mol_vec = self.mpn(mol_graphs)
# mol_batch is a junction tree
mol_tree_batch, tree_vec = self.jtnn(mol_batch)
self.n_nodes_total += sum(len(g.nodes) for g in mol_graphs)
self.n_edges_total += sum(len(g.edges) for g in mol_graphs)
self.n_tree_nodes_total += sum(len(t.nodes) for t in mol_batch)
mol_tree_batch, tree_vec = self.jtnn(mol_batch['mol_trees'])
self.n_nodes_total += mol_graphs.number_of_nodes()
self.n_edges_total += mol_graphs.number_of_edges()
self.n_tree_nodes_total += sum(t.number_of_nodes() for t in mol_batch['mol_trees'])
self.n_passes += 1
return mol_tree_batch, tree_vec, mol_vec
@profile
def forward(self, mol_batch, beta=0, e1=None, e2=None):
batch_size = len(mol_batch)
mol_tree_batch, tree_vec, mol_vec = self.encode(mol_batch)
def sample(self, tree_vec, mol_vec, e1=None, e2=None):
tree_mean = self.T_mean(tree_vec)
tree_log_var = -torch.abs(self.T_var(tree_vec))
mol_mean = self.G_mean(mol_vec)
mol_log_var = -torch.abs(self.G_var(mol_vec))
self.tree_mean = tree_mean
self.tree_log_var = tree_log_var
self.mol_mean = mol_mean
self.mol_log_var = mol_log_var
epsilon = cuda(torch.randn(*tree_mean.shape)) if e1 is None else e1
tree_vec = tree_mean + torch.exp(tree_log_var / 2) * epsilon
epsilon = cuda(torch.randn(*mol_mean.shape)) if e2 is None else e2
mol_vec = mol_mean + torch.exp(mol_log_var / 2) * epsilon
z_mean = torch.cat([tree_mean, mol_mean], dim=1)
z_log_var = torch.cat([tree_log_var, mol_log_var], dim=1)
kl_loss = -0.5 * torch.sum(1.0 + z_log_var - z_mean * z_mean - torch.exp(z_log_var)) / batch_size
z_mean = torch.cat([tree_mean, mol_mean], 1)
z_log_var = torch.cat([tree_log_var, mol_log_var], 1)
self.z_mean = z_mean
self.z_log_var = z_log_var
return tree_vec, mol_vec, z_mean, z_log_var
epsilon = cuda(torch.randn(batch_size, self.latent_size // 2)) if e1 is None else e1
tree_vec = tree_mean + torch.exp(tree_log_var / 2) * epsilon
epsilon = cuda(torch.randn(batch_size, self.latent_size // 2)) if e2 is None else e2
mol_vec = mol_mean + torch.exp(mol_log_var / 2) * epsilon
def forward(self, mol_batch, beta=0, e1=None, e2=None):
self.move_to_cuda(mol_batch)
mol_trees = mol_batch['mol_trees']
batch_size = len(mol_trees)
mol_tree_batch, tree_vec, mol_vec = self.encode(mol_batch)
self.tree_vec = tree_vec
self.mol_vec = mol_vec
tree_vec, mol_vec, z_mean, z_log_var = self.sample(tree_vec, mol_vec, e1, e2)
kl_loss = -0.5 * torch.sum(1.0 + z_log_var - z_mean * z_mean - torch.exp(z_log_var)) / batch_size
word_loss, topo_loss, word_acc, topo_acc = self.decoder(mol_batch, tree_vec)
word_loss, topo_loss, word_acc, topo_acc = self.decoder(mol_trees, tree_vec)
assm_loss, assm_acc = self.assm(mol_batch, mol_tree_batch, mol_vec)
stereo_loss, stereo_acc = self.stereo(mol_batch, mol_vec)
self.word_loss_v = word_loss
self.topo_loss_v = topo_loss
self.assm_loss_v = assm_loss
self.stereo_loss_v = stereo_loss
all_vec = torch.cat([tree_vec, mol_vec], dim=1)
loss = word_loss + topo_loss + assm_loss + 2 * stereo_loss + beta * kl_loss
self.all_vec = all_vec
return loss, kl_loss, word_acc, topo_acc, assm_acc, stereo_acc
@profile
def assm(self, mol_batch, mol_tree_batch, mol_vec):
cands = []
batch_idx = []
for i, mol_tree in enumerate(mol_batch):
for node_id, node in mol_tree.nodes.items():
if node['is_leaf'] or len(node['cands']) == 1:
continue
cands.extend([(cand, mol_tree, node_id) for cand in node['cand_mols']])
batch_idx.extend([i] * len(node['cands']))
cands = [mol_batch['cand_graph_batch'],
mol_batch['tree_mess_src_e'],
mol_batch['tree_mess_tgt_e'],
mol_batch['tree_mess_tgt_n']]
cand_vec = self.jtmpn(cands, mol_tree_batch)
cand_vec = self.G_mean(cand_vec)
batch_idx = cuda(torch.LongTensor(batch_idx))
batch_idx = cuda(torch.LongTensor(mol_batch['cand_batch_idx']))
mol_vec = mol_vec[batch_idx]
mol_vec = mol_vec.view(-1, 1, self.latent_size // 2)
......@@ -139,13 +124,13 @@ class DGLJTNNVAE(nn.Module):
cnt, tot, acc = 0, 0, 0
all_loss = []
for i, mol_tree in enumerate(mol_batch):
comp_nodes = [node_id for node_id, node in mol_tree.nodes.items()
for i, mol_tree in enumerate(mol_batch['mol_trees']):
comp_nodes = [node_id for node_id, node in mol_tree.nodes_dict.items()
if len(node['cands']) > 1 and not node['is_leaf']]
cnt += len(comp_nodes)
# segmented accuracy and cross entropy
for node_id in comp_nodes:
node = mol_tree.nodes[node_id]
node = mol_tree.nodes_dict[node_id]
label = node['cands'].index(node['label'])
ncand = len(node['cands'])
cur_score = scores[tot:tot+ncand]
......@@ -158,36 +143,28 @@ class DGLJTNNVAE(nn.Module):
all_loss.append(
F.cross_entropy(cur_score.view(1, -1), label, size_average=False))
all_loss = sum(all_loss) / len(mol_batch)
all_loss = sum(all_loss) / len(mol_batch['mol_trees'])
return all_loss, acc / cnt
@profile
def stereo(self, mol_batch, mol_vec):
stereo_cands, batch_idx = [], []
labels = []
for i, mol_tree in enumerate(mol_batch):
cands = mol_tree.stereo_cands
if len(cands) == 1:
continue
if mol_tree.smiles3D not in cands:
cands.append(mol_tree.smiles3D)
stereo_cands.extend(cands)
batch_idx.extend([i] * len(cands))
labels.append((cands.index(mol_tree.smiles3D), len(cands)))
stereo_cands = mol_batch['stereo_cand_graph_batch']
batch_idx = mol_batch['stereo_cand_batch_idx']
labels = mol_batch['stereo_cand_labels']
lengths = mol_batch['stereo_cand_lengths']
if len(labels) == 0:
# Only one stereoisomer exists; do nothing
return cuda(torch.tensor(0.)), 1.
batch_idx = cuda(torch.LongTensor(batch_idx))
stereo_cands = self.mpn(mol2dgl(stereo_cands))
stereo_cands = self.mpn(stereo_cands)
stereo_cands = self.G_mean(stereo_cands)
stereo_labels = mol_vec[batch_idx]
scores = F.cosine_similarity(stereo_cands, stereo_labels)
st, acc = 0, 0
all_loss = []
for label, le in labels:
for label, le in zip(labels, lengths):
cur_scores = scores[st:st+le]
if cur_scores.data[label].item() >= cur_scores.max().item():
acc += 1
......@@ -198,3 +175,134 @@ class DGLJTNNVAE(nn.Module):
all_loss = sum(all_loss) / len(labels)
return all_loss, acc / len(labels)
def decode(self, tree_vec, mol_vec):
mol_tree, nodes_dict, effective_nodes = self.decoder.decode(tree_vec)
effective_nodes_list = effective_nodes.tolist()
nodes_dict = [nodes_dict[v] for v in effective_nodes_list]
for i, (node_id, node) in enumerate(zip(effective_nodes_list, nodes_dict)):
node['idx'] = i
node['nid'] = i + 1
node['is_leaf'] = True
if mol_tree.in_degree(node_id) > 1:
node['is_leaf'] = False
set_atommap(node['mol'], node['nid'])
mol_tree_sg = mol_tree.subgraph(effective_nodes)
mol_tree_sg.copy_from_parent()
mol_tree_msg, _ = self.jtnn([mol_tree_sg])
mol_tree_msg = unbatch(mol_tree_msg)[0]
mol_tree_msg.nodes_dict = nodes_dict
cur_mol = copy_edit_mol(nodes_dict[0]['mol'])
global_amap = [{}] + [{} for node in nodes_dict]
global_amap[1] = {atom.GetIdx(): atom.GetIdx() for atom in cur_mol.GetAtoms()}
cur_mol = self.dfs_assemble(mol_tree_msg, mol_vec, cur_mol, global_amap, [], 0, None)
if cur_mol is None:
return None
cur_mol = cur_mol.GetMol()
set_atommap(cur_mol)
cur_mol = Chem.MolFromSmiles(Chem.MolToSmiles(cur_mol))
if cur_mol is None:
return None
smiles2D = Chem.MolToSmiles(cur_mol)
stereo_cands = decode_stereo(smiles2D)
if len(stereo_cands) == 1:
return stereo_cands[0]
stereo_graphs = [mol2dgl_enc(c) for c in stereo_cands]
stereo_cand_graphs, atom_x, bond_x = \
zip(*stereo_graphs)
stereo_cand_graphs = batch(stereo_cand_graphs)
atom_x = cuda(torch.cat(atom_x))
bond_x = cuda(torch.cat(bond_x))
stereo_cand_graphs.ndata['x'] = atom_x
stereo_cand_graphs.edata['x'] = bond_x
stereo_cand_graphs.edata['src_x'] = atom_x.new(
bond_x.shape[0], atom_x.shape[1]).zero_()
stereo_vecs = self.mpn(stereo_cand_graphs)
stereo_vecs = self.G_mean(stereo_vecs)
scores = F.cosine_similarity(stereo_vecs, mol_vec)
_, max_id = scores.max(0)
return stereo_cands[max_id.item()]
def dfs_assemble(self, mol_tree_msg, mol_vec, cur_mol,
global_amap, fa_amap, cur_node_id, fa_node_id):
nodes_dict = mol_tree_msg.nodes_dict
fa_node = nodes_dict[fa_node_id] if fa_node_id is not None else None
cur_node = nodes_dict[cur_node_id]
fa_nid = fa_node['nid'] if fa_node is not None else -1
prev_nodes = [fa_node] if fa_node is not None else []
children_node_id = [v for v in mol_tree_msg.successors(cur_node_id).tolist()
if nodes_dict[v]['nid'] != fa_nid]
children = [nodes_dict[v] for v in children_node_id]
neighbors = [nei for nei in children if nei['mol'].GetNumAtoms() > 1]
neighbors = sorted(neighbors, key=lambda x: x['mol'].GetNumAtoms(), reverse=True)
singletons = [nei for nei in children if nei['mol'].GetNumAtoms() == 1]
neighbors = singletons + neighbors
cur_amap = [(fa_nid, a2, a1) for nid, a1, a2 in fa_amap if nid == cur_node['nid']]
cands = enum_assemble_nx(cur_node, neighbors, prev_nodes, cur_amap)
if len(cands) == 0:
return None
cand_smiles, cand_mols, cand_amap = list(zip(*cands))
cands = [(candmol, mol_tree_msg, cur_node_id) for candmol in cand_mols]
cand_graphs, atom_x, bond_x, tree_mess_src_edges, \
tree_mess_tgt_edges, tree_mess_tgt_nodes = mol2dgl_dec(
cands)
cand_graphs = batch(cand_graphs)
atom_x = cuda(atom_x)
bond_x = cuda(bond_x)
cand_graphs.ndata['x'] = atom_x
cand_graphs.edata['x'] = bond_x
cand_graphs.edata['src_x'] = atom_x.new(bond_x.shape[0], atom_x.shape[1]).zero_()
cand_vecs = self.jtmpn(
(cand_graphs, tree_mess_src_edges, tree_mess_tgt_edges, tree_mess_tgt_nodes),
mol_tree_msg,
)
cand_vecs = self.G_mean(cand_vecs)
mol_vec = mol_vec.squeeze()
scores = cand_vecs @ mol_vec
_, cand_idx = torch.sort(scores, descending=True)
backup_mol = Chem.RWMol(cur_mol)
for i in range(len(cand_idx)):
cur_mol = Chem.RWMol(backup_mol)
pred_amap = cand_amap[cand_idx[i].item()]
new_global_amap = copy.deepcopy(global_amap)
for nei_id, ctr_atom, nei_atom in pred_amap:
if nei_id == fa_nid:
continue
new_global_amap[nei_id][nei_atom] = new_global_amap[cur_node['nid']][ctr_atom]
cur_mol = attach_mols_nx(cur_mol, children, [], new_global_amap)
new_mol = cur_mol.GetMol()
new_mol = Chem.MolFromSmiles(Chem.MolToSmiles(new_mol))
if new_mol is None:
continue
result = True
for nei_node_id, nei_node in zip(children_node_id, children):
if nei_node['is_leaf']:
continue
cur_mol = self.dfs_assemble(
mol_tree_msg, mol_vec, cur_mol, new_global_amap, pred_amap,
nei_node_id, cur_node_id)
if cur_mol is None:
result = False
break
if result:
return cur_mol
return None
......@@ -2,10 +2,17 @@ from dgl import DGLGraph
import rdkit.Chem as Chem
from .chemutils import get_clique_mol, tree_decomp, get_mol, get_smiles, \
set_atommap, enum_assemble_nx, decode_stereo
import numpy as np
from .line_profiler_integration import profile
class DGLMolTree(DGLGraph):
def __init__(self, smiles):
DGLGraph.__init__(self)
self.nodes_dict = {}
if smiles is None:
return
self.smiles = smiles
self.mol = get_mol(smiles)
......@@ -21,39 +28,43 @@ class DGLMolTree(DGLGraph):
for i, c in enumerate(cliques):
cmol = get_clique_mol(self.mol, c)
csmiles = get_smiles(cmol)
self.add_node(
i,
self.nodes_dict[i] = dict(
smiles=csmiles,
mol=get_mol(csmiles),
clique=c,
)
if min(c) == 0:
root = i
self.add_nodes(len(cliques))
# The clique with atom ID 0 becomes root
if root > 0:
for attr in self.nodes[0]:
self.nodes[0][attr], self.nodes[root][attr] = \
self.nodes[root][attr], self.nodes[0][attr]
for attr in self.nodes_dict[0]:
self.nodes_dict[0][attr], self.nodes_dict[root][attr] = \
self.nodes_dict[root][attr], self.nodes_dict[0][attr]
for _x, _y in edges:
src = np.zeros((len(edges) * 2,), dtype='int')
dst = np.zeros((len(edges) * 2,), dtype='int')
for i, (_x, _y) in enumerate(edges):
x = 0 if _x == root else root if _x == 0 else _x
y = 0 if _y == root else root if _y == 0 else _y
self.add_edge(x, y)
self.add_edge(y, x)
for i in self.nodes:
self.nodes[i]['nid'] = i + 1
if len(self[i]) > 1: # Leaf node mol is not marked
set_atommap(self.nodes[i]['mol'], self.nodes[i]['nid'])
self.nodes[i]['is_leaf'] = (len(self[i]) == 1)
src[2 * i] = x
dst[2 * i] = y
src[2 * i + 1] = y
dst[2 * i + 1] = x
self.add_edges(src, dst)
for i in self.nodes_dict:
self.nodes_dict[i]['nid'] = i + 1
if self.out_degree(i) > 1: # Leaf node mol is not marked
set_atommap(self.nodes_dict[i]['mol'], self.nodes_dict[i]['nid'])
self.nodes_dict[i]['is_leaf'] = (self.out_degree(i) == 1)
# avoiding DiGraph.size()
def treesize(self):
return len(self.nodes)
return self.number_of_nodes()
def _recover_node(self, i, original_mol):
node = self.nodes[i]
node = self.nodes_dict[i]
clique = []
clique.extend(node['clique'])
......@@ -61,8 +72,8 @@ class DGLMolTree(DGLGraph):
for cidx in node['clique']:
original_mol.GetAtomWithIdx(cidx).SetAtomMapNum(node['nid'])
for j in self[i]:
nei_node = self.nodes[j]
for j in self.successors(i).numpy():
nei_node = self.nodes_dict[j]
clique.extend(nei_node['clique'])
if nei_node['is_leaf']: # Leaf node, no need to mark
continue
......@@ -83,25 +94,27 @@ class DGLMolTree(DGLGraph):
return node['label']
def _assemble_node(self, i):
neighbors = [self.nodes[j] for j in self[i] if self.nodes[j]['mol'].GetNumAtoms() > 1]
neighbors = [self.nodes_dict[j] for j in self.successors(i).numpy()
if self.nodes_dict[j]['mol'].GetNumAtoms() > 1]
neighbors = sorted(neighbors, key=lambda x: x['mol'].GetNumAtoms(), reverse=True)
singletons = [self.nodes[j] for j in self[i] if self.nodes[j]['mol'].GetNumAtoms() == 1]
singletons = [self.nodes_dict[j] for j in self.successors(i).numpy()
if self.nodes_dict[j]['mol'].GetNumAtoms() == 1]
neighbors = singletons + neighbors
cands = enum_assemble_nx(self, i, neighbors)
cands = enum_assemble_nx(self.nodes_dict[i], neighbors)
if len(cands) > 0:
self.nodes[i]['cands'], self.nodes[i]['cand_mols'], _ = list(zip(*cands))
self.nodes[i]['cands'] = list(self.nodes[i]['cands'])
self.nodes[i]['cand_mols'] = list(self.nodes[i]['cand_mols'])
self.nodes_dict[i]['cands'], self.nodes_dict[i]['cand_mols'], _ = list(zip(*cands))
self.nodes_dict[i]['cands'] = list(self.nodes_dict[i]['cands'])
self.nodes_dict[i]['cand_mols'] = list(self.nodes_dict[i]['cand_mols'])
else:
self.nodes[i]['cands'] = []
self.nodes[i]['cand_mols'] = []
self.nodes_dict[i]['cands'] = []
self.nodes_dict[i]['cand_mols'] = []
def recover(self):
for i in self.nodes:
for i in self.nodes_dict:
self._recover_node(i, self.mol)
def assemble(self):
for i in self.nodes:
for i in self.nodes_dict:
self._assemble_node(i)
......@@ -4,11 +4,12 @@ import rdkit.Chem as Chem
import torch.nn.functional as F
from .nnutils import *
from .chemutils import get_mol
from networkx import Graph, DiGraph, line_graph, convert_node_labels_to_integers
from dgl import DGLGraph, line_graph, batch, unbatch
from networkx import Graph, DiGraph, convert_node_labels_to_integers
from dgl import DGLGraph, batch, unbatch, mean_nodes
import dgl.function as DGLF
from functools import partial
from .line_profiler_integration import profile
import numpy as np
ELEM_LIST = ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na', 'Ca', 'Fe', 'Al', 'I', 'B', 'K', 'Se', 'Zn', 'H', 'Cu', 'Mn', 'unknown']
......@@ -22,7 +23,7 @@ def onek_encoding_unk(x, allowable_set):
return [x == s for s in allowable_set]
def atom_features(atom):
return cuda(torch.Tensor(onek_encoding_unk(atom.GetSymbol(), ELEM_LIST)
return (torch.Tensor(onek_encoding_unk(atom.GetSymbol(), ELEM_LIST)
+ onek_encoding_unk(atom.GetDegree(), [0,1,2,3,4,5])
+ onek_encoding_unk(atom.GetFormalCharge(), [-1,-2,1,2,0])
+ onek_encoding_unk(int(atom.GetChiralTag()), [0,1,2,3])
......@@ -33,46 +34,45 @@ def bond_features(bond):
stereo = int(bond.GetStereo())
fbond = [bt == Chem.rdchem.BondType.SINGLE, bt == Chem.rdchem.BondType.DOUBLE, bt == Chem.rdchem.BondType.TRIPLE, bt == Chem.rdchem.BondType.AROMATIC, bond.IsInRing()]
fstereo = onek_encoding_unk(stereo, [0,1,2,3,4,5])
return cuda(torch.Tensor(fbond + fstereo))
@profile
def mol2dgl(smiles_batch):
n_nodes = 0
graph_list = []
for smiles in smiles_batch:
atom_feature_list = []
bond_feature_list = []
bond_source_feature_list = []
graph = DGLGraph()
return (torch.Tensor(fbond + fstereo))
def mol2dgl_single(smiles):
n_edges = 0
atom_x = []
bond_x = []
mol = get_mol(smiles)
for atom in mol.GetAtoms():
graph.add_node(atom.GetIdx())
atom_feature_list.append(atom_features(atom))
for bond in mol.GetBonds():
n_atoms = mol.GetNumAtoms()
n_bonds = mol.GetNumBonds()
graph = DGLGraph()
for i, atom in enumerate(mol.GetAtoms()):
assert i == atom.GetIdx()
atom_x.append(atom_features(atom))
graph.add_nodes(n_atoms)
bond_src = []
bond_dst = []
for i, bond in enumerate(mol.GetBonds()):
begin_idx = bond.GetBeginAtom().GetIdx()
end_idx = bond.GetEndAtom().GetIdx()
features = bond_features(bond)
graph.add_edge(begin_idx, end_idx)
bond_feature_list.append(features)
bond_src.append(begin_idx)
bond_dst.append(end_idx)
bond_x.append(features)
# set up the reverse direction
graph.add_edge(end_idx, begin_idx)
bond_feature_list.append(features)
atom_x = torch.stack(atom_feature_list)
graph.set_n_repr({'x': atom_x})
if len(bond_feature_list) > 0:
bond_x = torch.stack(bond_feature_list)
graph.set_e_repr({
'x': bond_x,
'src_x': atom_x.new(len(bond_feature_list), ATOM_FDIM).zero_()
})
graph_list.append(graph)
bond_src.append(end_idx)
bond_dst.append(begin_idx)
bond_x.append(features)
graph.add_edges(bond_src, bond_dst)
return graph_list
n_edges += n_bonds
return graph, torch.stack(atom_x), \
torch.stack(bond_x) if len(bond_x) > 0 else torch.zeros(0)
mpn_loopy_bp_msg = DGLF.copy_src(src='msg', out='msg')
mpn_loopy_bp_reduce = DGLF.sum(msgs='msg', out='accum_msg')
mpn_loopy_bp_reduce = DGLF.sum(msg='msg', out='accum_msg')
class LoopyBPUpdate(nn.Module):
......@@ -82,15 +82,15 @@ class LoopyBPUpdate(nn.Module):
self.W_h = nn.Linear(hidden_size, hidden_size, bias=False)
def forward(self, node):
msg_input = node['msg_input']
msg_delta = self.W_h(node['accum_msg'])
def forward(self, nodes):
msg_input = nodes.data['msg_input']
msg_delta = self.W_h(nodes.data['accum_msg'])
msg = F.relu(msg_input + msg_delta)
return {'msg': msg}
mpn_gather_msg = DGLF.copy_edge(edge='msg', out='msg')
mpn_gather_reduce = DGLF.sum(msgs='msg', out='m')
mpn_gather_reduce = DGLF.sum(msg='msg', out='m')
class GatherUpdate(nn.Module):
......@@ -100,10 +100,10 @@ class GatherUpdate(nn.Module):
self.W_o = nn.Linear(ATOM_FDIM + hidden_size, hidden_size)
def forward(self, node):
m = node['m']
def forward(self, nodes):
m = nodes.data['m']
return {
'h': F.relu(self.W_o(torch.cat([node['x'], m], 1))),
'h': F.relu(self.W_o(torch.cat([nodes.data['x'], m], 1))),
}
......@@ -124,19 +124,18 @@ class DGLMPN(nn.Module):
self.n_edges_total = 0
self.n_passes = 0
@profile
def forward(self, mol_graph_list):
n_samples = len(mol_graph_list)
def forward(self, mol_graph):
n_samples = mol_graph.batch_size
mol_graph = batch(mol_graph_list)
mol_line_graph = line_graph(mol_graph, no_backtracking=True)
mol_line_graph = mol_graph.line_graph(backtracking=False, shared=True)
n_nodes = len(mol_graph.nodes)
n_edges = len(mol_graph.edges)
n_nodes = mol_graph.number_of_nodes()
n_edges = mol_graph.number_of_edges()
mol_graph = self.run(mol_graph, mol_line_graph)
mol_graph_list = unbatch(mol_graph)
g_repr = torch.stack([g.get_n_repr()['h'].mean(0) for g in mol_graph_list], 0)
# TODO: replace with unbatch or readout
g_repr = mean_nodes(mol_graph, 'h')
self.n_samples_total += n_samples
self.n_nodes_total += n_nodes
......@@ -145,27 +144,25 @@ class DGLMPN(nn.Module):
return g_repr
@profile
def run(self, mol_graph, mol_line_graph):
n_nodes = len(mol_graph.nodes)
n_nodes = mol_graph.number_of_nodes()
mol_graph.update_edge(
#*zip(*mol_graph.edge_list),
edge_func=lambda src, dst, edge: {'src_x': src['x']},
batchable=True,
mol_graph.apply_edges(
func=lambda edges: {'src_x': edges.src['x']},
)
bond_features = mol_line_graph.get_n_repr()['x']
source_features = mol_line_graph.get_n_repr()['src_x']
e_repr = mol_line_graph.ndata
bond_features = e_repr['x']
source_features = e_repr['src_x']
features = torch.cat([source_features, bond_features], 1)
msg_input = self.W_i(features)
mol_line_graph.set_n_repr({
mol_line_graph.ndata.update({
'msg_input': msg_input,
'msg': F.relu(msg_input),
'accum_msg': torch.zeros_like(msg_input),
})
mol_graph.set_n_repr({
mol_graph.ndata.update({
'm': bond_features.new(n_nodes, self.hidden_size).zero_(),
'h': bond_features.new(n_nodes, self.hidden_size).zero_(),
})
......@@ -175,14 +172,12 @@ class DGLMPN(nn.Module):
mpn_loopy_bp_msg,
mpn_loopy_bp_reduce,
self.loopy_bp_updater,
True
)
mol_graph.update_all(
mpn_gather_msg,
mpn_gather_reduce,
self.gather_updater,
True
)
return mol_graph
import torch
import torch.nn as nn
from torch.autograd import Variable
import os
def create_var(tensor, requires_grad=None):
if requires_grad is None:
......@@ -9,7 +10,7 @@ def create_var(tensor, requires_grad=None):
return Variable(tensor, requires_grad=requires_grad)
def cuda(tensor):
if torch.cuda.is_available():
if torch.cuda.is_available() and not os.getenv('NOCUDA', None):
return tensor.cuda()
else:
return tensor
......@@ -25,17 +26,29 @@ class GRUUpdate(nn.Module):
self.U_r = nn.Linear(hidden_size, hidden_size)
self.W_h = nn.Linear(2 * hidden_size, hidden_size)
def forward(self, node):
src_x = node['src_x']
dst_x = node['dst_x']
s = node['s']
rm = node['accum_rm']
def update_zm(self, node):
src_x = node.data['src_x']
s = node.data['s']
rm = node.data['accum_rm']
z = torch.sigmoid(self.W_z(torch.cat([src_x, s], 1)))
m = torch.tanh(self.W_h(torch.cat([src_x, rm], 1)))
m = (1 - z) * s + z * m
return {'m': m, 'z': z}
def update_r(self, node, zm=None):
dst_x = node.data['dst_x']
m = node.data['m'] if zm is None else zm['m']
r_1 = self.W_r(dst_x)
r_2 = self.U_r(m)
r = torch.sigmoid(r_1 + r_2)
return {'r': r, 'rm': r * m}
def forward(self, node):
dic = self.update_zm(node)
dic.update(self.update_r(node, zm=dic))
return dic
return {'m': m, 'r': r, 'z': z, 'rm': r * m}
def move_dgl_to_cuda(g):
g.ndata.update({k: cuda(g.ndata[k]) for k in g.ndata})
g.edata.update({k: cuda(g.edata[k]) for k in g.edata})
......@@ -11,8 +11,12 @@ import rdkit
from jtnn import *
lg = rdkit.RDLogger.logger()
lg.setLevel(rdkit.RDLogger.CRITICAL)
torch.multiprocessing.set_sharing_strategy('file_system')
def worker_init_fn(id_):
lg = rdkit.RDLogger.logger()
lg.setLevel(rdkit.RDLogger.CRITICAL)
worker_init_fn(None)
parser = OptionParser()
parser.add_option("-t", "--train", dest="train", default='train', help='Training file name')
......@@ -25,10 +29,11 @@ parser.add_option("-l", "--latent", dest="latent_size", default=56)
parser.add_option("-d", "--depth", dest="depth", default=3)
parser.add_option("-z", "--beta", dest="beta", default=1.0)
parser.add_option("-q", "--lr", dest="lr", default=1e-3)
parser.add_option("-T", "--test", dest="test", action="store_true")
opts,args = parser.parse_args()
dataset = JTNNDataset(data=opts.train, vocab=opts.vocab)
vocab = Vocab([x.strip("\r\n ") for x in open(dataset.vocab_file)])
dataset = JTNNDataset(data=opts.train, vocab=opts.vocab, training=True)
vocab = dataset.vocab
batch_size = int(opts.batch_size)
hidden_size = int(opts.hidden_size)
......@@ -48,39 +53,37 @@ else:
else:
nn.init.xavier_normal(param)
if torch.cuda.is_available():
model = model.cuda()
model = cuda(model)
print("Model #Params: %dK" % (sum([x.nelement() for x in model.parameters()]) / 1000,))
optimizer = optim.Adam(model.parameters(), lr=lr)
scheduler = lr_scheduler.ExponentialLR(optimizer, 0.9)
scheduler.step()
MAX_EPOCH = 1
MAX_EPOCH = 100
PRINT_ITER = 20
@profile
def train():
dataset.training = True
dataloader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=True,
num_workers=0,
collate_fn=lambda x:x,
drop_last=True)
num_workers=4,
collate_fn=JTNNCollator(vocab, True),
drop_last=True,
worker_init_fn=worker_init_fn)
for epoch in range(MAX_EPOCH):
word_acc,topo_acc,assm_acc,steo_acc = 0,0,0,0
for it, batch in enumerate(dataloader):
for mol_tree in batch:
for node_id, node in mol_tree.nodes.items():
if node['label'] not in node['cands']:
node['cands'].append(node['label'])
node['cand_mols'].append(node['label_mol'])
model.zero_grad()
try:
loss, kl_div, wacc, tacc, sacc, dacc = model(batch, beta)
except:
print([t.smiles for t in batch['mol_trees']])
raise
loss.backward()
optimizer.step()
......@@ -95,8 +98,8 @@ def train():
assm_acc = assm_acc / PRINT_ITER * 100
steo_acc = steo_acc / PRINT_ITER * 100
print("KL: %.1f, Word: %.2f, Topo: %.2f, Assm: %.2f, Steo: %.2f" % (
kl_div, word_acc, topo_acc, assm_acc, steo_acc))
print("KL: %.1f, Word: %.2f, Topo: %.2f, Assm: %.2f, Steo: %.2f, Loss: %.6f" % (
kl_div, word_acc, topo_acc, assm_acc, steo_acc, loss.item()))
word_acc,topo_acc,assm_acc,steo_acc = 0,0,0,0
sys.stdout.flush()
......@@ -110,7 +113,32 @@ def train():
print("learning rate: %.6f" % scheduler.get_lr()[0])
torch.save(model.state_dict(), opts.save_path + "/model.iter-" + str(epoch))
def test():
dataset.training = False
dataloader = DataLoader(
dataset,
batch_size=1,
shuffle=False,
num_workers=0,
collate_fn=JTNNCollator(vocab, False),
drop_last=True,
worker_init_fn=worker_init_fn)
# Just an example of molecule decoding; in reality you may want to sample
# tree and molecule vectors.
for it, batch in enumerate(dataloader):
gt_smiles = batch['mol_trees'][0].smiles
print(gt_smiles)
model.move_to_cuda(batch)
_, tree_vec, mol_vec = model.encode(batch)
tree_vec, mol_vec, _, _ = model.sample(tree_vec, mol_vec)
smiles = model.decode(tree_vec, mol_vec)
print(smiles)
if __name__ == '__main__':
if opts.test:
test()
else:
train()
print('# passes:', model.n_passes)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment