"...pytorch/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "7612af0f180fd090a3651aa07f265b00f135f663"
Commit ec4216dd authored by Gan Quan's avatar Gan Quan Committed by Minjie Wang
Browse files

[MODEL] junction tree vae example (#70)

* junction tree vae example

* added README and moved data to dropbox

* auto download; some fix in python3
parent 1043c7d0
Junction Tree VAE - example for training
===
This is a direct modification from https://github.com/wengong-jin/icml18-jtnn
You need to have RDKit installed.
To run the model, use
```
python3 vaetrain_dgl.py
```
The script will automatically download the data, which is the same as the one in the
original repository.
from .mol_tree import Vocab
from .jtnn_vae import DGLJTNNVAE
from .mpn import DGLMPN, mol2dgl
from .nnutils import create_var
from .datautils import JTNNDataset
from .chemutils import decode_stereo
from .line_profiler_integration import profile
import rdkit
import rdkit.Chem as Chem
from scipy.sparse import csr_matrix
from scipy.sparse.csgraph import minimum_spanning_tree
from collections import defaultdict
from rdkit.Chem.EnumerateStereoisomers import EnumerateStereoisomers, StereoEnumerationOptions
MST_MAX_WEIGHT = 100
MAX_NCAND = 2000
def set_atommap(mol, num=0):
for atom in mol.GetAtoms():
atom.SetAtomMapNum(num)
def get_mol(smiles):
mol = Chem.MolFromSmiles(smiles)
if mol is None:
return None
Chem.Kekulize(mol)
return mol
def get_smiles(mol):
return Chem.MolToSmiles(mol, kekuleSmiles=True)
def decode_stereo(smiles2D):
mol = Chem.MolFromSmiles(smiles2D)
dec_isomers = list(EnumerateStereoisomers(mol))
dec_isomers = [Chem.MolFromSmiles(Chem.MolToSmiles(mol, isomericSmiles=True)) for mol in dec_isomers]
smiles3D = [Chem.MolToSmiles(mol, isomericSmiles=True) for mol in dec_isomers]
chiralN = [atom.GetIdx() for atom in dec_isomers[0].GetAtoms() if int(atom.GetChiralTag()) > 0 and atom.GetSymbol() == "N"]
if len(chiralN) > 0:
for mol in dec_isomers:
for idx in chiralN:
mol.GetAtomWithIdx(idx).SetChiralTag(Chem.rdchem.ChiralType.CHI_UNSPECIFIED)
smiles3D.append(Chem.MolToSmiles(mol, isomericSmiles=True))
return smiles3D
def sanitize(mol):
try:
smiles = get_smiles(mol)
mol = get_mol(smiles)
except Exception as e:
return None
return mol
def copy_atom(atom):
new_atom = Chem.Atom(atom.GetSymbol())
new_atom.SetFormalCharge(atom.GetFormalCharge())
new_atom.SetAtomMapNum(atom.GetAtomMapNum())
return new_atom
def copy_edit_mol(mol):
new_mol = Chem.RWMol(Chem.MolFromSmiles(''))
for atom in mol.GetAtoms():
new_atom = copy_atom(atom)
new_mol.AddAtom(new_atom)
for bond in mol.GetBonds():
a1 = bond.GetBeginAtom().GetIdx()
a2 = bond.GetEndAtom().GetIdx()
bt = bond.GetBondType()
new_mol.AddBond(a1, a2, bt)
return new_mol
def get_clique_mol(mol, atoms):
smiles = Chem.MolFragmentToSmiles(mol, atoms, kekuleSmiles=True)
new_mol = Chem.MolFromSmiles(smiles, sanitize=False)
new_mol = copy_edit_mol(new_mol).GetMol()
new_mol = sanitize(new_mol) #We assume this is not None
return new_mol
def tree_decomp(mol):
n_atoms = mol.GetNumAtoms()
if n_atoms == 1:
return [[0]], []
cliques = []
for bond in mol.GetBonds():
a1 = bond.GetBeginAtom().GetIdx()
a2 = bond.GetEndAtom().GetIdx()
if not bond.IsInRing():
cliques.append([a1,a2])
ssr = [list(x) for x in Chem.GetSymmSSSR(mol)]
cliques.extend(ssr)
nei_list = [[] for i in range(n_atoms)]
for i in range(len(cliques)):
for atom in cliques[i]:
nei_list[atom].append(i)
#Merge Rings with intersection > 2 atoms
for i in range(len(cliques)):
if len(cliques[i]) <= 2: continue
for atom in cliques[i]:
for j in nei_list[atom]:
if i >= j or len(cliques[j]) <= 2: continue
inter = set(cliques[i]) & set(cliques[j])
if len(inter) > 2:
cliques[i].extend(cliques[j])
cliques[i] = list(set(cliques[i]))
cliques[j] = []
cliques = [c for c in cliques if len(c) > 0]
nei_list = [[] for i in range(n_atoms)]
for i in range(len(cliques)):
for atom in cliques[i]:
nei_list[atom].append(i)
#Build edges and add singleton cliques
edges = defaultdict(int)
for atom in range(n_atoms):
if len(nei_list[atom]) <= 1:
continue
cnei = nei_list[atom]
bonds = [c for c in cnei if len(cliques[c]) == 2]
rings = [c for c in cnei if len(cliques[c]) > 4]
if len(bonds) > 2 or (len(bonds) == 2 and len(cnei) > 2): #In general, if len(cnei) >= 3, a singleton should be added, but 1 bond + 2 ring is currently not dealt with.
cliques.append([atom])
c2 = len(cliques) - 1
for c1 in cnei:
edges[(c1,c2)] = 1
elif len(rings) > 2: #Multiple (n>2) complex rings
cliques.append([atom])
c2 = len(cliques) - 1
for c1 in cnei:
edges[(c1,c2)] = MST_MAX_WEIGHT - 1
else:
for i in range(len(cnei)):
for j in range(i + 1, len(cnei)):
c1,c2 = cnei[i],cnei[j]
inter = set(cliques[c1]) & set(cliques[c2])
if edges[(c1,c2)] < len(inter):
edges[(c1,c2)] = len(inter) #cnei[i] < cnei[j] by construction
edges = [u + (MST_MAX_WEIGHT-v,) for u,v in edges.items()]
if len(edges) == 0:
return cliques, edges
#Compute Maximum Spanning Tree
row,col,data = list(zip(*edges))
n_clique = len(cliques)
clique_graph = csr_matrix( (data,(row,col)), shape=(n_clique,n_clique) )
junc_tree = minimum_spanning_tree(clique_graph)
row,col = junc_tree.nonzero()
edges = [(row[i],col[i]) for i in range(len(row))]
return (cliques, edges)
def atom_equal(a1, a2):
return a1.GetSymbol() == a2.GetSymbol() and a1.GetFormalCharge() == a2.GetFormalCharge()
#Bond type not considered because all aromatic (so SINGLE matches DOUBLE)
def ring_bond_equal(b1, b2, reverse=False):
b1 = (b1.GetBeginAtom(), b1.GetEndAtom())
if reverse:
b2 = (b2.GetEndAtom(), b2.GetBeginAtom())
else:
b2 = (b2.GetBeginAtom(), b2.GetEndAtom())
return atom_equal(b1[0], b2[0]) and atom_equal(b1[1], b2[1])
def attach_mols_nx(ctr_mol, neighbors, prev_nodes, nei_amap):
prev_nids = [node['nid'] for node in prev_nodes]
for nei_node in prev_nodes + neighbors:
nei_id, nei_mol = nei_node['nid'], nei_node['mol']
amap = nei_amap[nei_id]
for atom in nei_mol.GetAtoms():
if atom.GetIdx() not in amap:
new_atom = copy_atom(atom)
amap[atom.GetIdx()] = ctr_mol.AddAtom(new_atom)
if nei_mol.GetNumBonds() == 0:
nei_atom = nei_mol.GetAtomWithIdx(0)
ctr_atom = ctr_mol.GetAtomWithIdx(amap[0])
ctr_atom.SetAtomMapNum(nei_atom.GetAtomMapNum())
else:
for bond in nei_mol.GetBonds():
a1 = amap[bond.GetBeginAtom().GetIdx()]
a2 = amap[bond.GetEndAtom().GetIdx()]
if ctr_mol.GetBondBetweenAtoms(a1, a2) is None:
ctr_mol.AddBond(a1, a2, bond.GetBondType())
elif nei_id in prev_nids: #father node overrides
ctr_mol.RemoveBond(a1, a2)
ctr_mol.AddBond(a1, a2, bond.GetBondType())
return ctr_mol
def local_attach_nx(ctr_mol, neighbors, prev_nodes, amap_list):
ctr_mol = copy_edit_mol(ctr_mol)
nei_amap = {nei['nid']: {} for nei in prev_nodes + neighbors}
for nei_id,ctr_atom,nei_atom in amap_list:
nei_amap[nei_id][nei_atom] = ctr_atom
ctr_mol = attach_mols_nx(ctr_mol, neighbors, prev_nodes, nei_amap)
return ctr_mol.GetMol()
#This version records idx mapping between ctr_mol and nei_mol
def enum_attach_nx(ctr_mol, nei_node, amap, singletons):
nei_mol,nei_idx = nei_node['mol'], nei_node['nid']
att_confs = []
black_list = [atom_idx for nei_id,atom_idx,_ in amap if nei_id in singletons]
ctr_atoms = [atom for atom in ctr_mol.GetAtoms() if atom.GetIdx() not in black_list]
ctr_bonds = [bond for bond in ctr_mol.GetBonds()]
if nei_mol.GetNumBonds() == 0: #neighbor singleton
nei_atom = nei_mol.GetAtomWithIdx(0)
used_list = [atom_idx for _,atom_idx,_ in amap]
for atom in ctr_atoms:
if atom_equal(atom, nei_atom) and atom.GetIdx() not in used_list:
new_amap = amap + [(nei_idx, atom.GetIdx(), 0)]
att_confs.append( new_amap )
elif nei_mol.GetNumBonds() == 1: #neighbor is a bond
bond = nei_mol.GetBondWithIdx(0)
bond_val = int(bond.GetBondTypeAsDouble())
b1,b2 = bond.GetBeginAtom(), bond.GetEndAtom()
for atom in ctr_atoms:
#Optimize if atom is carbon (other atoms may change valence)
if atom.GetAtomicNum() == 6 and atom.GetTotalNumHs() < bond_val:
continue
if atom_equal(atom, b1):
new_amap = amap + [(nei_idx, atom.GetIdx(), b1.GetIdx())]
att_confs.append( new_amap )
elif atom_equal(atom, b2):
new_amap = amap + [(nei_idx, atom.GetIdx(), b2.GetIdx())]
att_confs.append( new_amap )
else:
#intersection is an atom
for a1 in ctr_atoms:
for a2 in nei_mol.GetAtoms():
if atom_equal(a1, a2):
#Optimize if atom is carbon (other atoms may change valence)
if a1.GetAtomicNum() == 6 and a1.GetTotalNumHs() + a2.GetTotalNumHs() < 4:
continue
new_amap = amap + [(nei_idx, a1.GetIdx(), a2.GetIdx())]
att_confs.append( new_amap )
#intersection is an bond
if ctr_mol.GetNumBonds() > 1:
for b1 in ctr_bonds:
for b2 in nei_mol.GetBonds():
if ring_bond_equal(b1, b2):
new_amap = amap + [(nei_idx, b1.GetBeginAtom().GetIdx(), b2.GetBeginAtom().GetIdx()), (nei_idx, b1.GetEndAtom().GetIdx(), b2.GetEndAtom().GetIdx())]
att_confs.append( new_amap )
if ring_bond_equal(b1, b2, reverse=True):
new_amap = amap + [(nei_idx, b1.GetBeginAtom().GetIdx(), b2.GetEndAtom().GetIdx()), (nei_idx, b1.GetEndAtom().GetIdx(), b2.GetBeginAtom().GetIdx())]
att_confs.append( new_amap )
return att_confs
#Try rings first: Speed-Up
def enum_assemble_nx(graph, node_idx, neighbors, prev_nodes=[], prev_amap=[]):
node = graph.nodes[node_idx]
all_attach_confs = []
singletons = [nei_node['nid'] for nei_node in neighbors + prev_nodes if nei_node['mol'].GetNumAtoms() == 1]
def search(cur_amap, depth):
if len(all_attach_confs) > MAX_NCAND:
return
if depth == len(neighbors):
all_attach_confs.append(cur_amap)
return
nei_node = neighbors[depth]
cand_amap = enum_attach_nx(node['mol'], nei_node, cur_amap, singletons)
cand_smiles = set()
candidates = []
for amap in cand_amap:
cand_mol = local_attach_nx(node['mol'], neighbors[:depth+1], prev_nodes, amap)
cand_mol = sanitize(cand_mol)
if cand_mol is None:
continue
smiles = get_smiles(cand_mol)
if smiles in cand_smiles:
continue
cand_smiles.add(smiles)
candidates.append(amap)
if len(candidates) == 0:
return []
for new_amap in candidates:
search(new_amap, depth + 1)
search(prev_amap, 0)
cand_smiles = set()
candidates = []
for amap in all_attach_confs:
cand_mol = local_attach_nx(node['mol'], neighbors, prev_nodes, amap)
cand_mol = Chem.MolFromSmiles(Chem.MolToSmiles(cand_mol))
smiles = Chem.MolToSmiles(cand_mol)
if smiles in cand_smiles:
continue
cand_smiles.add(smiles)
Chem.Kekulize(cand_mol)
candidates.append( (smiles,cand_mol,amap) )
return candidates
#Only used for debugging purpose
def dfs_assemble_nx(graph, cur_mol, global_amap, fa_amap, cur_node_id, fa_node_id):
cur_node = graph.nodes[cur_node_id]
fa_node = graph.nodes[fa_node_id] if fa_node_id is not None else None
fa_nid = fa_node['nid'] if fa_node is not None else -1
prev_nodes = [fa_node] if fa_node is not None else []
children_id = [nei for nei in graph[cur_node_id] if graph.nodes[nei]['nid'] != fa_nid]
children = [graph.nodes[nei] for nei in children_id]
neighbors = [nei for nei in children if nei['mol'].GetNumAtoms() > 1]
neighbors = sorted(neighbors, key=lambda x:x['mol'].GetNumAtoms(), reverse=True)
singletons = [nei for nei in children if nei['mol'].GetNumAtoms() == 1]
neighbors = singletons + neighbors
cur_amap = [(fa_nid,a2,a1) for nid,a1,a2 in fa_amap if nid == cur_node['nid']]
cands = enum_assemble_nx(graph, cur_node_id, neighbors, prev_nodes, cur_amap)
if len(cands) == 0:
return
cand_smiles, _, cand_amap = zip(*cands)
label_idx = cand_smiles.index(cur_node['label'])
label_amap = cand_amap[label_idx]
for nei_id,ctr_atom,nei_atom in label_amap:
if nei_id == fa_nid:
continue
global_amap[nei_id][nei_atom] = global_amap[cur_node['nid']][ctr_atom]
cur_mol = attach_mols_nx(cur_mol, children, [], global_amap) #father is already attached
for nei_node_id, nei_node in zip(children_id, children):
if not nei_node['is_leaf']:
dfs_assemble_nx(graph, cur_mol, global_amap, label_amap, nei_node_id, cur_node_id)
from torch.utils.data import Dataset
import numpy as np
import dgl
from dgl.data.utils import download, extract_archive, get_download_dir
_url = 'https://www.dropbox.com/s/4ypr0e0abcbsvoh/jtnn.zip?dl=1'
class JTNNDataset(Dataset):
def __init__(self, data, vocab):
self.dir = get_download_dir()
self.zip_file_path='{}/jtnn.zip'.format(self.dir)
download(_url, path=self.zip_file_path)
extract_archive(self.zip_file_path, '{}/jtnn'.format(self.dir))
print('Loading data...')
data_file = '{}/jtnn/{}.txt'.format(self.dir, data)
with open(data_file) as f:
self.data = [line.strip("\r\n ").split()[0] for line in f]
self.vocab_file = '{}/jtnn/{}.txt'.format(self.dir, vocab)
print('Loading finished.')
print('\tNum samples:', len(self.data))
print('\tVocab file:', self.vocab_file)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
from .mol_tree_nx import DGLMolTree
smiles = self.data[idx]
mol_tree = DGLMolTree(smiles)
mol_tree.recover()
mol_tree.assemble()
return mol_tree
import torch
import torch.nn as nn
from .nnutils import cuda
from .chemutils import get_mol
#from mpn import atom_features, bond_features, ATOM_FDIM, BOND_FDIM
import rdkit.Chem as Chem
from dgl import DGLGraph, line_graph, batch, unbatch
import dgl.function as DGLF
from .line_profiler_integration import profile
import os
ELEM_LIST = ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na', 'Ca', 'Fe', 'Al', 'I', 'B', 'K', 'Se', 'Zn', 'H', 'Cu', 'Mn', 'unknown']
ATOM_FDIM = len(ELEM_LIST) + 6 + 5 + 1
BOND_FDIM = 5
MAX_NB = 10
PAPER = os.getenv('PAPER', False)
def onek_encoding_unk(x, allowable_set):
if x not in allowable_set:
x = allowable_set[-1]
return [x == s for s in allowable_set]
# Note that during graph decoding they don't predict stereochemistry-related
# characteristics (i.e. Chiral Atoms, E-Z, Cis-Trans). Instead, they decode
# the 2-D graph first, then enumerate all possible 3-D forms and find the
# one with highest score.
def atom_features(atom):
return cuda(torch.Tensor(onek_encoding_unk(atom.GetSymbol(), ELEM_LIST)
+ onek_encoding_unk(atom.GetDegree(), [0,1,2,3,4,5])
+ onek_encoding_unk(atom.GetFormalCharge(), [-1,-2,1,2,0])
+ [atom.GetIsAromatic()]))
def bond_features(bond):
bt = bond.GetBondType()
return cuda(torch.Tensor([bt == Chem.rdchem.BondType.SINGLE, bt == Chem.rdchem.BondType.DOUBLE, bt == Chem.rdchem.BondType.TRIPLE, bt == Chem.rdchem.BondType.AROMATIC, bond.IsInRing()]))
@profile
def mol2dgl(cand_batch, mol_tree_batch):
cand_graphs = []
tree_mess_source_edges = [] # map these edges from trees to...
tree_mess_target_edges = [] # these edges on candidate graphs
tree_mess_target_nodes = []
n_nodes = 0
for mol, mol_tree, ctr_node_id in cand_batch:
atom_feature_list = []
bond_feature_list = []
ctr_node = mol_tree.nodes[ctr_node_id]
ctr_bid = ctr_node['idx']
g = DGLGraph()
for atom in mol.GetAtoms():
atom_feature_list.append(atom_features(atom))
g.add_node(atom.GetIdx())
for bond in mol.GetBonds():
a1 = bond.GetBeginAtom()
a2 = bond.GetEndAtom()
begin_idx = a1.GetIdx()
end_idx = a2.GetIdx()
features = bond_features(bond)
g.add_edge(begin_idx, end_idx)
bond_feature_list.append(features)
g.add_edge(end_idx, begin_idx)
bond_feature_list.append(features)
x_nid, y_nid = a1.GetAtomMapNum(), a2.GetAtomMapNum()
# Tree node ID in the batch
x_bid = mol_tree.nodes[x_nid - 1]['idx'] if x_nid > 0 else -1
y_bid = mol_tree.nodes[y_nid - 1]['idx'] if y_nid > 0 else -1
if x_bid >= 0 and y_bid >= 0 and x_bid != y_bid:
if (x_bid, y_bid) in mol_tree_batch.edge_list:
tree_mess_target_edges.append(
(begin_idx + n_nodes, end_idx + n_nodes))
tree_mess_source_edges.append((x_bid, y_bid))
tree_mess_target_nodes.append(end_idx + n_nodes)
if (y_bid, x_bid) in mol_tree_batch.edge_list:
tree_mess_target_edges.append(
(end_idx + n_nodes, begin_idx + n_nodes))
tree_mess_source_edges.append((y_bid, x_bid))
tree_mess_target_nodes.append(begin_idx + n_nodes)
n_nodes += len(g.nodes)
atom_x = torch.stack(atom_feature_list)
g.set_n_repr({'x': atom_x})
if len(bond_feature_list) > 0:
bond_x = torch.stack(bond_feature_list)
g.set_e_repr({
'x': bond_x,
'src_x': atom_x.new(len(bond_feature_list), ATOM_FDIM).zero_()
})
cand_graphs.append(g)
return cand_graphs, tree_mess_source_edges, tree_mess_target_edges, \
tree_mess_target_nodes
mpn_loopy_bp_msg = DGLF.copy_src(src='msg', out='msg')
mpn_loopy_bp_reduce = DGLF.sum(msgs='msg', out='accum_msg')
class LoopyBPUpdate(nn.Module):
def __init__(self, hidden_size):
super(LoopyBPUpdate, self).__init__()
self.hidden_size = hidden_size
self.W_h = nn.Linear(hidden_size, hidden_size, bias=False)
def forward(self, node):
msg_input = node['msg_input']
msg_delta = self.W_h(node['accum_msg'] + node['alpha'])
msg = torch.relu(msg_input + msg_delta)
return {'msg': msg}
if PAPER:
mpn_gather_msg = [
DGLF.copy_edge(edge='msg', out='msg'),
DGLF.copy_edge(edge='alpha', out='alpha')
]
else:
mpn_gather_msg = DGLF.copy_edge(edge='msg', out='msg')
if PAPER:
mpn_gather_reduce = [
DGLF.sum(msgs='msg', out='m'),
DGLF.sum(msgs='alpha', out='accum_alpha'),
]
else:
mpn_gather_reduce = DGLF.sum(msgs='msg', out='m')
class GatherUpdate(nn.Module):
def __init__(self, hidden_size):
super(GatherUpdate, self).__init__()
self.hidden_size = hidden_size
self.W_o = nn.Linear(ATOM_FDIM + hidden_size, hidden_size)
def forward(self, node):
if PAPER:
#m = node['m']
m = node['m'] + node['accum_alpha']
else:
m = node['m'] + node['alpha']
return {
'h': torch.relu(self.W_o(torch.cat([node['x'], m], 1))),
}
class DGLJTMPN(nn.Module):
def __init__(self, hidden_size, depth):
nn.Module.__init__(self)
self.depth = depth
self.W_i = nn.Linear(ATOM_FDIM + BOND_FDIM, hidden_size, bias=False)
self.loopy_bp_updater = LoopyBPUpdate(hidden_size)
self.gather_updater = GatherUpdate(hidden_size)
self.hidden_size = hidden_size
self.n_samples_total = 0
self.n_nodes_total = 0
self.n_edges_total = 0
self.n_passes = 0
@profile
def forward(self, cand_batch, mol_tree_batch):
cand_graphs, tree_mess_src_edges, tree_mess_tgt_edges, tree_mess_tgt_nodes = \
mol2dgl(cand_batch, mol_tree_batch)
n_samples = len(cand_graphs)
cand_graphs = batch(cand_graphs)
cand_line_graph = line_graph(cand_graphs, no_backtracking=True)
n_nodes = len(cand_graphs.nodes)
n_edges = len(cand_graphs.edges)
cand_graphs = self.run(
cand_graphs, cand_line_graph, tree_mess_src_edges, tree_mess_tgt_edges,
tree_mess_tgt_nodes, mol_tree_batch)
cand_graphs = unbatch(cand_graphs)
g_repr = torch.stack([g.get_n_repr()['h'].mean(0) for g in cand_graphs], 0)
self.n_samples_total += n_samples
self.n_nodes_total += n_nodes
self.n_edges_total += n_edges
self.n_passes += 1
return g_repr
@profile
def run(self, cand_graphs, cand_line_graph, tree_mess_src_edges, tree_mess_tgt_edges,
tree_mess_tgt_nodes, mol_tree_batch):
n_nodes = len(cand_graphs.nodes)
cand_graphs.update_edge(
#*zip(*cand_graphs.edge_list),
edge_func=lambda src, dst, edge: {'src_x': src['x']},
batchable=True,
)
bond_features = cand_line_graph.get_n_repr()['x']
source_features = cand_line_graph.get_n_repr()['src_x']
features = torch.cat([source_features, bond_features], 1)
msg_input = self.W_i(features)
cand_line_graph.set_n_repr({
'msg_input': msg_input,
'msg': torch.relu(msg_input),
'accum_msg': torch.zeros_like(msg_input),
})
zero_node_state = bond_features.new(n_nodes, self.hidden_size).zero_()
cand_graphs.set_n_repr({
'm': zero_node_state.clone(),
'h': zero_node_state.clone(),
})
if PAPER:
cand_graphs.set_e_repr({
'alpha': cuda(torch.zeros(len(cand_graphs.edge_list), self.hidden_size))
})
alpha = mol_tree_batch.get_e_repr(*zip(*tree_mess_src_edges))['m']
cand_graphs.set_e_repr({'alpha': alpha}, *zip(*tree_mess_tgt_edges))
else:
alpha = mol_tree_batch.get_e_repr(*zip(*tree_mess_src_edges))['m']
node_idx = (torch.LongTensor(tree_mess_tgt_nodes)
.to(device=zero_node_state.device)[:, None]
.expand_as(alpha))
node_alpha = zero_node_state.clone().scatter_add(0, node_idx, alpha)
cand_graphs.set_n_repr({'alpha': node_alpha})
cand_graphs.update_edge(
#*zip(*cand_graphs.edge_list),
edge_func=lambda src, dst, edge: {'alpha': src['alpha']},
batchable=True,
)
for i in range(self.depth - 1):
cand_line_graph.update_all(
mpn_loopy_bp_msg,
mpn_loopy_bp_reduce,
self.loopy_bp_updater,
True
)
cand_graphs.update_all(
mpn_gather_msg,
mpn_gather_reduce,
self.gather_updater,
True
)
return cand_graphs
import torch
import torch.nn as nn
import torch.nn.functional as F
from .mol_tree import Vocab
from .nnutils import GRUUpdate, cuda
import copy
import itertools
from dgl import batch, line_graph
import dgl.function as DGLF
import networkx as nx
from .line_profiler_integration import profile
MAX_NB = 8
MAX_DECODE_LEN = 100
def dfs_order(forest, roots):
'''
Returns edge source, edge destination, tree ID, and whether u is generating
a new children
'''
edge_list = []
for i, root in enumerate(roots):
edge_list.append([])
# The following gives the DFS order on edge on a tree.
for u, v, t in nx.dfs_labeled_edges(forest, root):
if u == v or t == 'nontree':
continue
elif t == 'forward':
edge_list[-1].append((u, v, i, 1))
elif t == 'reverse':
edge_list[-1].append((v, u, i, 0))
for edges in itertools.zip_longest(*edge_list):
edges = (e for e in edges if e is not None)
u, v, i, p = zip(*edges)
yield u, v, i, p
dec_tree_node_msg = DGLF.copy_edge(edge='m', out='m')
dec_tree_node_reduce = DGLF.sum(msgs='m', out='h')
def dec_tree_node_update(node):
return {'new': node['new'].clone().zero_()}
dec_tree_edge_msg = [DGLF.copy_src(src='m', out='m'), DGLF.copy_src(src='rm', out='rm')]
dec_tree_edge_reduce = [DGLF.sum(msgs='m', out='s'), DGLF.sum(msgs='rm', out='accum_rm')]
class DGLJTNNDecoder(nn.Module):
def __init__(self, vocab, hidden_size, latent_size, embedding=None):
nn.Module.__init__(self)
self.hidden_size = hidden_size
self.vocab_size = vocab.size()
self.vocab = vocab
if embedding is None:
self.embedding = nn.Embedding(self.vocab_size, hidden_size)
else:
self.embedding = embedding
self.dec_tree_edge_update = GRUUpdate(hidden_size)
self.W = nn.Linear(latent_size + hidden_size, hidden_size)
self.U = nn.Linear(latent_size + 2 * hidden_size, hidden_size)
self.W_o = nn.Linear(hidden_size, self.vocab_size)
self.U_s = nn.Linear(hidden_size, 1)
@profile
def forward(self, mol_trees, tree_vec):
'''
The training procedure which computes the prediction loss given the
ground truth tree
'''
mol_tree_batch = batch(mol_trees)
mol_tree_batch_lg = line_graph(mol_tree_batch, no_backtracking=True)
n_trees = len(mol_trees)
return self.run(mol_tree_batch, mol_tree_batch_lg, n_trees, tree_vec)
@profile
def run(self, mol_tree_batch, mol_tree_batch_lg, n_trees, tree_vec):
root_ids = mol_tree_batch.node_offset[:-1]
n_nodes = len(mol_tree_batch.nodes)
edge_list = mol_tree_batch.edge_list
n_edges = len(edge_list)
mol_tree_batch.set_n_repr({
'x': self.embedding(mol_tree_batch.get_n_repr()['wid']),
'h': cuda(torch.zeros(n_nodes, self.hidden_size)),
'new': cuda(torch.ones(n_nodes).byte()), # whether it's newly generated node
})
mol_tree_batch.set_e_repr({
's': cuda(torch.zeros(n_edges, self.hidden_size)),
'm': cuda(torch.zeros(n_edges, self.hidden_size)),
'r': cuda(torch.zeros(n_edges, self.hidden_size)),
'z': cuda(torch.zeros(n_edges, self.hidden_size)),
'src_x': cuda(torch.zeros(n_edges, self.hidden_size)),
'dst_x': cuda(torch.zeros(n_edges, self.hidden_size)),
'rm': cuda(torch.zeros(n_edges, self.hidden_size)),
'accum_rm': cuda(torch.zeros(n_edges, self.hidden_size)),
})
mol_tree_batch.update_edge(
#*zip(*edge_list),
edge_func=lambda src, dst, edge: {'src_x': src['x'], 'dst_x': dst['x']},
batchable=True,
)
# input tensors for stop prediction (p) and label prediction (q)
p_inputs = []
p_targets = []
q_inputs = []
q_targets = []
# Predict root
mol_tree_batch.pull(
root_ids,
dec_tree_node_msg,
dec_tree_node_reduce,
dec_tree_node_update,
batchable=True,
)
# Extract hidden states and store them for stop/label prediction
h = mol_tree_batch.get_n_repr(root_ids)['h']
x = mol_tree_batch.get_n_repr(root_ids)['x']
p_inputs.append(torch.cat([x, h, tree_vec], 1))
t_set = list(range(len(root_ids)))
q_inputs.append(torch.cat([h, tree_vec], 1))
q_targets.append(mol_tree_batch.get_n_repr(root_ids)['wid'])
# Traverse the tree and predict on children
for u, v, i, p in dfs_order(mol_tree_batch, root_ids):
assert set(t_set).issuperset(i)
ip = dict(zip(i, p))
# TODO: context
p_targets.append(cuda(torch.tensor([ip.get(_i, 0) for _i in t_set])))
t_set = list(i)
eid = mol_tree_batch.get_edge_id(u, v)
mol_tree_batch_lg.pull(
eid,
dec_tree_edge_msg,
dec_tree_edge_reduce,
self.dec_tree_edge_update,
batchable=True,
)
is_new = mol_tree_batch.get_n_repr(v)['new']
mol_tree_batch.pull(
v,
dec_tree_node_msg,
dec_tree_node_reduce,
dec_tree_node_update,
batchable=True,
)
# Extract
h = mol_tree_batch.get_n_repr(v)['h']
x = mol_tree_batch.get_n_repr(v)['x']
p_inputs.append(torch.cat([x, h, tree_vec[t_set]], 1))
# Only newly generated nodes are needed for label prediction
# NOTE: The following works since the uncomputed messages are zeros.
q_inputs.append(torch.cat([h[is_new], tree_vec[t_set][is_new]], 1))
q_targets.append(mol_tree_batch.get_n_repr(v)['wid'][is_new])
p_targets.append(cuda(torch.tensor([0 for _ in t_set])))
# Batch compute the stop/label prediction losses
p_inputs = torch.cat(p_inputs, 0)
p_targets = torch.cat(p_targets, 0)
q_inputs = torch.cat(q_inputs, 0)
q_targets = torch.cat(q_targets, 0)
q = self.W_o(torch.relu(self.W(q_inputs)))
p = self.U_s(torch.relu(self.U(p_inputs)))[:, 0]
p_loss = F.binary_cross_entropy_with_logits(
p, p_targets.float(), size_average=False
) / n_trees
q_loss = F.cross_entropy(q, q_targets, size_average=False) / n_trees
p_acc = ((p > 0).long() == p_targets).sum().float() / p_targets.shape[0]
q_acc = (q.max(1)[1] == q_targets).float().sum() / q_targets.shape[0]
return q_loss, p_loss, q_acc, p_acc
import torch
import torch.nn as nn
from collections import deque
from .mol_tree import Vocab
from .nnutils import GRUUpdate, cuda
import itertools
import networkx as nx
from dgl import batch, unbatch, line_graph
import dgl.function as DGLF
from .line_profiler_integration import profile
MAX_NB = 8
def level_order(forest, roots):
'''
Given the forest and the list of root nodes,
returns iterator of list of edges ordered by depth, first in bottom-up
and then top-down
'''
edge_list = []
node_depth = {}
edge_list.append([])
for root in roots:
node_depth[root] = 0
for u, v in nx.bfs_edges(forest, root):
node_depth[v] = node_depth[u] + 1
if len(edge_list) == node_depth[u]:
edge_list.append([])
edge_list[node_depth[u]].append((u, v))
for edges in reversed(edge_list):
u, v = zip(*edges)
yield v, u
for edges in edge_list:
u, v = zip(*edges)
yield u, v
enc_tree_msg = [DGLF.copy_src(src='m', out='m'), DGLF.copy_src(src='rm', out='rm')]
enc_tree_reduce = [DGLF.sum(msgs='m', out='s'), DGLF.sum(msgs='rm', out='accum_rm')]
enc_tree_gather_msg = DGLF.copy_edge(edge='m', out='m')
enc_tree_gather_reduce = DGLF.sum(msgs='m', out='m')
class EncoderGatherUpdate(nn.Module):
def __init__(self, hidden_size):
nn.Module.__init__(self)
self.hidden_size = hidden_size
self.W = nn.Linear(2 * hidden_size, hidden_size)
def forward(self, node):
x = node['x']
m = node['m']
return {
'h': torch.relu(self.W(torch.cat([x, m], 1))),
}
class DGLJTNNEncoder(nn.Module):
def __init__(self, vocab, hidden_size, embedding=None):
nn.Module.__init__(self)
self.hidden_size = hidden_size
self.vocab_size = vocab.size()
self.vocab = vocab
if embedding is None:
self.embedding = nn.Embedding(self.vocab_size, hidden_size)
else:
self.embedding = embedding
self.enc_tree_update = GRUUpdate(hidden_size)
self.enc_tree_gather_update = EncoderGatherUpdate(hidden_size)
@profile
def forward(self, mol_trees):
mol_tree_batch = batch(mol_trees)
# Build line graph to prepare for belief propagation
mol_tree_batch_lg = line_graph(mol_tree_batch, no_backtracking=True)
return self.run(mol_tree_batch, mol_tree_batch_lg)
@profile
def run(self, mol_tree_batch, mol_tree_batch_lg):
# Since tree roots are designated to 0. In the batched graph we can
# simply find the corresponding node ID by looking at node_offset
root_ids = mol_tree_batch.node_offset[:-1]
n_nodes = len(mol_tree_batch.nodes)
edge_list = mol_tree_batch.edge_list
n_edges = len(edge_list)
# Assign structure embeddings to tree nodes
mol_tree_batch.set_n_repr({
'x': self.embedding(mol_tree_batch.get_n_repr()['wid']),
'h': cuda(torch.zeros(n_nodes, self.hidden_size)),
})
# Initialize the intermediate variables according to Eq (4)-(8).
# Also initialize the src_x and dst_x fields.
# TODO: context?
mol_tree_batch.set_e_repr({
's': cuda(torch.zeros(n_edges, self.hidden_size)),
'm': cuda(torch.zeros(n_edges, self.hidden_size)),
'r': cuda(torch.zeros(n_edges, self.hidden_size)),
'z': cuda(torch.zeros(n_edges, self.hidden_size)),
'src_x': cuda(torch.zeros(n_edges, self.hidden_size)),
'dst_x': cuda(torch.zeros(n_edges, self.hidden_size)),
'rm': cuda(torch.zeros(n_edges, self.hidden_size)),
'accum_rm': cuda(torch.zeros(n_edges, self.hidden_size)),
})
# Send the source/destination node features to edges
mol_tree_batch.update_edge(
#*zip(*edge_list),
edge_func=lambda src, dst, edge: {'src_x': src['x'], 'dst_x': dst['x']},
batchable=True,
)
# Message passing
# I exploited the fact that the reduce function is a sum of incoming
# messages, and the uncomputed messages are zero vectors. Essentially,
# we can always compute s_ij as the sum of incoming m_ij, no matter
# if m_ij is actually computed or not.
for u, v in level_order(mol_tree_batch, root_ids):
eid = mol_tree_batch.get_edge_id(u, v)
mol_tree_batch_lg.pull(
eid,
enc_tree_msg,
enc_tree_reduce,
self.enc_tree_update,
batchable=True,
)
# Readout
mol_tree_batch.update_all(
enc_tree_gather_msg,
enc_tree_gather_reduce,
self.enc_tree_gather_update,
batchable=True,
)
root_vecs = mol_tree_batch.get_n_repr(root_ids)['h']
return mol_tree_batch, root_vecs
import torch
import torch.nn as nn
import torch.nn.functional as F
from .mol_tree import Vocab
from .nnutils import create_var, cuda
from .jtnn_enc import DGLJTNNEncoder
from .jtnn_dec import DGLJTNNDecoder
from .mpn import DGLMPN, mol2dgl
from .jtmpn import DGLJTMPN
from .line_profiler_integration import profile
import rdkit
import rdkit.Chem as Chem
from rdkit import DataStructs
from rdkit.Chem import AllChem
import copy, math
def dgl_set_batch_nodeID(mol_batch, vocab):
tot = 0
for mol_tree in mol_batch:
wid = []
for node in mol_tree.nodes:
mol_tree.nodes[node]['idx'] = tot
tot += 1
wid.append(vocab.get_index(mol_tree.nodes[node]['smiles']))
mol_tree.set_n_repr({'wid': cuda(torch.LongTensor(wid))})
class DGLJTNNVAE(nn.Module):
def __init__(self, vocab, hidden_size, latent_size, depth):
super(DGLJTNNVAE, self).__init__()
self.vocab = vocab
self.hidden_size = hidden_size
self.latent_size = latent_size
self.depth = depth
self.embedding = nn.Embedding(vocab.size(), hidden_size)
self.mpn = DGLMPN(hidden_size, depth)
self.jtnn = DGLJTNNEncoder(vocab, hidden_size, self.embedding)
self.decoder = DGLJTNNDecoder(
vocab, hidden_size, latent_size // 2, self.embedding)
self.jtmpn = DGLJTMPN(hidden_size, depth)
self.T_mean = nn.Linear(hidden_size, latent_size // 2)
self.T_var = nn.Linear(hidden_size, latent_size // 2)
self.G_mean = nn.Linear(hidden_size, latent_size // 2)
self.G_var = nn.Linear(hidden_size, latent_size // 2)
self.n_nodes_total = 0
self.n_passes = 0
self.n_edges_total = 0
self.n_tree_nodes_total = 0
@profile
def encode(self, mol_batch):
dgl_set_batch_nodeID(mol_batch, self.vocab)
smiles_batch = [mol_tree.smiles for mol_tree in mol_batch]
mol_graphs = mol2dgl(smiles_batch)
mol_vec = self.mpn(mol_graphs)
# mol_batch is a junction tree
mol_tree_batch, tree_vec = self.jtnn(mol_batch)
self.n_nodes_total += sum(len(g.nodes) for g in mol_graphs)
self.n_edges_total += sum(len(g.edges) for g in mol_graphs)
self.n_tree_nodes_total += sum(len(t.nodes) for t in mol_batch)
self.n_passes += 1
return mol_tree_batch, tree_vec, mol_vec
@profile
def forward(self, mol_batch, beta=0, e1=None, e2=None):
batch_size = len(mol_batch)
mol_tree_batch, tree_vec, mol_vec = self.encode(mol_batch)
tree_mean = self.T_mean(tree_vec)
tree_log_var = -torch.abs(self.T_var(tree_vec))
mol_mean = self.G_mean(mol_vec)
mol_log_var = -torch.abs(self.G_var(mol_vec))
self.tree_mean = tree_mean
self.tree_log_var = tree_log_var
self.mol_mean = mol_mean
self.mol_log_var = mol_log_var
z_mean = torch.cat([tree_mean, mol_mean], dim=1)
z_log_var = torch.cat([tree_log_var, mol_log_var], dim=1)
kl_loss = -0.5 * torch.sum(1.0 + z_log_var - z_mean * z_mean - torch.exp(z_log_var)) / batch_size
self.z_mean = z_mean
self.z_log_var = z_log_var
epsilon = cuda(torch.randn(batch_size, self.latent_size // 2)) if e1 is None else e1
tree_vec = tree_mean + torch.exp(tree_log_var / 2) * epsilon
epsilon = cuda(torch.randn(batch_size, self.latent_size // 2)) if e2 is None else e2
mol_vec = mol_mean + torch.exp(mol_log_var / 2) * epsilon
self.tree_vec = tree_vec
self.mol_vec = mol_vec
word_loss, topo_loss, word_acc, topo_acc = self.decoder(mol_batch, tree_vec)
assm_loss, assm_acc = self.assm(mol_batch, mol_tree_batch, mol_vec)
stereo_loss, stereo_acc = self.stereo(mol_batch, mol_vec)
self.word_loss_v = word_loss
self.topo_loss_v = topo_loss
self.assm_loss_v = assm_loss
self.stereo_loss_v = stereo_loss
all_vec = torch.cat([tree_vec, mol_vec], dim=1)
loss = word_loss + topo_loss + assm_loss + 2 * stereo_loss + beta * kl_loss
self.all_vec = all_vec
return loss, kl_loss, word_acc, topo_acc, assm_acc, stereo_acc
@profile
def assm(self, mol_batch, mol_tree_batch, mol_vec):
cands = []
batch_idx = []
for i, mol_tree in enumerate(mol_batch):
for node_id, node in mol_tree.nodes.items():
if node['is_leaf'] or len(node['cands']) == 1:
continue
cands.extend([(cand, mol_tree, node_id) for cand in node['cand_mols']])
batch_idx.extend([i] * len(node['cands']))
cand_vec = self.jtmpn(cands, mol_tree_batch)
cand_vec = self.G_mean(cand_vec)
batch_idx = cuda(torch.LongTensor(batch_idx))
mol_vec = mol_vec[batch_idx]
mol_vec = mol_vec.view(-1, 1, self.latent_size // 2)
cand_vec = cand_vec.view(-1, self.latent_size // 2, 1)
scores = (mol_vec @ cand_vec)[:, 0, 0]
cnt, tot, acc = 0, 0, 0
all_loss = []
for i, mol_tree in enumerate(mol_batch):
comp_nodes = [node_id for node_id, node in mol_tree.nodes.items()
if len(node['cands']) > 1 and not node['is_leaf']]
cnt += len(comp_nodes)
# segmented accuracy and cross entropy
for node_id in comp_nodes:
node = mol_tree.nodes[node_id]
label = node['cands'].index(node['label'])
ncand = len(node['cands'])
cur_score = scores[tot:tot+ncand]
tot += ncand
if cur_score[label].item() >= cur_score.max().item():
acc += 1
label = cuda(torch.LongTensor([label]))
all_loss.append(
F.cross_entropy(cur_score.view(1, -1), label, size_average=False))
all_loss = sum(all_loss) / len(mol_batch)
return all_loss, acc / cnt
@profile
def stereo(self, mol_batch, mol_vec):
stereo_cands, batch_idx = [], []
labels = []
for i, mol_tree in enumerate(mol_batch):
cands = mol_tree.stereo_cands
if len(cands) == 1:
continue
if mol_tree.smiles3D not in cands:
cands.append(mol_tree.smiles3D)
stereo_cands.extend(cands)
batch_idx.extend([i] * len(cands))
labels.append((cands.index(mol_tree.smiles3D), len(cands)))
if len(labels) == 0:
# Only one stereoisomer exists; do nothing
return cuda(torch.tensor(0.)), 1.
batch_idx = cuda(torch.LongTensor(batch_idx))
stereo_cands = self.mpn(mol2dgl(stereo_cands))
stereo_cands = self.G_mean(stereo_cands)
stereo_labels = mol_vec[batch_idx]
scores = F.cosine_similarity(stereo_cands, stereo_labels)
st, acc = 0, 0
all_loss = []
for label, le in labels:
cur_scores = scores[st:st+le]
if cur_scores.data[label].item() >= cur_scores.max().item():
acc += 1
label = cuda(torch.LongTensor([label]))
all_loss.append(
F.cross_entropy(cur_scores.view(1, -1), label, size_average=False))
st += le
all_loss = sum(all_loss) / len(labels)
return all_loss, acc / len(labels)
'''
line_profiler integration
'''
import os
if os.getenv('PROFILE', 0):
import line_profiler
import atexit
profile = line_profiler.LineProfiler()
profile_output = os.getenv('PROFILE_OUTPUT', None)
if profile_output:
from functools import partial
atexit.register(partial(profile.dump_stats, profile_output))
else:
atexit.register(profile.print_stats)
else:
def profile(f):
return f
import rdkit
import rdkit.Chem as Chem
import copy
def get_slots(smiles):
mol = Chem.MolFromSmiles(smiles)
return [(atom.GetSymbol(), atom.GetFormalCharge(), atom.GetTotalNumHs()) for atom in mol.GetAtoms()]
class Vocab(object):
def __init__(self, smiles_list):
self.vocab = smiles_list
self.vmap = {x:i for i,x in enumerate(self.vocab)}
self.slots = [get_slots(smiles) for smiles in self.vocab]
def get_index(self, smiles):
return self.vmap[smiles]
def get_smiles(self, idx):
return self.vocab[idx]
def get_slots(self, idx):
return copy.deepcopy(self.slots[idx])
def size(self):
return len(self.vocab)
from dgl import DGLGraph
import rdkit.Chem as Chem
from .chemutils import get_clique_mol, tree_decomp, get_mol, get_smiles, \
set_atommap, enum_assemble_nx, decode_stereo
class DGLMolTree(DGLGraph):
def __init__(self, smiles):
DGLGraph.__init__(self)
self.smiles = smiles
self.mol = get_mol(smiles)
# Stereo Generation
mol = Chem.MolFromSmiles(smiles)
self.smiles3D = Chem.MolToSmiles(mol, isomericSmiles=True)
self.smiles2D = Chem.MolToSmiles(mol)
self.stereo_cands = decode_stereo(self.smiles2D)
# cliques: a list of list of atom indices
cliques, edges = tree_decomp(self.mol)
root = 0
for i, c in enumerate(cliques):
cmol = get_clique_mol(self.mol, c)
csmiles = get_smiles(cmol)
self.add_node(
i,
smiles=csmiles,
mol=get_mol(csmiles),
clique=c,
)
if min(c) == 0:
root = i
# The clique with atom ID 0 becomes root
if root > 0:
for attr in self.nodes[0]:
self.nodes[0][attr], self.nodes[root][attr] = \
self.nodes[root][attr], self.nodes[0][attr]
for _x, _y in edges:
x = 0 if _x == root else root if _x == 0 else _x
y = 0 if _y == root else root if _y == 0 else _y
self.add_edge(x, y)
self.add_edge(y, x)
for i in self.nodes:
self.nodes[i]['nid'] = i + 1
if len(self[i]) > 1: # Leaf node mol is not marked
set_atommap(self.nodes[i]['mol'], self.nodes[i]['nid'])
self.nodes[i]['is_leaf'] = (len(self[i]) == 1)
# avoiding DiGraph.size()
def treesize(self):
return len(self.nodes)
def _recover_node(self, i, original_mol):
node = self.nodes[i]
clique = []
clique.extend(node['clique'])
if not node['is_leaf']:
for cidx in node['clique']:
original_mol.GetAtomWithIdx(cidx).SetAtomMapNum(node['nid'])
for j in self[i]:
nei_node = self.nodes[j]
clique.extend(nei_node['clique'])
if nei_node['is_leaf']: # Leaf node, no need to mark
continue
for cidx in nei_node['clique']:
# allow singleton node override the atom mapping
if cidx not in node['clique'] or len(nei_node['clique']) == 1:
atom = original_mol.GetAtomWithIdx(cidx)
atom.SetAtomMapNum(nei_node['nid'])
clique = list(set(clique))
label_mol = get_clique_mol(original_mol, clique)
node['label'] = Chem.MolToSmiles(Chem.MolFromSmiles(get_smiles(label_mol)))
node['label_mol'] = get_mol(node['label'])
for cidx in clique:
original_mol.GetAtomWithIdx(cidx).SetAtomMapNum(0)
return node['label']
def _assemble_node(self, i):
neighbors = [self.nodes[j] for j in self[i] if self.nodes[j]['mol'].GetNumAtoms() > 1]
neighbors = sorted(neighbors, key=lambda x: x['mol'].GetNumAtoms(), reverse=True)
singletons = [self.nodes[j] for j in self[i] if self.nodes[j]['mol'].GetNumAtoms() == 1]
neighbors = singletons + neighbors
cands = enum_assemble_nx(self, i, neighbors)
if len(cands) > 0:
self.nodes[i]['cands'], self.nodes[i]['cand_mols'], _ = list(zip(*cands))
self.nodes[i]['cands'] = list(self.nodes[i]['cands'])
self.nodes[i]['cand_mols'] = list(self.nodes[i]['cand_mols'])
else:
self.nodes[i]['cands'] = []
self.nodes[i]['cand_mols'] = []
def recover(self):
for i in self.nodes:
self._recover_node(i, self.mol)
def assemble(self):
for i in self.nodes:
self._assemble_node(i)
import torch
import torch.nn as nn
import rdkit.Chem as Chem
import torch.nn.functional as F
from .nnutils import *
from .chemutils import get_mol
from networkx import Graph, DiGraph, line_graph, convert_node_labels_to_integers
from dgl import DGLGraph, line_graph, batch, unbatch
import dgl.function as DGLF
from functools import partial
from .line_profiler_integration import profile
ELEM_LIST = ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na', 'Ca', 'Fe', 'Al', 'I', 'B', 'K', 'Se', 'Zn', 'H', 'Cu', 'Mn', 'unknown']
ATOM_FDIM = len(ELEM_LIST) + 6 + 5 + 4 + 1
BOND_FDIM = 5 + 6
MAX_NB = 6
def onek_encoding_unk(x, allowable_set):
if x not in allowable_set:
x = allowable_set[-1]
return [x == s for s in allowable_set]
def atom_features(atom):
return cuda(torch.Tensor(onek_encoding_unk(atom.GetSymbol(), ELEM_LIST)
+ onek_encoding_unk(atom.GetDegree(), [0,1,2,3,4,5])
+ onek_encoding_unk(atom.GetFormalCharge(), [-1,-2,1,2,0])
+ onek_encoding_unk(int(atom.GetChiralTag()), [0,1,2,3])
+ [atom.GetIsAromatic()]))
def bond_features(bond):
bt = bond.GetBondType()
stereo = int(bond.GetStereo())
fbond = [bt == Chem.rdchem.BondType.SINGLE, bt == Chem.rdchem.BondType.DOUBLE, bt == Chem.rdchem.BondType.TRIPLE, bt == Chem.rdchem.BondType.AROMATIC, bond.IsInRing()]
fstereo = onek_encoding_unk(stereo, [0,1,2,3,4,5])
return cuda(torch.Tensor(fbond + fstereo))
@profile
def mol2dgl(smiles_batch):
n_nodes = 0
graph_list = []
for smiles in smiles_batch:
atom_feature_list = []
bond_feature_list = []
bond_source_feature_list = []
graph = DGLGraph()
mol = get_mol(smiles)
for atom in mol.GetAtoms():
graph.add_node(atom.GetIdx())
atom_feature_list.append(atom_features(atom))
for bond in mol.GetBonds():
begin_idx = bond.GetBeginAtom().GetIdx()
end_idx = bond.GetEndAtom().GetIdx()
features = bond_features(bond)
graph.add_edge(begin_idx, end_idx)
bond_feature_list.append(features)
# set up the reverse direction
graph.add_edge(end_idx, begin_idx)
bond_feature_list.append(features)
atom_x = torch.stack(atom_feature_list)
graph.set_n_repr({'x': atom_x})
if len(bond_feature_list) > 0:
bond_x = torch.stack(bond_feature_list)
graph.set_e_repr({
'x': bond_x,
'src_x': atom_x.new(len(bond_feature_list), ATOM_FDIM).zero_()
})
graph_list.append(graph)
return graph_list
mpn_loopy_bp_msg = DGLF.copy_src(src='msg', out='msg')
mpn_loopy_bp_reduce = DGLF.sum(msgs='msg', out='accum_msg')
class LoopyBPUpdate(nn.Module):
def __init__(self, hidden_size):
super(LoopyBPUpdate, self).__init__()
self.hidden_size = hidden_size
self.W_h = nn.Linear(hidden_size, hidden_size, bias=False)
def forward(self, node):
msg_input = node['msg_input']
msg_delta = self.W_h(node['accum_msg'])
msg = F.relu(msg_input + msg_delta)
return {'msg': msg}
mpn_gather_msg = DGLF.copy_edge(edge='msg', out='msg')
mpn_gather_reduce = DGLF.sum(msgs='msg', out='m')
class GatherUpdate(nn.Module):
def __init__(self, hidden_size):
super(GatherUpdate, self).__init__()
self.hidden_size = hidden_size
self.W_o = nn.Linear(ATOM_FDIM + hidden_size, hidden_size)
def forward(self, node):
m = node['m']
return {
'h': F.relu(self.W_o(torch.cat([node['x'], m], 1))),
}
class DGLMPN(nn.Module):
def __init__(self, hidden_size, depth):
super(DGLMPN, self).__init__()
self.depth = depth
self.W_i = nn.Linear(ATOM_FDIM + BOND_FDIM, hidden_size, bias=False)
self.loopy_bp_updater = LoopyBPUpdate(hidden_size)
self.gather_updater = GatherUpdate(hidden_size)
self.hidden_size = hidden_size
self.n_samples_total = 0
self.n_nodes_total = 0
self.n_edges_total = 0
self.n_passes = 0
@profile
def forward(self, mol_graph_list):
n_samples = len(mol_graph_list)
mol_graph = batch(mol_graph_list)
mol_line_graph = line_graph(mol_graph, no_backtracking=True)
n_nodes = len(mol_graph.nodes)
n_edges = len(mol_graph.edges)
mol_graph = self.run(mol_graph, mol_line_graph)
mol_graph_list = unbatch(mol_graph)
g_repr = torch.stack([g.get_n_repr()['h'].mean(0) for g in mol_graph_list], 0)
self.n_samples_total += n_samples
self.n_nodes_total += n_nodes
self.n_edges_total += n_edges
self.n_passes += 1
return g_repr
@profile
def run(self, mol_graph, mol_line_graph):
n_nodes = len(mol_graph.nodes)
mol_graph.update_edge(
#*zip(*mol_graph.edge_list),
edge_func=lambda src, dst, edge: {'src_x': src['x']},
batchable=True,
)
bond_features = mol_line_graph.get_n_repr()['x']
source_features = mol_line_graph.get_n_repr()['src_x']
features = torch.cat([source_features, bond_features], 1)
msg_input = self.W_i(features)
mol_line_graph.set_n_repr({
'msg_input': msg_input,
'msg': F.relu(msg_input),
'accum_msg': torch.zeros_like(msg_input),
})
mol_graph.set_n_repr({
'm': bond_features.new(n_nodes, self.hidden_size).zero_(),
'h': bond_features.new(n_nodes, self.hidden_size).zero_(),
})
for i in range(self.depth - 1):
mol_line_graph.update_all(
mpn_loopy_bp_msg,
mpn_loopy_bp_reduce,
self.loopy_bp_updater,
True
)
mol_graph.update_all(
mpn_gather_msg,
mpn_gather_reduce,
self.gather_updater,
True
)
return mol_graph
import torch
import torch.nn as nn
from torch.autograd import Variable
def create_var(tensor, requires_grad=None):
if requires_grad is None:
return Variable(tensor)
else:
return Variable(tensor, requires_grad=requires_grad)
def cuda(tensor):
if torch.cuda.is_available():
return tensor.cuda()
else:
return tensor
class GRUUpdate(nn.Module):
def __init__(self, hidden_size):
nn.Module.__init__(self)
self.hidden_size = hidden_size
self.W_z = nn.Linear(2 * hidden_size, hidden_size)
self.W_r = nn.Linear(hidden_size, hidden_size, bias=False)
self.U_r = nn.Linear(hidden_size, hidden_size)
self.W_h = nn.Linear(2 * hidden_size, hidden_size)
def forward(self, node):
src_x = node['src_x']
dst_x = node['dst_x']
s = node['s']
rm = node['accum_rm']
z = torch.sigmoid(self.W_z(torch.cat([src_x, s], 1)))
m = torch.tanh(self.W_h(torch.cat([src_x, rm], 1)))
m = (1 - z) * s + z * m
r_1 = self.W_r(dst_x)
r_2 = self.U_r(m)
r = torch.sigmoid(r_1 + r_2)
return {'m': m, 'r': r, 'z': z, 'rm': r * m}
import torch
import torch.nn as nn
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
from torch.utils.data import DataLoader
import math, random, sys
from optparse import OptionParser
from collections import deque
import rdkit
from jtnn import *
lg = rdkit.RDLogger.logger()
lg.setLevel(rdkit.RDLogger.CRITICAL)
parser = OptionParser()
parser.add_option("-t", "--train", dest="train", default='train', help='Training file name')
parser.add_option("-v", "--vocab", dest="vocab", default='vocab', help='Vocab file name')
parser.add_option("-s", "--save_dir", dest="save_path")
parser.add_option("-m", "--model", dest="model_path", default=None)
parser.add_option("-b", "--batch", dest="batch_size", default=40)
parser.add_option("-w", "--hidden", dest="hidden_size", default=200)
parser.add_option("-l", "--latent", dest="latent_size", default=56)
parser.add_option("-d", "--depth", dest="depth", default=3)
parser.add_option("-z", "--beta", dest="beta", default=1.0)
parser.add_option("-q", "--lr", dest="lr", default=1e-3)
opts,args = parser.parse_args()
dataset = JTNNDataset(data=opts.train, vocab=opts.vocab)
vocab = Vocab([x.strip("\r\n ") for x in open(dataset.vocab_file)])
batch_size = int(opts.batch_size)
hidden_size = int(opts.hidden_size)
latent_size = int(opts.latent_size)
depth = int(opts.depth)
beta = float(opts.beta)
lr = float(opts.lr)
model = DGLJTNNVAE(vocab, hidden_size, latent_size, depth)
if opts.model_path is not None:
model.load_state_dict(torch.load(opts.model_path))
else:
for param in model.parameters():
if param.dim() == 1:
nn.init.constant(param, 0)
else:
nn.init.xavier_normal(param)
if torch.cuda.is_available():
model = model.cuda()
print("Model #Params: %dK" % (sum([x.nelement() for x in model.parameters()]) / 1000,))
optimizer = optim.Adam(model.parameters(), lr=lr)
scheduler = lr_scheduler.ExponentialLR(optimizer, 0.9)
scheduler.step()
MAX_EPOCH = 1
PRINT_ITER = 20
@profile
def train():
dataloader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=True,
num_workers=0,
collate_fn=lambda x:x,
drop_last=True)
for epoch in range(MAX_EPOCH):
word_acc,topo_acc,assm_acc,steo_acc = 0,0,0,0
for it, batch in enumerate(dataloader):
for mol_tree in batch:
for node_id, node in mol_tree.nodes.items():
if node['label'] not in node['cands']:
node['cands'].append(node['label'])
node['cand_mols'].append(node['label_mol'])
model.zero_grad()
loss, kl_div, wacc, tacc, sacc, dacc = model(batch, beta)
loss.backward()
optimizer.step()
word_acc += wacc
topo_acc += tacc
assm_acc += sacc
steo_acc += dacc
if (it + 1) % PRINT_ITER == 0:
word_acc = word_acc / PRINT_ITER * 100
topo_acc = topo_acc / PRINT_ITER * 100
assm_acc = assm_acc / PRINT_ITER * 100
steo_acc = steo_acc / PRINT_ITER * 100
print("KL: %.1f, Word: %.2f, Topo: %.2f, Assm: %.2f, Steo: %.2f" % (
kl_div, word_acc, topo_acc, assm_acc, steo_acc))
word_acc,topo_acc,assm_acc,steo_acc = 0,0,0,0
sys.stdout.flush()
if (it + 1) % 1500 == 0: #Fast annealing
scheduler.step()
print("learning rate: %.6f" % scheduler.get_lr()[0])
torch.save(model.state_dict(),
opts.save_path + "/model.iter-%d-%d" % (epoch, it + 1))
scheduler.step()
print("learning rate: %.6f" % scheduler.get_lr()[0])
torch.save(model.state_dict(), opts.save_path + "/model.iter-" + str(epoch))
if __name__ == '__main__':
train()
print('# passes:', model.n_passes)
print('Total # nodes processed:', model.n_nodes_total)
print('Total # edges processed:', model.n_edges_total)
print('Total # tree nodes processed:', model.n_tree_nodes_total)
print('Graph decoder: # passes:', model.jtmpn.n_passes)
print('Graph decoder: Total # candidates processed:', model.jtmpn.n_samples_total)
print('Graph decoder: Total # nodes processed:', model.jtmpn.n_nodes_total)
print('Graph decoder: Total # edges processed:', model.jtmpn.n_edges_total)
print('Graph encoder: # passes:', model.mpn.n_passes)
print('Graph encoder: Total # candidates processed:', model.mpn.n_samples_total)
print('Graph encoder: Total # nodes processed:', model.mpn.n_nodes_total)
print('Graph encoder: Total # edges processed:', model.mpn.n_edges_total)
"""Dataset utilities.""" """Dataset utilities."""
from __future__ import absolute_import
import os import os, sys
import hashlib import hashlib
import warnings import warnings
import zipfile import zipfile
...@@ -125,17 +126,22 @@ def extract_archive(file, target_dir): ...@@ -125,17 +126,22 @@ def extract_archive(file, target_dir):
target_dir : str target_dir : str
Target directory of the archive to be uncompressed Target directory of the archive to be uncompressed
""" """
if os.path.exists(target_dir):
return
if file.endswith('.gz') or file.endswith('.tar') or file.endswith('.tgz'): if file.endswith('.gz') or file.endswith('.tar') or file.endswith('.tgz'):
archive = tarfile.open(file, 'r') archive = tarfile.open(file, 'r')
elif file.endswith('.zip'): elif file.endswith('.zip'):
archive = zipfile.ZipFile(file, 'r') archive = zipfile.ZipFile(file, 'r')
else: else:
raise Exception('Unrecognized file type: ' + file) raise Exception('Unrecognized file type: ' + file)
print('Extracting file to {}'.format(target_dir))
archive.extractall(path=target_dir) archive.extractall(path=target_dir)
archive.close() archive.close()
def get_download_dir(): def get_download_dir():
dirname = '_download' """Get the absolute path to the download directory."""
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
dirname = os.path.join(curr_path, '../../../_download')
if not os.path.exists(dirname): if not os.path.exists(dirname):
os.makedirs(dirname) os.makedirs(dirname)
return dirname return dirname
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment