Unverified Commit 36c7b771 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[LifeSci] Move to Independent Repo (#1592)

* Move LifeSci

* Remove doc
parent 94c67203
import rdkit.Chem as Chem
import torch
from collections import defaultdict
from rdkit.Chem.EnumerateStereoisomers import EnumerateStereoisomers
from scipy.sparse import csr_matrix
from scipy.sparse.csgraph import minimum_spanning_tree
from dgl import DGLGraph
ELEM_LIST = ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na',
'Ca', 'Fe', 'Al', 'I', 'B', 'K', 'Se', 'Zn', 'H', 'Cu', 'Mn', 'unknown']
MST_MAX_WEIGHT = 100
MAX_NCAND = 2000
def onek_encoding_unk(x, allowable_set):
if x not in allowable_set:
x = allowable_set[-1]
return [x == s for s in allowable_set]
def set_atommap(mol, num=0):
for atom in mol.GetAtoms():
atom.SetAtomMapNum(num)
def get_mol(smiles):
mol = Chem.MolFromSmiles(smiles)
if mol is None:
return None
Chem.Kekulize(mol)
return mol
def get_smiles(mol):
return Chem.MolToSmiles(mol, kekuleSmiles=True)
def decode_stereo(smiles2D):
mol = Chem.MolFromSmiles(smiles2D)
dec_isomers = list(EnumerateStereoisomers(mol))
dec_isomers = [Chem.MolFromSmiles(Chem.MolToSmiles(
mol, isomericSmiles=True)) for mol in dec_isomers]
smiles3D = [Chem.MolToSmiles(mol, isomericSmiles=True)
for mol in dec_isomers]
chiralN = [atom.GetIdx() for atom in dec_isomers[0].GetAtoms() if int(
atom.GetChiralTag()) > 0 and atom.GetSymbol() == "N"]
if len(chiralN) > 0:
for mol in dec_isomers:
for idx in chiralN:
mol.GetAtomWithIdx(idx).SetChiralTag(
Chem.rdchem.ChiralType.CHI_UNSPECIFIED)
smiles3D.append(Chem.MolToSmiles(mol, isomericSmiles=True))
return smiles3D
def sanitize(mol):
try:
smiles = get_smiles(mol)
mol = get_mol(smiles)
except Exception as e:
return None
return mol
def copy_atom(atom):
new_atom = Chem.Atom(atom.GetSymbol())
new_atom.SetFormalCharge(atom.GetFormalCharge())
new_atom.SetAtomMapNum(atom.GetAtomMapNum())
return new_atom
def copy_edit_mol(mol):
new_mol = Chem.RWMol(Chem.MolFromSmiles(''))
for atom in mol.GetAtoms():
new_atom = copy_atom(atom)
new_mol.AddAtom(new_atom)
for bond in mol.GetBonds():
a1 = bond.GetBeginAtom().GetIdx()
a2 = bond.GetEndAtom().GetIdx()
bt = bond.GetBondType()
new_mol.AddBond(a1, a2, bt)
return new_mol
def get_clique_mol(mol, atoms):
smiles = Chem.MolFragmentToSmiles(mol, atoms, kekuleSmiles=True)
new_mol = Chem.MolFromSmiles(smiles, sanitize=False)
new_mol = copy_edit_mol(new_mol).GetMol()
new_mol = sanitize(new_mol) # We assume this is not None
return new_mol
def tree_decomp(mol):
n_atoms = mol.GetNumAtoms()
if n_atoms == 1:
return [[0]], []
cliques = []
for bond in mol.GetBonds():
a1 = bond.GetBeginAtom().GetIdx()
a2 = bond.GetEndAtom().GetIdx()
if not bond.IsInRing():
cliques.append([a1, a2])
ssr = [list(x) for x in Chem.GetSymmSSSR(mol)]
cliques.extend(ssr)
nei_list = [[] for i in range(n_atoms)]
for i in range(len(cliques)):
for atom in cliques[i]:
nei_list[atom].append(i)
# Merge Rings with intersection > 2 atoms
for i in range(len(cliques)):
if len(cliques[i]) <= 2:
continue
for atom in cliques[i]:
for j in nei_list[atom]:
if i >= j or len(cliques[j]) <= 2:
continue
inter = set(cliques[i]) & set(cliques[j])
if len(inter) > 2:
cliques[i].extend(cliques[j])
cliques[i] = list(set(cliques[i]))
cliques[j] = []
cliques = [c for c in cliques if len(c) > 0]
nei_list = [[] for i in range(n_atoms)]
for i in range(len(cliques)):
for atom in cliques[i]:
nei_list[atom].append(i)
# Build edges and add singleton cliques
edges = defaultdict(int)
for atom in range(n_atoms):
if len(nei_list[atom]) <= 1:
continue
cnei = nei_list[atom]
bonds = [c for c in cnei if len(cliques[c]) == 2]
rings = [c for c in cnei if len(cliques[c]) > 4]
# In general, if len(cnei) >= 3, a singleton should be added, but 1 bond + 2 ring is currently not dealt with.
if len(bonds) > 2 or (len(bonds) == 2 and len(cnei) > 2):
cliques.append([atom])
c2 = len(cliques) - 1
for c1 in cnei:
edges[(c1, c2)] = 1
elif len(rings) > 2: # Multiple (n>2) complex rings
cliques.append([atom])
c2 = len(cliques) - 1
for c1 in cnei:
edges[(c1, c2)] = MST_MAX_WEIGHT - 1
else:
for i in range(len(cnei)):
for j in range(i + 1, len(cnei)):
c1, c2 = cnei[i], cnei[j]
inter = set(cliques[c1]) & set(cliques[c2])
if edges[(c1, c2)] < len(inter):
# cnei[i] < cnei[j] by construction
edges[(c1, c2)] = len(inter)
edges = [u + (MST_MAX_WEIGHT-v,) for u, v in edges.items()]
if len(edges) == 0:
return cliques, edges
# Compute Maximum Spanning Tree
row, col, data = list(zip(*edges))
n_clique = len(cliques)
clique_graph = csr_matrix((data, (row, col)), shape=(n_clique, n_clique))
junc_tree = minimum_spanning_tree(clique_graph)
row, col = junc_tree.nonzero()
edges = [(row[i], col[i]) for i in range(len(row))]
return (cliques, edges)
def atom_equal(a1, a2):
return a1.GetSymbol() == a2.GetSymbol() and a1.GetFormalCharge() == a2.GetFormalCharge()
# Bond type not considered because all aromatic (so SINGLE matches DOUBLE)
def ring_bond_equal(b1, b2, reverse=False):
b1 = (b1.GetBeginAtom(), b1.GetEndAtom())
if reverse:
b2 = (b2.GetEndAtom(), b2.GetBeginAtom())
else:
b2 = (b2.GetBeginAtom(), b2.GetEndAtom())
return atom_equal(b1[0], b2[0]) and atom_equal(b1[1], b2[1])
def attach_mols_nx(ctr_mol, neighbors, prev_nodes, nei_amap):
prev_nids = [node['nid'] for node in prev_nodes]
for nei_node in prev_nodes + neighbors:
nei_id, nei_mol = nei_node['nid'], nei_node['mol']
amap = nei_amap[nei_id]
for atom in nei_mol.GetAtoms():
if atom.GetIdx() not in amap:
new_atom = copy_atom(atom)
amap[atom.GetIdx()] = ctr_mol.AddAtom(new_atom)
if nei_mol.GetNumBonds() == 0:
nei_atom = nei_mol.GetAtomWithIdx(0)
ctr_atom = ctr_mol.GetAtomWithIdx(amap[0])
ctr_atom.SetAtomMapNum(nei_atom.GetAtomMapNum())
else:
for bond in nei_mol.GetBonds():
a1 = amap[bond.GetBeginAtom().GetIdx()]
a2 = amap[bond.GetEndAtom().GetIdx()]
if ctr_mol.GetBondBetweenAtoms(a1, a2) is None:
ctr_mol.AddBond(a1, a2, bond.GetBondType())
elif nei_id in prev_nids: # father node overrides
ctr_mol.RemoveBond(a1, a2)
ctr_mol.AddBond(a1, a2, bond.GetBondType())
return ctr_mol
def local_attach_nx(ctr_mol, neighbors, prev_nodes, amap_list):
ctr_mol = copy_edit_mol(ctr_mol)
nei_amap = {nei['nid']: {} for nei in prev_nodes + neighbors}
for nei_id, ctr_atom, nei_atom in amap_list:
nei_amap[nei_id][nei_atom] = ctr_atom
ctr_mol = attach_mols_nx(ctr_mol, neighbors, prev_nodes, nei_amap)
return ctr_mol.GetMol()
# This version records idx mapping between ctr_mol and nei_mol
def enum_attach_nx(ctr_mol, nei_node, amap, singletons):
nei_mol, nei_idx = nei_node['mol'], nei_node['nid']
att_confs = []
black_list = [atom_idx for nei_id, atom_idx,
_ in amap if nei_id in singletons]
ctr_atoms = [atom for atom in ctr_mol.GetAtoms() if atom.GetIdx()
not in black_list]
ctr_bonds = [bond for bond in ctr_mol.GetBonds()]
if nei_mol.GetNumBonds() == 0: # neighbor singleton
nei_atom = nei_mol.GetAtomWithIdx(0)
used_list = [atom_idx for _, atom_idx, _ in amap]
for atom in ctr_atoms:
if atom_equal(atom, nei_atom) and atom.GetIdx() not in used_list:
new_amap = amap + [(nei_idx, atom.GetIdx(), 0)]
att_confs.append(new_amap)
elif nei_mol.GetNumBonds() == 1: # neighbor is a bond
bond = nei_mol.GetBondWithIdx(0)
bond_val = int(bond.GetBondTypeAsDouble())
b1, b2 = bond.GetBeginAtom(), bond.GetEndAtom()
for atom in ctr_atoms:
# Optimize if atom is carbon (other atoms may change valence)
if atom.GetAtomicNum() == 6 and atom.GetTotalNumHs() < bond_val:
continue
if atom_equal(atom, b1):
new_amap = amap + [(nei_idx, atom.GetIdx(), b1.GetIdx())]
att_confs.append(new_amap)
elif atom_equal(atom, b2):
new_amap = amap + [(nei_idx, atom.GetIdx(), b2.GetIdx())]
att_confs.append(new_amap)
else:
# intersection is an atom
for a1 in ctr_atoms:
for a2 in nei_mol.GetAtoms():
if atom_equal(a1, a2):
# Optimize if atom is carbon (other atoms may change valence)
if a1.GetAtomicNum() == 6 and a1.GetTotalNumHs() + a2.GetTotalNumHs() < 4:
continue
new_amap = amap + [(nei_idx, a1.GetIdx(), a2.GetIdx())]
att_confs.append(new_amap)
# intersection is an bond
if ctr_mol.GetNumBonds() > 1:
for b1 in ctr_bonds:
for b2 in nei_mol.GetBonds():
if ring_bond_equal(b1, b2):
new_amap = amap + [(nei_idx, b1.GetBeginAtom().GetIdx(), b2.GetBeginAtom(
).GetIdx()), (nei_idx, b1.GetEndAtom().GetIdx(), b2.GetEndAtom().GetIdx())]
att_confs.append(new_amap)
if ring_bond_equal(b1, b2, reverse=True):
new_amap = amap + [(nei_idx, b1.GetBeginAtom().GetIdx(), b2.GetEndAtom(
).GetIdx()), (nei_idx, b1.GetEndAtom().GetIdx(), b2.GetBeginAtom().GetIdx())]
att_confs.append(new_amap)
return att_confs
# Try rings first: Speed-Up
def enum_assemble_nx(node, neighbors, prev_nodes=[], prev_amap=[]):
all_attach_confs = []
singletons = [nei_node['nid'] for nei_node in neighbors +
prev_nodes if nei_node['mol'].GetNumAtoms() == 1]
def search(cur_amap, depth):
if len(all_attach_confs) > MAX_NCAND:
return
if depth == len(neighbors):
all_attach_confs.append(cur_amap)
return
nei_node = neighbors[depth]
cand_amap = enum_attach_nx(node['mol'], nei_node, cur_amap, singletons)
cand_smiles = set()
candidates = []
for amap in cand_amap:
cand_mol = local_attach_nx(
node['mol'], neighbors[:depth+1], prev_nodes, amap)
cand_mol = sanitize(cand_mol)
if cand_mol is None:
continue
smiles = get_smiles(cand_mol)
if smiles in cand_smiles:
continue
cand_smiles.add(smiles)
candidates.append(amap)
if len(candidates) == 0:
return []
for new_amap in candidates:
search(new_amap, depth + 1)
search(prev_amap, 0)
cand_smiles = set()
candidates = []
for amap in all_attach_confs:
cand_mol = local_attach_nx(node['mol'], neighbors, prev_nodes, amap)
cand_mol = Chem.MolFromSmiles(Chem.MolToSmiles(cand_mol))
smiles = Chem.MolToSmiles(cand_mol)
if smiles in cand_smiles:
continue
cand_smiles.add(smiles)
Chem.Kekulize(cand_mol)
candidates.append((smiles, cand_mol, amap))
return candidates
# Only used for debugging purpose
def dfs_assemble_nx(graph, cur_mol, global_amap, fa_amap, cur_node_id, fa_node_id):
cur_node = graph.nodes_dict[cur_node_id]
fa_node = graph.nodes_dict[fa_node_id] if fa_node_id is not None else None
fa_nid = fa_node['nid'] if fa_node is not None else -1
prev_nodes = [fa_node] if fa_node is not None else []
children_id = [nei for nei in graph[cur_node_id]
if graph.nodes_dict[nei]['nid'] != fa_nid]
children = [graph.nodes_dict[nei] for nei in children_id]
neighbors = [nei for nei in children if nei['mol'].GetNumAtoms() > 1]
neighbors = sorted(
neighbors, key=lambda x: x['mol'].GetNumAtoms(), reverse=True)
singletons = [nei for nei in children if nei['mol'].GetNumAtoms() == 1]
neighbors = singletons + neighbors
cur_amap = [(fa_nid, a2, a1)
for nid, a1, a2 in fa_amap if nid == cur_node['nid']]
cands = enum_assemble_nx(
graph.nodes_dict[cur_node_id], neighbors, prev_nodes, cur_amap)
if len(cands) == 0:
return
cand_smiles, _, cand_amap = zip(*cands)
label_idx = cand_smiles.index(cur_node['label'])
label_amap = cand_amap[label_idx]
for nei_id, ctr_atom, nei_atom in label_amap:
if nei_id == fa_nid:
continue
global_amap[nei_id][nei_atom] = global_amap[cur_node['nid']][ctr_atom]
# father is already attached
cur_mol = attach_mols_nx(cur_mol, children, [], global_amap)
for nei_node_id, nei_node in zip(children_id, children):
if not nei_node['is_leaf']:
dfs_assemble_nx(graph, cur_mol, global_amap,
label_amap, nei_node_id, cur_node_id)
def mol2dgl_dec(cand_batch):
# Note that during graph decoding they don't predict stereochemistry-related
# characteristics (i.e. Chiral Atoms, E-Z, Cis-Trans). Instead, they decode
# the 2-D graph first, then enumerate all possible 3-D forms and find the
# one with highest score.
def atom_features(atom):
return (torch.Tensor(onek_encoding_unk(atom.GetSymbol(), ELEM_LIST)
+ onek_encoding_unk(atom.GetDegree(),
[0, 1, 2, 3, 4, 5])
+ onek_encoding_unk(atom.GetFormalCharge(),
[-1, -2, 1, 2, 0])
+ [atom.GetIsAromatic()]))
def bond_features(bond):
bt = bond.GetBondType()
return (torch.Tensor([bt == Chem.rdchem.BondType.SINGLE, bt == Chem.rdchem.BondType.DOUBLE,
bt == Chem.rdchem.BondType.TRIPLE, bt == Chem.rdchem.BondType.AROMATIC,
bond.IsInRing()]))
cand_graphs = []
tree_mess_source_edges = [] # map these edges from trees to...
tree_mess_target_edges = [] # these edges on candidate graphs
tree_mess_target_nodes = []
n_nodes = 0
atom_x = []
bond_x = []
for mol, mol_tree, ctr_node_id in cand_batch:
n_atoms = mol.GetNumAtoms()
g = DGLGraph()
for i, atom in enumerate(mol.GetAtoms()):
assert i == atom.GetIdx()
atom_x.append(atom_features(atom))
g.add_nodes(n_atoms)
bond_src = []
bond_dst = []
for i, bond in enumerate(mol.GetBonds()):
a1 = bond.GetBeginAtom()
a2 = bond.GetEndAtom()
begin_idx = a1.GetIdx()
end_idx = a2.GetIdx()
features = bond_features(bond)
bond_src.append(begin_idx)
bond_dst.append(end_idx)
bond_x.append(features)
bond_src.append(end_idx)
bond_dst.append(begin_idx)
bond_x.append(features)
x_nid, y_nid = a1.GetAtomMapNum(), a2.GetAtomMapNum()
# Tree node ID in the batch
x_bid = mol_tree.nodes_dict[x_nid - 1]['idx'] if x_nid > 0 else -1
y_bid = mol_tree.nodes_dict[y_nid - 1]['idx'] if y_nid > 0 else -1
if x_bid >= 0 and y_bid >= 0 and x_bid != y_bid:
if mol_tree.has_edge_between(x_bid, y_bid):
tree_mess_target_edges.append(
(begin_idx + n_nodes, end_idx + n_nodes))
tree_mess_source_edges.append((x_bid, y_bid))
tree_mess_target_nodes.append(end_idx + n_nodes)
if mol_tree.has_edge_between(y_bid, x_bid):
tree_mess_target_edges.append(
(end_idx + n_nodes, begin_idx + n_nodes))
tree_mess_source_edges.append((y_bid, x_bid))
tree_mess_target_nodes.append(begin_idx + n_nodes)
n_nodes += n_atoms
g.add_edges(bond_src, bond_dst)
cand_graphs.append(g)
return cand_graphs, torch.stack(atom_x), \
torch.stack(bond_x) if len(bond_x) > 0 else torch.zeros(0), \
torch.LongTensor(tree_mess_source_edges), \
torch.LongTensor(tree_mess_target_edges), \
torch.LongTensor(tree_mess_target_nodes)
def mol2dgl_enc(smiles):
def atom_features(atom):
return (torch.Tensor(onek_encoding_unk(atom.GetSymbol(), ELEM_LIST)
+ onek_encoding_unk(atom.GetDegree(),
[0, 1, 2, 3, 4, 5])
+ onek_encoding_unk(atom.GetFormalCharge(), [-1, -2, 1, 2, 0])
+ onek_encoding_unk(int(atom.GetChiralTag()), [0, 1, 2, 3])
+ [atom.GetIsAromatic()]))
def bond_features(bond):
bt = bond.GetBondType()
stereo = int(bond.GetStereo())
fbond = [bt == Chem.rdchem.BondType.SINGLE, bt == Chem.rdchem.BondType.DOUBLE, bt ==
Chem.rdchem.BondType.TRIPLE, bt == Chem.rdchem.BondType.AROMATIC, bond.IsInRing()]
fstereo = onek_encoding_unk(stereo, [0, 1, 2, 3, 4, 5])
return (torch.Tensor(fbond + fstereo))
n_edges = 0
atom_x = []
bond_x = []
mol = get_mol(smiles)
n_atoms = mol.GetNumAtoms()
n_bonds = mol.GetNumBonds()
graph = DGLGraph()
for i, atom in enumerate(mol.GetAtoms()):
assert i == atom.GetIdx()
atom_x.append(atom_features(atom))
graph.add_nodes(n_atoms)
bond_src = []
bond_dst = []
for i, bond in enumerate(mol.GetBonds()):
begin_idx = bond.GetBeginAtom().GetIdx()
end_idx = bond.GetEndAtom().GetIdx()
features = bond_features(bond)
bond_src.append(begin_idx)
bond_dst.append(end_idx)
bond_x.append(features)
# set up the reverse direction
bond_src.append(end_idx)
bond_dst.append(begin_idx)
bond_x.append(features)
graph.add_edges(bond_src, bond_dst)
n_edges += n_bonds
return graph, torch.stack(atom_x), \
torch.stack(bond_x) if len(bond_x) > 0 else torch.zeros(0)
import dgl
import os
import torch
from dgl.data.utils import download, extract_archive, get_download_dir
from torch.utils.data import Dataset
from .mol_tree import Vocab, DGLMolTree
from .chemutils import mol2dgl_dec, mol2dgl_enc
ELEM_LIST = ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na', 'Ca',
'Fe', 'Al', 'I', 'B', 'K', 'Se', 'Zn', 'H', 'Cu', 'Mn', 'unknown']
ATOM_FDIM_DEC = len(ELEM_LIST) + 6 + 5 + 1
BOND_FDIM_DEC = 5
MAX_NB = 10
PAPER = os.getenv('PAPER', False)
_url = 'https://s3-ap-southeast-1.amazonaws.com/dgl-data-cn/dataset/jtnn.zip'
def _unpack_field(examples, field):
return [e[field] for e in examples]
def _set_node_id(mol_tree, vocab):
wid = []
for i, node in enumerate(mol_tree.nodes_dict):
mol_tree.nodes_dict[node]['idx'] = i
wid.append(vocab.get_index(mol_tree.nodes_dict[node]['smiles']))
return wid
class JTNNDataset(Dataset):
def __init__(self, data, vocab, training=True):
self.dir = get_download_dir()
self.zip_file_path='{}/jtnn.zip'.format(self.dir)
download(_url, path=self.zip_file_path)
extract_archive(self.zip_file_path, '{}/jtnn'.format(self.dir))
print('Loading data...')
if data in ['train', 'test']:
data_file = '{}/jtnn/{}.txt'.format(self.dir, data)
else:
data_file = data
with open(data_file) as f:
self.data = [line.strip("\r\n ").split()[0] for line in f]
self.vocab_file = '{}/jtnn/{}.txt'.format(self.dir, vocab)
print('Loading finished.')
print('\tNum samples:', len(self.data))
print('\tVocab file:', self.vocab_file)
self.training = training
self.vocab = Vocab([x.strip("\r\n ") for x in open(self.vocab_file)])
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
smiles = self.data[idx]
mol_tree = DGLMolTree(smiles)
mol_tree.recover()
mol_tree.assemble()
wid = _set_node_id(mol_tree, self.vocab)
# prebuild the molecule graph
mol_graph, atom_x_enc, bond_x_enc = mol2dgl_enc(mol_tree.smiles)
result = {
'mol_tree': mol_tree,
'mol_graph': mol_graph,
'atom_x_enc': atom_x_enc,
'bond_x_enc': bond_x_enc,
'wid': wid,
}
if not self.training:
return result
# prebuild the candidate graph list
cands = []
for node_id, node in mol_tree.nodes_dict.items():
# fill in ground truth
if node['label'] not in node['cands']:
node['cands'].append(node['label'])
node['cand_mols'].append(node['label_mol'])
if node['is_leaf'] or len(node['cands']) == 1:
continue
cands.extend([(cand, mol_tree, node_id)
for cand in node['cand_mols']])
if len(cands) > 0:
cand_graphs, atom_x_dec, bond_x_dec, tree_mess_src_e, \
tree_mess_tgt_e, tree_mess_tgt_n = mol2dgl_dec(cands)
else:
cand_graphs = []
atom_x_dec = torch.zeros(0, ATOM_FDIM_DEC)
bond_x_dec = torch.zeros(0, BOND_FDIM_DEC)
tree_mess_src_e = torch.zeros(0, 2).long()
tree_mess_tgt_e = torch.zeros(0, 2).long()
tree_mess_tgt_n = torch.zeros(0).long()
# prebuild the stereoisomers
cands = mol_tree.stereo_cands
if len(cands) > 1:
if mol_tree.smiles3D not in cands:
cands.append(mol_tree.smiles3D)
stereo_graphs = [mol2dgl_enc(c) for c in cands]
stereo_cand_graphs, stereo_atom_x_enc, stereo_bond_x_enc = \
zip(*stereo_graphs)
stereo_atom_x_enc = torch.cat(stereo_atom_x_enc)
stereo_bond_x_enc = torch.cat(stereo_bond_x_enc)
stereo_cand_label = [(cands.index(mol_tree.smiles3D), len(cands))]
else:
stereo_cand_graphs = []
stereo_atom_x_enc = torch.zeros(0, atom_x_enc.shape[1])
stereo_bond_x_enc = torch.zeros(0, bond_x_enc.shape[1])
stereo_cand_label = []
result.update({
'cand_graphs': cand_graphs,
'atom_x_dec': atom_x_dec,
'bond_x_dec': bond_x_dec,
'tree_mess_src_e': tree_mess_src_e,
'tree_mess_tgt_e': tree_mess_tgt_e,
'tree_mess_tgt_n': tree_mess_tgt_n,
'stereo_cand_graphs': stereo_cand_graphs,
'stereo_atom_x_enc': stereo_atom_x_enc,
'stereo_bond_x_enc': stereo_bond_x_enc,
'stereo_cand_label': stereo_cand_label,
})
return result
class JTNNCollator(object):
def __init__(self, vocab, training):
self.vocab = vocab
self.training = training
@staticmethod
def _batch_and_set(graphs, atom_x, bond_x, flatten):
if flatten:
graphs = [g for f in graphs for g in f]
graph_batch = dgl.batch(graphs)
graph_batch.ndata['x'] = atom_x
graph_batch.edata.update({
'x': bond_x,
'src_x': atom_x.new(bond_x.shape[0], atom_x.shape[1]).zero_(),
})
return graph_batch
def __call__(self, examples):
# get list of trees
mol_trees = _unpack_field(examples, 'mol_tree')
wid = _unpack_field(examples, 'wid')
for _wid, mol_tree in zip(wid, mol_trees):
mol_tree.ndata['wid'] = torch.LongTensor(_wid)
# TODO: either support pickling or get around ctypes pointers using scipy
# batch molecule graphs
mol_graphs = _unpack_field(examples, 'mol_graph')
atom_x = torch.cat(_unpack_field(examples, 'atom_x_enc'))
bond_x = torch.cat(_unpack_field(examples, 'bond_x_enc'))
mol_graph_batch = self._batch_and_set(mol_graphs, atom_x, bond_x, False)
result = {
'mol_trees': mol_trees,
'mol_graph_batch': mol_graph_batch,
}
if not self.training:
return result
# batch candidate graphs
cand_graphs = _unpack_field(examples, 'cand_graphs')
cand_batch_idx = []
atom_x = torch.cat(_unpack_field(examples, 'atom_x_dec'))
bond_x = torch.cat(_unpack_field(examples, 'bond_x_dec'))
tree_mess_src_e = _unpack_field(examples, 'tree_mess_src_e')
tree_mess_tgt_e = _unpack_field(examples, 'tree_mess_tgt_e')
tree_mess_tgt_n = _unpack_field(examples, 'tree_mess_tgt_n')
n_graph_nodes = 0
n_tree_nodes = 0
for i in range(len(cand_graphs)):
tree_mess_tgt_e[i] += n_graph_nodes
tree_mess_src_e[i] += n_tree_nodes
tree_mess_tgt_n[i] += n_graph_nodes
n_graph_nodes += sum(g.number_of_nodes() for g in cand_graphs[i])
n_tree_nodes += mol_trees[i].number_of_nodes()
cand_batch_idx.extend([i] * len(cand_graphs[i]))
tree_mess_tgt_e = torch.cat(tree_mess_tgt_e)
tree_mess_src_e = torch.cat(tree_mess_src_e)
tree_mess_tgt_n = torch.cat(tree_mess_tgt_n)
cand_graph_batch = self._batch_and_set(cand_graphs, atom_x, bond_x, True)
# batch stereoisomers
stereo_cand_graphs = _unpack_field(examples, 'stereo_cand_graphs')
atom_x = torch.cat(_unpack_field(examples, 'stereo_atom_x_enc'))
bond_x = torch.cat(_unpack_field(examples, 'stereo_bond_x_enc'))
stereo_cand_batch_idx = []
for i in range(len(stereo_cand_graphs)):
stereo_cand_batch_idx.extend([i] * len(stereo_cand_graphs[i]))
if len(stereo_cand_batch_idx) > 0:
stereo_cand_labels = [
(label, length)
for ex in _unpack_field(examples, 'stereo_cand_label')
for label, length in ex
]
stereo_cand_labels, stereo_cand_lengths = zip(*stereo_cand_labels)
stereo_cand_graph_batch = self._batch_and_set(
stereo_cand_graphs, atom_x, bond_x, True)
else:
stereo_cand_labels = []
stereo_cand_lengths = []
stereo_cand_graph_batch = None
stereo_cand_batch_idx = []
result.update({
'cand_graph_batch': cand_graph_batch,
'cand_batch_idx': cand_batch_idx,
'tree_mess_tgt_e': tree_mess_tgt_e,
'tree_mess_src_e': tree_mess_src_e,
'tree_mess_tgt_n': tree_mess_tgt_n,
'stereo_cand_graph_batch': stereo_cand_graph_batch,
'stereo_cand_batch_idx': stereo_cand_batch_idx,
'stereo_cand_labels': stereo_cand_labels,
'stereo_cand_lengths': stereo_cand_lengths,
})
return result
import copy
import numpy as np
import rdkit.Chem as Chem
from dgl import DGLGraph
from .chemutils import get_clique_mol, tree_decomp, get_mol, get_smiles, \
set_atommap, enum_assemble_nx, decode_stereo
def get_slots(smiles):
mol = Chem.MolFromSmiles(smiles)
return [(atom.GetSymbol(), atom.GetFormalCharge(), atom.GetTotalNumHs()) for atom in mol.GetAtoms()]
class Vocab(object):
def __init__(self, smiles_list):
self.vocab = smiles_list
self.vmap = {x: i for i, x in enumerate(self.vocab)}
self.slots = [get_slots(smiles) for smiles in self.vocab]
def get_index(self, smiles):
return self.vmap[smiles]
def get_smiles(self, idx):
return self.vocab[idx]
def get_slots(self, idx):
return copy.deepcopy(self.slots[idx])
def size(self):
return len(self.vocab)
class DGLMolTree(DGLGraph):
def __init__(self, smiles):
DGLGraph.__init__(self)
self.nodes_dict = {}
if smiles is None:
return
self.smiles = smiles
self.mol = get_mol(smiles)
# Stereo Generation
mol = Chem.MolFromSmiles(smiles)
self.smiles3D = Chem.MolToSmiles(mol, isomericSmiles=True)
self.smiles2D = Chem.MolToSmiles(mol)
self.stereo_cands = decode_stereo(self.smiles2D)
# cliques: a list of list of atom indices
cliques, edges = tree_decomp(self.mol)
root = 0
for i, c in enumerate(cliques):
cmol = get_clique_mol(self.mol, c)
csmiles = get_smiles(cmol)
self.nodes_dict[i] = dict(
smiles=csmiles,
mol=get_mol(csmiles),
clique=c,
)
if min(c) == 0:
root = i
self.add_nodes(len(cliques))
# The clique with atom ID 0 becomes root
if root > 0:
for attr in self.nodes_dict[0]:
self.nodes_dict[0][attr], self.nodes_dict[root][attr] = \
self.nodes_dict[root][attr], self.nodes_dict[0][attr]
src = np.zeros((len(edges) * 2,), dtype='int')
dst = np.zeros((len(edges) * 2,), dtype='int')
for i, (_x, _y) in enumerate(edges):
x = 0 if _x == root else root if _x == 0 else _x
y = 0 if _y == root else root if _y == 0 else _y
src[2 * i] = x
dst[2 * i] = y
src[2 * i + 1] = y
dst[2 * i + 1] = x
self.add_edges(src, dst)
for i in self.nodes_dict:
self.nodes_dict[i]['nid'] = i + 1
if self.out_degree(i) > 1: # Leaf node mol is not marked
set_atommap(self.nodes_dict[i]['mol'], self.nodes_dict[i]['nid'])
self.nodes_dict[i]['is_leaf'] = (self.out_degree(i) == 1)
def treesize(self):
return self.number_of_nodes()
def _recover_node(self, i, original_mol):
node = self.nodes_dict[i]
clique = []
clique.extend(node['clique'])
if not node['is_leaf']:
for cidx in node['clique']:
original_mol.GetAtomWithIdx(cidx).SetAtomMapNum(node['nid'])
for j in self.successors(i).numpy():
nei_node = self.nodes_dict[j]
clique.extend(nei_node['clique'])
if nei_node['is_leaf']: # Leaf node, no need to mark
continue
for cidx in nei_node['clique']:
# allow singleton node override the atom mapping
if cidx not in node['clique'] or len(nei_node['clique']) == 1:
atom = original_mol.GetAtomWithIdx(cidx)
atom.SetAtomMapNum(nei_node['nid'])
clique = list(set(clique))
label_mol = get_clique_mol(original_mol, clique)
node['label'] = Chem.MolToSmiles(Chem.MolFromSmiles(get_smiles(label_mol)))
node['label_mol'] = get_mol(node['label'])
for cidx in clique:
original_mol.GetAtomWithIdx(cidx).SetAtomMapNum(0)
return node['label']
def _assemble_node(self, i):
neighbors = [self.nodes_dict[j] for j in self.successors(i).numpy()
if self.nodes_dict[j]['mol'].GetNumAtoms() > 1]
neighbors = sorted(neighbors, key=lambda x: x['mol'].GetNumAtoms(), reverse=True)
singletons = [self.nodes_dict[j] for j in self.successors(i).numpy()
if self.nodes_dict[j]['mol'].GetNumAtoms() == 1]
neighbors = singletons + neighbors
cands = enum_assemble_nx(self.nodes_dict[i], neighbors)
if len(cands) > 0:
self.nodes_dict[i]['cands'], self.nodes_dict[i]['cand_mols'], _ = list(zip(*cands))
self.nodes_dict[i]['cands'] = list(self.nodes_dict[i]['cands'])
self.nodes_dict[i]['cand_mols'] = list(self.nodes_dict[i]['cand_mols'])
else:
self.nodes_dict[i]['cands'] = []
self.nodes_dict[i]['cand_mols'] = []
def recover(self):
for i in self.nodes_dict:
self._recover_node(i, self.mol)
def assemble(self):
for i in self.nodes_dict:
self._assemble_node(i)
import argparse
import rdkit
import torch
from dgllife.model import DGLJTNNVAE, load_pretrained
from dgllife.model.model_zoo.jtnn.nnutils import cuda
from torch.utils.data import DataLoader
from jtnn import *
def worker_init_fn(id_):
lg = rdkit.RDLogger.logger()
lg.setLevel(rdkit.RDLogger.CRITICAL)
worker_init_fn(None)
parser = argparse.ArgumentParser(description="Evaluation for JTNN",
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("-t", "--train", dest="train",
default='test', help='Training file name')
parser.add_argument("-v", "--vocab", dest="vocab",
default='vocab', help='Vocab file name')
parser.add_argument("-m", "--model", dest="model_path", default=None,
help="Pre-trained model to be loaded for evalutaion. If not specified,"
" would use pre-trained model from model zoo")
parser.add_argument("-w", "--hidden", dest="hidden_size", default=450,
help="Hidden size of representation vector, "
"should be consistent with pre-trained model")
parser.add_argument("-l", "--latent", dest="latent_size", default=56,
help="Latent Size of node(atom) features and edge(atom) features, "
"should be consistent with pre-trained model")
parser.add_argument("-d", "--depth", dest="depth", default=3,
help="Depth of message passing hops, "
"should be consistent with pre-trained model")
args = parser.parse_args()
dataset = JTNNDataset(data=args.train, vocab=args.vocab, training=False)
vocab_file = dataset.vocab_file
hidden_size = int(args.hidden_size)
latent_size = int(args.latent_size)
depth = int(args.depth)
model = DGLJTNNVAE(vocab_file=vocab_file,
hidden_size=hidden_size,
latent_size=latent_size,
depth=depth)
if args.model_path is not None:
model.load_state_dict(torch.load(args.model_path))
else:
model = load_pretrained("JTNN_ZINC")
model = cuda(model)
model.eval()
print("Model #Params: %dK" %
(sum([x.nelement() for x in model.parameters()]) / 1000,))
MAX_EPOCH = 100
PRINT_ITER = 20
def reconstruct():
dataset.training = False
dataloader = DataLoader(
dataset,
batch_size=1,
shuffle=False,
num_workers=0,
collate_fn=JTNNCollator(dataset.vocab, False),
drop_last=True,
worker_init_fn=worker_init_fn)
# Just an example of molecule decoding; in reality you may want to sample
# tree and molecule vectors.
acc = 0.0
tot = 0
with torch.no_grad():
for it, batch in enumerate(dataloader):
gt_smiles = batch['mol_trees'][0].smiles
# print(gt_smiles)
model.move_to_cuda(batch)
try:
_, tree_vec, mol_vec = model.encode(batch)
tree_mean = model.T_mean(tree_vec)
# Following Mueller et al.
tree_log_var = -torch.abs(model.T_var(tree_vec))
mol_mean = model.G_mean(mol_vec)
# Following Mueller et al.
mol_log_var = -torch.abs(model.G_var(mol_vec))
epsilon = torch.randn(1, model.latent_size // 2).cuda()
tree_vec = tree_mean + torch.exp(tree_log_var // 2) * epsilon
epsilon = torch.randn(1, model.latent_size // 2).cuda()
mol_vec = mol_mean + torch.exp(mol_log_var // 2) * epsilon
dec_smiles = model.decode(tree_vec, mol_vec)
if dec_smiles == gt_smiles:
acc += 1
tot += 1
except Exception as e:
print("Failed to encode: {}".format(gt_smiles))
print(e)
if it % 20 == 1:
print("Progress {}/{}; Current Reconstruction Accuracy: {:.4f}".format(it,
len(dataloader), acc / tot))
return acc / tot
if __name__ == '__main__':
reconstruct_acc = reconstruct()
print("Reconstruction Accuracy: {}".format(reconstruct_acc))
import argparse
import rdkit
import sys
import torch
import torch.nn as nn
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
from dgllife.model import DGLJTNNVAE
from dgllife.model.model_zoo.jtnn.nnutils import cuda
from torch.utils.data import DataLoader
from jtnn import *
torch.multiprocessing.set_sharing_strategy('file_system')
def worker_init_fn(id_):
lg = rdkit.RDLogger.logger()
lg.setLevel(rdkit.RDLogger.CRITICAL)
worker_init_fn(None)
parser = argparse.ArgumentParser(description="Training for JTNN",
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("-t", "--train", dest="train", default='train', help='Training file name')
parser.add_argument("-v", "--vocab", dest="vocab", default='vocab', help='Vocab file name')
parser.add_argument("-s", "--save_dir", dest="save_path", default='./',
help="Path to save checkpoint models, default to be current working directory")
parser.add_argument("-m", "--model", dest="model_path", default=None,
help="Path to load pre-trained model")
parser.add_argument("-b", "--batch", dest="batch_size", default=40,
help="Batch size")
parser.add_argument("-w", "--hidden", dest="hidden_size", default=200,
help="Size of representation vectors")
parser.add_argument("-l", "--latent", dest="latent_size", default=56,
help="Latent Size of node(atom) features and edge(atom) features")
parser.add_argument("-d", "--depth", dest="depth", default=3,
help="Depth of message passing hops")
parser.add_argument("-z", "--beta", dest="beta", default=1.0,
help="Coefficient of KL Divergence term")
parser.add_argument("-q", "--lr", dest="lr", default=1e-3,
help="Learning Rate")
args = parser.parse_args()
dataset = JTNNDataset(data=args.train, vocab=args.vocab, training=True)
vocab_file = dataset.vocab_file
batch_size = int(args.batch_size)
hidden_size = int(args.hidden_size)
latent_size = int(args.latent_size)
depth = int(args.depth)
beta = float(args.beta)
lr = float(args.lr)
model = DGLJTNNVAE(vocab_file=vocab_file,
hidden_size=hidden_size,
latent_size=latent_size,
depth=depth)
if args.model_path is not None:
model.load_state_dict(torch.load(args.model_path))
else:
for param in model.parameters():
if param.dim() == 1:
nn.init.constant_(param, 0)
else:
nn.init.xavier_normal_(param)
model = cuda(model)
print("Model #Params: %dK" % (sum([x.nelement() for x in model.parameters()]) / 1000,))
optimizer = optim.Adam(model.parameters(), lr=lr)
scheduler = lr_scheduler.ExponentialLR(optimizer, 0.9)
MAX_EPOCH = 100
PRINT_ITER = 20
def train():
dataset.training = True
dataloader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=True,
num_workers=4,
collate_fn=JTNNCollator(dataset.vocab, True),
drop_last=True,
worker_init_fn=worker_init_fn)
for epoch in range(MAX_EPOCH):
word_acc, topo_acc, assm_acc, steo_acc = 0, 0, 0, 0
for it, batch in enumerate(dataloader):
model.zero_grad()
try:
loss, kl_div, wacc, tacc, sacc, dacc = model(batch, beta)
except:
print([t.smiles for t in batch['mol_trees']])
raise
loss.backward()
optimizer.step()
word_acc += wacc
topo_acc += tacc
assm_acc += sacc
steo_acc += dacc
if (it + 1) % PRINT_ITER == 0:
word_acc = word_acc / PRINT_ITER * 100
topo_acc = topo_acc / PRINT_ITER * 100
assm_acc = assm_acc / PRINT_ITER * 100
steo_acc = steo_acc / PRINT_ITER * 100
print("KL: %.1f, Word: %.2f, Topo: %.2f, Assm: %.2f, Steo: %.2f, Loss: %.6f" % (
kl_div, word_acc, topo_acc, assm_acc, steo_acc, loss.item()))
word_acc, topo_acc, assm_acc, steo_acc = 0, 0, 0, 0
sys.stdout.flush()
if (it + 1) % 1500 == 0: # Fast annealing
scheduler.step()
print("learning rate: %.6f" % scheduler.get_lr()[0])
torch.save(model.state_dict(),
args.save_path + "/model.iter-%d-%d" % (epoch, it + 1))
scheduler.step()
print("learning rate: %.6f" % scheduler.get_lr()[0])
torch.save(model.state_dict(), args.save_path + "/model.iter-" + str(epoch))
if __name__ == '__main__':
train()
print('# passes:', model.n_passes)
print('Total # nodes processed:', model.n_nodes_total)
print('Total # edges processed:', model.n_edges_total)
print('Total # tree nodes processed:', model.n_tree_nodes_total)
print('Graph decoder: # passes:', model.jtmpn.n_passes)
print('Graph decoder: Total # candidates processed:', model.jtmpn.n_samples_total)
print('Graph decoder: Total # nodes processed:', model.jtmpn.n_nodes_total)
print('Graph decoder: Total # edges processed:', model.jtmpn.n_edges_total)
print('Graph encoder: # passes:', model.mpn.n_passes)
print('Graph encoder: Total # candidates processed:', model.mpn.n_samples_total)
print('Graph encoder: Total # nodes processed:', model.mpn.n_nodes_total)
print('Graph encoder: Total # edges processed:', model.mpn.n_edges_total)
"""Generate vocabulary for a new dataset."""
if __name__ == '__main__':
import argparse
import os
import rdkit
from dgl.data.utils import _get_dgl_url, download, get_download_dir, extract_archive
from jtnn.mol_tree import DGLMolTree
parser = argparse.ArgumentParser('Generate vocabulary for a molecule dataset')
parser.add_argument('-d', '--data-path', type=str,
help='Path to the dataset')
parser.add_argument('-v', '--vocab', type=str,
help='Path to the vocabulary file to save')
args = parser.parse_args()
lg = rdkit.RDLogger.logger()
lg.setLevel(rdkit.RDLogger.CRITICAL)
vocab = set()
with open(args.data_path, 'r') as f:
for line in f:
smiles = line.strip()
mol = DGLMolTree(smiles)
for i in mol.nodes_dict:
vocab.add(mol.nodes_dict[i]['smiles'])
with open(args.vocab, 'w') as f:
for v in vocab:
f.write(v + '\n')
# Get the vocabulary used for the pre-trained model
default_dir = get_download_dir()
vocab_file = '{}/jtnn/{}.txt'.format(default_dir, 'vocab')
if not os.path.exists(vocab_file):
zip_file_path = '{}/jtnn.zip'.format(default_dir)
download(_get_dgl_url('dgllife/jtnn.zip'), path=zip_file_path)
extract_archive(zip_file_path, '{}/jtnn'.format(default_dir))
default_vocab = set()
with open(vocab_file, 'r') as f:
for line in f:
default_vocab.add(line.strip())
print('The new vocabulary is a subset of the default vocabulary: {}'.format(
vocab.issubset(default_vocab)))
# Property Prediction
## Classification
Classification tasks require assigning discrete labels to a molecule, e.g. molecule toxicity.
### Datasets
- **Tox21**. The ["Toxicology in the 21st Century" (Tox21)](https://tripod.nih.gov/tox21/challenge/) initiative created
a public database measuring toxicity of compounds, which has been used in the 2014 Tox21 Data Challenge. The dataset
contains qualitative toxicity measurements for 8014 compounds on 12 different targets, including nuclear receptors and
stress response pathways. Each target yields a binary prediction problem. MoleculeNet [1] randomly splits the dataset
into training, validation and test set with a 80/10/10 ratio. By default we follow their split method.
### Models
- **Weave** [9]. Weave is one of the pioneering efforts in applying graph neural networks to molecular property prediction.
- **Graph Convolutional Network** [2], [3]. Graph Convolutional Networks (GCN) have been one of the most popular graph neural
networks and they can be easily extended for graph level prediction. MoleculeNet [1] reports baseline results of graph
convolutions over multiple datasets.
- **Graph Attention Networks** [7]. Graph Attention Networks (GATs) incorporate multi-head attention into GCNs,
explicitly modeling the interactions between adjacent atoms.
### Usage
Use `classification.py` with arguments
```
-m {GCN, GAT, Weave}, MODEL, model to use
-d {Tox21}, DATASET, dataset to use
```
If you want to use the pre-trained model, simply add `-p`.
We use GPU whenever it is available.
### Performance
#### GCN on Tox21
| Source | Averaged Test ROC-AUC Score |
| ---------------- | --------------------------- |
| MoleculeNet [1] | 0.829 |
| [DeepChem example](https://github.com/deepchem/deepchem/blob/master/examples/tox21/tox21_tensorgraph_graph_conv.py) | 0.813 |
| Pretrained model | 0.833 |
Note that the dataset is randomly split so these numbers are only for reference and they do not necessarily suggest
a real difference.
#### GAT on Tox21
| Source | Averaged Test ROC-AUC Score |
| ---------------- | --------------------------- |
| Pretrained model | 0.853 |
#### Weave on Tox21
| Source | Averaged Test ROC-AUC Score |
| ---------------- | --------------------------- |
| Pretrained model | 0.8074 |
## Regression
Regression tasks require assigning continuous labels to a molecule, e.g. molecular energy.
### Datasets
- **Alchemy**. The [Alchemy Dataset](https://alchemy.tencent.com/) is introduced by Tencent Quantum Lab to facilitate the development of new
machine learning models useful for chemistry and materials science. The dataset lists 12 quantum mechanical properties of 130,000+ organic
molecules comprising up to 12 heavy atoms (C, N, O, S, F and Cl), sampled from the [GDBMedChem](http://gdb.unibe.ch/downloads/) database.
These properties have been calculated using the open-source computational chemistry program Python-based Simulation of Chemistry Framework
([PySCF](https://github.com/pyscf/pyscf)). The Alchemy dataset expands on the volume and diversity of existing molecular datasets such as QM9.
- **PubChem BioAssay Aromaticity**. The dataset is introduced in
[Pushing the Boundaries of Molecular Representation for Drug Discovery with the Graph Attention Mechanism](https://www.ncbi.nlm.nih.gov/pubmed/31408336),
for the task of predicting the number of aromatic atoms in molecules. The dataset was constructed by sampling 3945 molecules with 0-40 aromatic atoms
from the PubChem BioAssay dataset.
### Models
- **Message Passing Neural Network** [6]. Message Passing Neural Networks (MPNNs) have reached the best performance on
the QM9 dataset for some time.
- **SchNet** [4]. SchNet employs continuous filter convolutional layers to model quantum interactions in molecules
without requiring them to lie on grids.
- **Multilevel Graph Convolutional Neural Network** [5]. Multilevel Graph Convolutional Neural Networks (MGCN) are
hierarchical graph neural networks that extract features from the conformation and spatial information followed by the
multilevel interactions.
- **AttentiveFP** [8]. AttentiveFP combines attention and GRU for better model capacity and shows competitive
performance across datasetts.
### Usage
Use `regression.py` with arguments
```
-m {MPNN, SchNet, MGCN, AttentiveFP}, Model to use
-d {Alchemy, Aromaticity}, Dataset to use
```
If you want to use the pre-trained model, simply add `-p`. Currently we only support pre-trained models of AttentiveFP
on PubChem BioAssay Aromaticity dataset.
### Performance
#### Alchemy
The Alchemy contest is still ongoing. Before the test set is fully released, we only include the performance numbers
on the training and validation set for reference.
| Model | Training MAE | Validation MAE |
| ---------- | ------------ | -------------- |
| SchNet [4] | 0.0651 | 0.0925 |
| MGCN [5] | 0.0582 | 0.0942 |
| MPNN [6] | 0.1004 | 0.1587 |
#### PubChem BioAssay Aromaticity
| Model | Test RMSE |
| --------------- | --------- |
| AttentiveFP [8] | 0.7508 |
Note that the dataset is randomly split so this number is only for reference.
## Interpretation
[8] visualizes the weights of atoms in readout for possible interpretations like the figure below.
We provide a jupyter notebook for performing the visualization and you can download it with
`wget https://data.dgl.ai/dgllife/attentive_fp/atom_weight_visualization.ipynb`.
![](https://data.dgl.ai/dgllife/attentive_fp_vis_example.png)
## Dataset Customization
Generally we follow the practice of PyTorch.
A dataset class should implement `__getitem__(self, index)` and `__len__(self)` method
```python
class CustomDataset(object):
def __init__(self):
pass
def __getitem__(self, index):
"""
Parameters
----------
index : int
Index for the datapoint.
Returns
-------
str
SMILES for the molecule
DGLGraph
Constructed DGLGraph for the molecule
1D Tensor of dtype float32
Labels of the datapoint
"""
return self.smiles[index], self.graphs[index], self.labels[index]
def __len__(self):
return len(self.smiles)
```
We provide various methods for graph construction in `dgllife.utils.mol_to_graph`. If your dataset can
be converted to a pandas dataframe, e.g. a .csv file, you may use `MoleculeCSVDataset` in
`dgllife.data.csv_dataset`.
## References
[1] Wu et al. (2017) MoleculeNet: a benchmark for molecular machine learning. *Chemical Science* 9, 513-530.
[2] Duvenaud et al. (2015) Convolutional networks on graphs for learning molecular fingerprints. *Advances in neural
information processing systems (NeurIPS)*, 2224-2232.
[3] Kipf et al. (2017) Semi-Supervised Classification with Graph Convolutional Networks.
*The International Conference on Learning Representations (ICLR)*.
[4] Schütt et al. (2017) SchNet: A continuous-filter convolutional neural network for modeling quantum interactions.
*Advances in Neural Information Processing Systems (NeurIPS)*, 992-1002.
[5] Lu et al. (2019) Molecular Property Prediction: A Multilevel Quantum Interactions Modeling Perspective.
*The 33rd AAAI Conference on Artificial Intelligence*.
[6] Gilmer et al. (2017) Neural Message Passing for Quantum Chemistry. *Proceedings of the 34th International Conference on
Machine Learning*, JMLR. 1263-1272.
[7] Veličković et al. (2018) Graph Attention Networks.
*The International Conference on Learning Representations (ICLR)*.
[8] Xiong et al. (2019) Pushing the Boundaries of Molecular Representation for Drug Discovery with the Graph
Attention Mechanism. *Journal of Medicinal Chemistry*.
[9] Kearnes et al. (2016) Molecular graph convolutions: moving beyond fingerprints.
*Journal of Computer-Aided Molecular Design*.
import numpy as np
import torch
from dgllife.model import load_pretrained
from dgllife.utils import EarlyStopping, Meter
from torch.nn import BCEWithLogitsLoss
from torch.optim import Adam
from torch.utils.data import DataLoader
from utils import set_random_seed, load_dataset_for_classification, collate_molgraphs, load_model
def predict(args, model, bg):
node_feats = bg.ndata.pop(args['node_data_field']).to(args['device'])
if args.get('edge_featurizer', None) is not None:
edge_feats = bg.edata.pop(args['edge_data_field']).to(args['device'])
return model(bg, node_feats, edge_feats)
else:
return model(bg, node_feats)
def run_a_train_epoch(args, epoch, model, data_loader, loss_criterion, optimizer):
model.train()
train_meter = Meter()
for batch_id, batch_data in enumerate(data_loader):
smiles, bg, labels, masks = batch_data
labels, masks = labels.to(args['device']), masks.to(args['device'])
logits = predict(args, model, bg)
# Mask non-existing labels
loss = (loss_criterion(logits, labels) * (masks != 0).float()).mean()
optimizer.zero_grad()
loss.backward()
optimizer.step()
print('epoch {:d}/{:d}, batch {:d}/{:d}, loss {:.4f}'.format(
epoch + 1, args['num_epochs'], batch_id + 1, len(data_loader), loss.item()))
train_meter.update(logits, labels, masks)
train_score = np.mean(train_meter.compute_metric(args['metric_name']))
print('epoch {:d}/{:d}, training {} {:.4f}'.format(
epoch + 1, args['num_epochs'], args['metric_name'], train_score))
def run_an_eval_epoch(args, model, data_loader):
model.eval()
eval_meter = Meter()
with torch.no_grad():
for batch_id, batch_data in enumerate(data_loader):
smiles, bg, labels, masks = batch_data
labels = labels.to(args['device'])
logits = predict(args, model, bg)
eval_meter.update(logits, labels, masks)
return np.mean(eval_meter.compute_metric(args['metric_name']))
def main(args):
args['device'] = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
set_random_seed(args['random_seed'])
# Interchangeable with other datasets
dataset, train_set, val_set, test_set = load_dataset_for_classification(args)
train_loader = DataLoader(train_set, batch_size=args['batch_size'],
collate_fn=collate_molgraphs)
val_loader = DataLoader(val_set, batch_size=args['batch_size'],
collate_fn=collate_molgraphs)
test_loader = DataLoader(test_set, batch_size=args['batch_size'],
collate_fn=collate_molgraphs)
if args['pre_trained']:
args['num_epochs'] = 0
model = load_pretrained(args['exp'])
else:
args['n_tasks'] = dataset.n_tasks
model = load_model(args)
loss_criterion = BCEWithLogitsLoss(pos_weight=dataset.task_pos_weights.to(args['device']),
reduction='none')
optimizer = Adam(model.parameters(), lr=args['lr'])
stopper = EarlyStopping(patience=args['patience'])
model.to(args['device'])
for epoch in range(args['num_epochs']):
# Train
run_a_train_epoch(args, epoch, model, train_loader, loss_criterion, optimizer)
# Validation and early stop
val_score = run_an_eval_epoch(args, model, val_loader)
early_stop = stopper.step(val_score, model)
print('epoch {:d}/{:d}, validation {} {:.4f}, best validation {} {:.4f}'.format(
epoch + 1, args['num_epochs'], args['metric_name'],
val_score, args['metric_name'], stopper.best_score))
if early_stop:
break
if not args['pre_trained']:
stopper.load_checkpoint(model)
test_score = run_an_eval_epoch(args, model, test_loader)
print('test {} {:.4f}'.format(args['metric_name'], test_score))
if __name__ == '__main__':
import argparse
from configure import get_exp_configure
parser = argparse.ArgumentParser(description='Molecule Classification')
parser.add_argument('-m', '--model', type=str, choices=['GCN', 'GAT', 'Weave'],
help='Model to use')
parser.add_argument('-d', '--dataset', type=str, choices=['Tox21'],
help='Dataset to use')
parser.add_argument('-p', '--pre-trained', action='store_true',
help='Whether to skip training and use a pre-trained model')
args = parser.parse_args().__dict__
args['exp'] = '_'.join([args['model'], args['dataset']])
args.update(get_exp_configure(args['exp']))
main(args)
from functools import partial
# graph construction
from dgllife.utils import smiles_to_bigraph, smiles_to_complete_graph
# general featurization
from dgllife.utils import ConcatFeaturizer
# node featurization
from dgllife.utils import CanonicalAtomFeaturizer, BaseAtomFeaturizer, WeaveAtomFeaturizer, \
atom_type_one_hot, atom_degree_one_hot, atom_formal_charge, atom_num_radical_electrons, \
atom_hybridization_one_hot, atom_total_num_H_one_hot
# edge featurization
from dgllife.utils.featurizers import BaseBondFeaturizer, WeaveEdgeFeaturizer
from utils import chirality
GCN_Tox21 = {
'random_seed': 2,
'batch_size': 128,
'lr': 1e-3,
'num_epochs': 100,
'node_data_field': 'h',
'frac_train': 0.8,
'frac_val': 0.1,
'frac_test': 0.1,
'in_feats': 74,
'gcn_hidden_feats': [64, 64],
'classifier_hidden_feats': 64,
'patience': 10,
'smiles_to_graph': smiles_to_bigraph,
'node_featurizer': CanonicalAtomFeaturizer(),
'metric_name': 'roc_auc_score'
}
GAT_Tox21 = {
'random_seed': 2,
'batch_size': 128,
'lr': 1e-3,
'num_epochs': 100,
'node_data_field': 'h',
'frac_train': 0.8,
'frac_val': 0.1,
'frac_test': 0.1,
'in_feats': 74,
'gat_hidden_feats': [32, 32],
'classifier_hidden_feats': 64,
'num_heads': [4, 4],
'patience': 10,
'smiles_to_graph': smiles_to_bigraph,
'node_featurizer': CanonicalAtomFeaturizer(),
'metric_name': 'roc_auc_score'
}
Weave_Tox21 = {
'random_seed': 2,
'batch_size': 32,
'lr': 1e-3,
'num_epochs': 100,
'node_data_field': 'h',
'edge_data_field': 'e',
'frac_train': 0.8,
'frac_val': 0.1,
'frac_test': 0.1,
'num_gnn_layers': 2,
'gnn_hidden_feats': 50,
'graph_feats': 128,
'patience': 10,
'smiles_to_graph': partial(smiles_to_complete_graph, add_self_loop=True),
'node_featurizer': WeaveAtomFeaturizer(),
'edge_featurizer': WeaveEdgeFeaturizer(max_distance=2),
'metric_name': 'roc_auc_score'
}
MPNN_Alchemy = {
'random_seed': 0,
'batch_size': 16,
'num_epochs': 250,
'node_in_feats': 15,
'node_out_feats': 64,
'edge_in_feats': 5,
'edge_hidden_feats': 128,
'n_tasks': 12,
'lr': 0.0001,
'patience': 50,
'metric_name': 'mae',
'weight_decay': 0
}
SchNet_Alchemy = {
'random_seed': 0,
'batch_size': 16,
'num_epochs': 250,
'node_feats': 64,
'hidden_feats': [64, 64, 64],
'classifier_hidden_feats': 64,
'n_tasks': 12,
'lr': 0.0001,
'patience': 50,
'metric_name': 'mae',
'weight_decay': 0
}
MGCN_Alchemy = {
'random_seed': 0,
'batch_size': 16,
'num_epochs': 250,
'feats': 128,
'n_layers': 3,
'classifier_hidden_feats': 64,
'n_tasks': 12,
'lr': 0.0001,
'patience': 50,
'metric_name': 'mae',
'weight_decay': 0
}
AttentiveFP_Aromaticity = {
'random_seed': 8,
'graph_feat_size': 200,
'num_layers': 2,
'num_timesteps': 2,
'node_feat_size': 39,
'edge_feat_size': 10,
'n_tasks': 1,
'dropout': 0.2,
'weight_decay': 10 ** (-5.0),
'lr': 10 ** (-2.5),
'batch_size': 128,
'num_epochs': 800,
'frac_train': 0.8,
'frac_val': 0.1,
'frac_test': 0.1,
'patience': 80,
'metric_name': 'rmse',
'smiles_to_graph': smiles_to_bigraph,
# Follow the atom featurization in the original work
'node_featurizer': BaseAtomFeaturizer(
featurizer_funcs={'hv': ConcatFeaturizer([
partial(atom_type_one_hot, allowable_set=[
'B', 'C', 'N', 'O', 'F', 'Si', 'P', 'S', 'Cl', 'As', 'Se', 'Br', 'Te', 'I', 'At'],
encode_unknown=True),
partial(atom_degree_one_hot, allowable_set=list(range(6))),
atom_formal_charge, atom_num_radical_electrons,
partial(atom_hybridization_one_hot, encode_unknown=True),
lambda atom: [0], # A placeholder for aromatic information,
atom_total_num_H_one_hot, chirality
],
)}
),
'edge_featurizer': BaseBondFeaturizer({
'he': lambda bond: [0 for _ in range(10)]
})
}
experiment_configures = {
'GCN_Tox21': GCN_Tox21,
'GAT_Tox21': GAT_Tox21,
'Weave_Tox21': Weave_Tox21,
'MPNN_Alchemy': MPNN_Alchemy,
'SchNet_Alchemy': SchNet_Alchemy,
'MGCN_Alchemy': MGCN_Alchemy,
'AttentiveFP_Aromaticity': AttentiveFP_Aromaticity
}
def get_exp_configure(exp_name):
return experiment_configures[exp_name]
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from dgllife.model import load_pretrained
from dgllife.utils import EarlyStopping, Meter
from utils import set_random_seed, load_dataset_for_regression, collate_molgraphs, load_model
def regress(args, model, bg):
if args['model'] == 'MPNN':
h = bg.ndata.pop('n_feat')
e = bg.edata.pop('e_feat')
h, e = h.to(args['device']), e.to(args['device'])
return model(bg, h, e)
elif args['model'] in ['SchNet', 'MGCN']:
node_types = bg.ndata.pop('node_type')
edge_distances = bg.edata.pop('distance')
node_types, edge_distances = node_types.to(args['device']), \
edge_distances.to(args['device'])
return model(bg, node_types, edge_distances)
else:
atom_feats, bond_feats = bg.ndata.pop('hv'), bg.edata.pop('he')
atom_feats, bond_feats = atom_feats.to(args['device']), bond_feats.to(args['device'])
return model(bg, atom_feats, bond_feats)
def run_a_train_epoch(args, epoch, model, data_loader,
loss_criterion, optimizer):
model.train()
train_meter = Meter()
for batch_id, batch_data in enumerate(data_loader):
smiles, bg, labels, masks = batch_data
labels, masks = labels.to(args['device']), masks.to(args['device'])
prediction = regress(args, model, bg)
loss = (loss_criterion(prediction, labels) * (masks != 0).float()).mean()
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_meter.update(prediction, labels, masks)
total_score = np.mean(train_meter.compute_metric(args['metric_name']))
print('epoch {:d}/{:d}, training {} {:.4f}'.format(
epoch + 1, args['num_epochs'], args['metric_name'], total_score))
def run_an_eval_epoch(args, model, data_loader):
model.eval()
eval_meter = Meter()
with torch.no_grad():
for batch_id, batch_data in enumerate(data_loader):
smiles, bg, labels, masks = batch_data
labels = labels.to(args['device'])
prediction = regress(args, model, bg)
eval_meter.update(prediction, labels, masks)
total_score = np.mean(eval_meter.compute_metric(args['metric_name']))
return total_score
def main(args):
args['device'] = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
set_random_seed(args['random_seed'])
train_set, val_set, test_set = load_dataset_for_regression(args)
train_loader = DataLoader(dataset=train_set,
batch_size=args['batch_size'],
shuffle=True,
collate_fn=collate_molgraphs)
val_loader = DataLoader(dataset=val_set,
batch_size=args['batch_size'],
shuffle=True,
collate_fn=collate_molgraphs)
if test_set is not None:
test_loader = DataLoader(dataset=test_set,
batch_size=args['batch_size'],
collate_fn=collate_molgraphs)
if args['pre_trained']:
args['num_epochs'] = 0
model = load_pretrained(args['exp'])
else:
model = load_model(args)
loss_fn = nn.MSELoss(reduction='none')
optimizer = torch.optim.Adam(model.parameters(), lr=args['lr'],
weight_decay=args['weight_decay'])
stopper = EarlyStopping(mode='lower', patience=args['patience'])
model.to(args['device'])
for epoch in range(args['num_epochs']):
# Train
run_a_train_epoch(args, epoch, model, train_loader, loss_fn, optimizer)
# Validation and early stop
val_score = run_an_eval_epoch(args, model, val_loader)
early_stop = stopper.step(val_score, model)
print('epoch {:d}/{:d}, validation {} {:.4f}, best validation {} {:.4f}'.format(
epoch + 1, args['num_epochs'], args['metric_name'], val_score,
args['metric_name'], stopper.best_score))
if early_stop:
break
if test_set is not None:
if not args['pre_trained']:
stopper.load_checkpoint(model)
test_score = run_an_eval_epoch(args, model, test_loader)
print('test {} {:.4f}'.format(args['metric_name'], test_score))
if __name__ == "__main__":
import argparse
from configure import get_exp_configure
parser = argparse.ArgumentParser(description='Molecule Regression')
parser.add_argument('-m', '--model', type=str,
choices=['MPNN', 'SchNet', 'MGCN', 'AttentiveFP'],
help='Model to use')
parser.add_argument('-d', '--dataset', type=str, choices=['Alchemy', 'Aromaticity'],
help='Dataset to use')
parser.add_argument('-p', '--pre-trained', action='store_true',
help='Whether to skip training and use a pre-trained model')
args = parser.parse_args().__dict__
args['exp'] = '_'.join([args['model'], args['dataset']])
args.update(get_exp_configure(args['exp']))
main(args)
import dgl
import numpy as np
import random
import torch
from dgllife.utils.featurizers import one_hot_encoding
from dgllife.utils.splitters import RandomSplitter
def set_random_seed(seed=0):
"""Set random seed.
Parameters
----------
seed : int
Random seed to use
"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
def load_dataset_for_classification(args):
"""Load dataset for classification tasks.
Parameters
----------
args : dict
Configurations.
Returns
-------
dataset
The whole dataset.
train_set
Subset for training.
val_set
Subset for validation.
test_set
Subset for test.
"""
assert args['dataset'] in ['Tox21']
if args['dataset'] == 'Tox21':
from dgllife.data import Tox21
dataset = Tox21(smiles_to_graph=args['smiles_to_graph'],
node_featurizer=args.get('node_featurizer', None),
edge_featurizer=args.get('edge_featurizer', None))
train_set, val_set, test_set = RandomSplitter.train_val_test_split(
dataset, frac_train=args['frac_train'], frac_val=args['frac_val'],
frac_test=args['frac_test'], random_state=args['random_seed'])
return dataset, train_set, val_set, test_set
def load_dataset_for_regression(args):
"""Load dataset for regression tasks.
Parameters
----------
args : dict
Configurations.
Returns
-------
train_set
Subset for training.
val_set
Subset for validation.
test_set
Subset for test.
"""
assert args['dataset'] in ['Alchemy', 'Aromaticity']
if args['dataset'] == 'Alchemy':
from dgllife.data import TencentAlchemyDataset
train_set = TencentAlchemyDataset(mode='dev')
val_set = TencentAlchemyDataset(mode='valid')
test_set = None
if args['dataset'] == 'Aromaticity':
from dgllife.data import PubChemBioAssayAromaticity
dataset = PubChemBioAssayAromaticity(smiles_to_graph=args['smiles_to_graph'],
node_featurizer=args.get('node_featurizer', None),
edge_featurizer=args.get('edge_featurizer', None))
train_set, val_set, test_set = RandomSplitter.train_val_test_split(
dataset, frac_train=args['frac_train'], frac_val=args['frac_val'],
frac_test=args['frac_test'], random_state=args['random_seed'])
return train_set, val_set, test_set
def collate_molgraphs(data):
"""Batching a list of datapoints for dataloader.
Parameters
----------
data : list of 3-tuples or 4-tuples.
Each tuple is for a single datapoint, consisting of
a SMILES, a DGLGraph, all-task labels and optionally
a binary mask indicating the existence of labels.
Returns
-------
smiles : list
List of smiles
bg : DGLGraph
The batched DGLGraph.
labels : Tensor of dtype float32 and shape (B, T)
Batched datapoint labels. B is len(data) and
T is the number of total tasks.
masks : Tensor of dtype float32 and shape (B, T)
Batched datapoint binary mask, indicating the
existence of labels. If binary masks are not
provided, return a tensor with ones.
"""
assert len(data[0]) in [3, 4], \
'Expect the tuple to be of length 3 or 4, got {:d}'.format(len(data[0]))
if len(data[0]) == 3:
smiles, graphs, labels = map(list, zip(*data))
masks = None
else:
smiles, graphs, labels, masks = map(list, zip(*data))
bg = dgl.batch(graphs)
bg.set_n_initializer(dgl.init.zero_initializer)
bg.set_e_initializer(dgl.init.zero_initializer)
labels = torch.stack(labels, dim=0)
if masks is None:
masks = torch.ones(labels.shape)
else:
masks = torch.stack(masks, dim=0)
return smiles, bg, labels, masks
def load_model(args):
if args['model'] == 'GCN':
from dgllife.model import GCNPredictor
model = GCNPredictor(in_feats=args['node_featurizer'].feat_size(),
hidden_feats=args['gcn_hidden_feats'],
classifier_hidden_feats=args['classifier_hidden_feats'],
n_tasks=args['n_tasks'])
if args['model'] == 'GAT':
from dgllife.model import GATPredictor
model = GATPredictor(in_feats=args['node_featurizer'].feat_size(),
hidden_feats=args['gat_hidden_feats'],
num_heads=args['num_heads'],
classifier_hidden_feats=args['classifier_hidden_feats'],
n_tasks=args['n_tasks'])
if args['model'] == 'Weave':
from dgllife.model import WeavePredictor
model = WeavePredictor(node_in_feats=args['node_featurizer'].feat_size(),
edge_in_feats=args['edge_featurizer'].feat_size(),
num_gnn_layers=args['num_gnn_layers'],
gnn_hidden_feats=args['gnn_hidden_feats'],
graph_feats=args['graph_feats'],
n_tasks=args['n_tasks'])
if args['model'] == 'AttentiveFP':
from dgllife.model import AttentiveFPPredictor
model = AttentiveFPPredictor(node_feat_size=args['node_featurizer'].feat_size(),
edge_feat_size=args['edge_featurizer'].feat_size(),
num_layers=args['num_layers'],
num_timesteps=args['num_timesteps'],
graph_feat_size=args['graph_feat_size'],
n_tasks=args['n_tasks'],
dropout=args['dropout'])
if args['model'] == 'SchNet':
from dgllife.model import SchNetPredictor
model = SchNetPredictor(node_feats=args['node_feats'],
hidden_feats=args['hidden_feats'],
classifier_hidden_feats=args['classifier_hidden_feats'],
n_tasks=args['n_tasks'])
if args['model'] == 'MGCN':
from dgllife.model import MGCNPredictor
model = MGCNPredictor(feats=args['feats'],
n_layers=args['n_layers'],
classifier_hidden_feats=args['classifier_hidden_feats'],
n_tasks=args['n_tasks'])
if args['model'] == 'MPNN':
from dgllife.model import MPNNPredictor
model = MPNNPredictor(node_in_feats=args['node_in_feats'],
edge_in_feats=args['edge_in_feats'],
node_out_feats=args['node_out_feats'],
edge_hidden_feats=args['edge_hidden_feats'],
n_tasks=args['n_tasks'])
return model
def chirality(atom):
try:
return one_hot_encoding(atom.GetProp('_CIPCode'), ['R', 'S']) + \
[atom.HasProp('_ChiralityPossible')]
except:
return [False, False] + [atom.HasProp('_ChiralityPossible')]
# A graph-convolutional neural network model for the prediction of chemical reactivity
- [paper in Chemical Science](https://pubs.rsc.org/en/content/articlelanding/2019/sc/c8sc04228d#!divAbstract)
- [authors' code](https://github.com/connorcoley/rexgen_direct)
An earlier version of the work was published in NeurIPS 2017 as
["Predicting Organic Reaction Outcomes with Weisfeiler-Lehman Network"](https://arxiv.org/abs/1709.04555) with some
slight difference in modeling.
This work proposes a template-free approach for reaction prediction with 2 stages:
1) Identify reaction center (pairs of atoms that will lose a bond or form a bond)
2) Enumerate the possible combinations of bond changes and rank the corresponding candidate products
We provide a jupyter notebook for walking through a demonstration with our pre-trained models. You can
download it with `wget https://data.dgl.ai/dgllife/reaction_prediction_pretrained.ipynb` and you need to put it
in this directory. Below we visualize a reaction prediction by the model:
![](https://data.dgl.ai/dgllife/wln_reaction.png)
## Dataset
The example by default works with reactions from USPTO (United States Patent and Trademark) granted patents,
collected by Lowe [1]. After removing duplicates and erroneous reactions, the authors obtain a set of 480K reactions.
The dataset is divided into 400K, 40K, and 40K for training, validation and test.
## Reaction Center Prediction
### Modeling
Reaction centers refer to the pairs of atoms that lose/form a bond in the reactions. A Graph neural network
(Weisfeiler-Lehman Network in this case) is trained to update the representations of all atoms. Then we combine
pairs of atom representations to predict the likelihood for the corresponding atoms to form/lose a bond.
For evaluation, we select pairs of atoms with top-k scores for each reaction and compute the proportion of reactions
whose reaction centers have all been selected.
### Training with Default Options
We use GPU whenever possible. To train the model with default options, simply do
```bash
python find_reaction_center_train.py
```
Once the training process starts, the progress will be printed in the terminal as follows:
```bash
Epoch 1/50, iter 8150/20452 | loss 8.4788 | grad norm 12.9927
Epoch 1/50, iter 8200/20452 | loss 8.6722 | grad norm 14.0833
```
Everytime the learning rate is decayed (specified as `'decay_every'` in `configure.py`'s `reaction_center_config`), we save a checkpoint of
the model and evaluate the model on the validation set. The evaluation result is formatted as follows, where `total samples x` means
the we have trained the model on `x` samples and `acc@k` means top-k accuracy:
```bash
total samples 800000, (epoch 2/35, iter 2443/2557) | acc@12 0.9278 | acc@16 0.9419 | acc@20 0.9496 | acc@40 0.9596 | acc@80 0.9596 |
```
All model check points and evaluation results can be found under `center_results`. `model_x.pkl` stores a model
checkpoint after seeing `x` training samples. `val_eval.txt` stores all
evaluation results on the validation set.
You may want to terminate the training process when the validation performance no longer improves for some time.
### Multi-GPU Training
By default we use one GPU only. We also allow multi-gpu training. To use GPUs with ids `id1,id2,...`, do
```bash
python find_reaction_center_train.py --gpus id1,id2,...
```
A summary of the training speedup with the DGL implementation is presented below.
| Item | Training time (s/epoch) | Speedup |
| ----------------------- | ----------------------- | ------- |
| Authors' implementation | 11657 | 1x |
| DGL with 1 gpu | 858 | 13.6x |
| DGL with 2 gpus | 443 | 26.3x |
| DGL with 4 gpus | 243 | 48.0x |
| DGL with 8 gpus | 134 | 87.0x |
### Evaluation
```bash
python find_reaction_center_eval.py --model-path X
```
For example, you can evaluate the model trained for 800000 samples by setting `X` to be
`center_results/model_800000.pkl`. The evaluation results will be stored at `center_results/test_eval.txt`.
For model evaluation, we can choose whether to exclude reactants not contributing heavy atoms to the product
(e.g. reagents and solvents) in top-k atom pair selection, which will make the task easier.
For the easier evaluation, do
```bash
python find_reaction_center_eval.py --easy
```
A summary of the model performance of various settings is as follows:
| Item | Top 6 accuracy | Top 8 accuracy | Top 10 accuracy |
| --------------- | -------------- | -------------- | --------------- |
| Paper | 89.8 | 92.0 | 93.3 |
| Hard evaluation from authors' code | 87.7 | 90.6 | 92.1 |
| Easy evaluation from authors' code | 90.0 | 92.8 | 94.2 |
| Hard evaluation | 88.9 | 91.7 | 93.1 |
| Easy evaluation | 91.2 | 93.8 | 95.0 |
| Hard evaluation for model trained on 8 gpus | 88.1 | 91.0 | 92.5 |
| Easy evaluation for model trained on 8 gpus | 90.3 | 93.3 | 94.6 |
1. We are able to match the results reported from authors' code for both single-gpu and multi-gpu training
2. While multi-gpu training provides a great speedup, the performance with the default hyperparameters drops slightly.
### Data Pre-processing with Multi-Processing
By default we use 32 processes for data pre-processing. If you encounter an error with
`BrokenPipeError: [Errno 32] Broken pipe`, you can specify a smaller number of processes with
```bash
python find_reaction_center_train.py -np X
```
```bash
python find_reaction_center_eval.py -np X
```
where `X` is the number of processes that you would like to use.
### Pre-trained Model
We provide a pre-trained model so users do not need to train from scratch. To evaluate the pre-trained model, simply do
```bash
python find_reaction_center_eval.py
```
### Adapting to a New Dataset
New datasets should be processed such that each line corresponds to the SMILES for a reaction like below:
```bash
[CH3:14][NH2:15].[N+:1](=[O:2])([O-:3])[c:4]1[cH:5][c:6]([C:7](=[O:8])[OH:9])[cH:10][cH:11][c:12]1[Cl:13].[OH2:16]>>[N+:1](=[O:2])([O-:3])[c:4]1[cH:5][c:6]([C:7](=[O:8])[OH:9])[cH:10][cH:11][c:12]1[NH:15][CH3:14]
```
The reactants are placed before `>>` and the product is placed after `>>`. The reactants are separated by `.`.
In addition, atom mapping information is provided.
You can then train a model on new datasets with
```bash
python find_reaction_center_train.py --train-path X --val-path Y
```
where `X`, `Y` are paths to the new training/validation as described above.
For evaluation,
```bash
python find_reaction_center_eval.py --eval-path Z
```
where `Z` is the path to the new test set as described above.
## Candidate Ranking
### Additional Dependency
In addition to RDKit, MolVS is an alternative for comparing whether two molecules are the same after sanitization.
- [molvs](https://molvs.readthedocs.io/en/latest/)
### Modeling
For candidate ranking, we assume that a model has been trained for reaction center prediction first.
The pipeline for predicting candidate products given a reaction proceeds as follows:
1. Select top-k bond changes for atom pairs in the reactants, ranked by the model for reaction center prediction.
By default, we use k=80 and exclude reactants not contributing heavy atoms to the ground truth product in
selecting top-k bond changes as in the paper.
2. Filter out candidate bond changes for bonds that are already in the reactants
3. Enumerate possible combinations of atom pairs with up to C pairs, which reflects the number of bond changes
(losing or forming a bond) in reactions. A statistical analysis in USPTO suggests that setting it to 5 is enough.
4. Filter out invalid combinations where 1) atoms in candidate bond changes are not connected or 2) an atom pair is
predicted to have different types of bond changes
(e.g. two atoms are predicted simultaneously to form a single and double bond) or 3) valence constraints are violated.
5. Apply the candidate bond changes for each valid combination and get the corresponding candidate products.
6. Construct molecular graphs for the reactants and candidate products, featurize their atoms and bonds.
7. Apply a Weisfeiler-Lehman Network to the molecular graphs for reactants and candidate products and score them
### Training with Default Options
We use GPU whenever possible. To train the model with default options, simply do
```bash
python candidate_ranking_train.py -cmp X
```
where `X` is the path to a trained model for reaction center prediction. You can use our
pre-trained model by not specifying `-cmp`.
Once the training process starts, the progress will be printed in the terminal as follows:
```bash
Epoch 6/6, iter 16439/20061 | time 1.1124 | accuracy 0.8500 | grad norm 5.3218
Epoch 6/6, iter 16440/20061 | time 1.1124 | accuracy 0.9500 | grad norm 2.1163
```
Everytime the learning rate is decayed (specified as `'decay_every'` in `configure.py`'s `candidate_ranking_config`),
we save a checkpoint of the model and evaluate the model on the validation set. The evaluation result is formatted
as follows, where `total samples x` means that we have trained the model for `x` samples, `acc@k` means top-k accuracy,
`gfound` means the proportion of reactions where the ground truth product can be recovered by the ground truth bond changes.
We perform the evaluation based on RDKit-sanitized molecule equivalence (marked with `[strict]`) and MOLVS-sanitized
molecule equivalence (marked with `[molvs]`).
```bash
total samples 100000, (epoch 1/20, iter 5000/20061)
[strict] acc@1: 0.7732 acc@2: 0.8466 acc@3: 0.8763 acc@5: 0.8987 gfound 0.9864
[molvs] acc@1: 0.7779 acc@2: 0.8523 acc@3: 0.8826 acc@5: 0.9057 gfound 0.9953
```
All model check points and evaluation results can be found under `candidate_results`. `model_x.pkl` stores a model
checkpoint after seeing `x` training samples in total. `val_eval.txt` stores all
evaluation results on the validation set.
You may want to terminate the training process when the validation performance no longer improves for some time.
### Evaluation
```bash
python candidate_ranking_eval.py --model-path X -cmp Y
```
where `X` is the path to a trained model for candidate ranking and `Y` is the path to a trained model
for reaction center prediction. For example, you can evaluate the model trained for 800000 samples by setting `X` to be
`candidate_results/model_800000.pkl`. The evaluation results will be stored at `candidate_results/test_eval.txt`. As
in training, you can use our pre-trained model by not specifying `-cmp`.
A summary of the model performance of various settings is as follows:
| Item | Top 1 accuracy | Top 2 accuracy | Top 3 accuracy | Top 5 accuracy |
| -------------------------- | -------------- | -------------- | -------------- | -------------- |
| Authors' strict evaluation | 85.6 | 90.5 | 92.8 | 93.4 |
| DGL's strict evaluation | 85.6 | 90.0 | 91.7 | 92.9 |
| Authors' molvs evaluation | 86.2 | 91.2 | 92.8 | 94.2 |
| DGL's molvs evaluation | 86.1 | 90.6 | 92.4 | 93.6 |
### Pre-trained Model
We provide a pre-trained model so users do not need to train from scratch. To evaluate the pre-trained model,
simply do
```bash
python candidate_ranking_eval.py
```
### Adapting to a New Dataset
You can train a model on new datasets with
```bash
python candidate_ranking_train.py --train-path X --val-path Y
```
where `X`, `Y` are paths to the new training/validation set as described in the `Reaction Center Prediction` section.
For evaluation,
```bash
python candidate_ranking_train.py --eval-path Z
```
where `Z` is the path to the new test set as described in the `Reaction Center Prediction` section.
## References
[1] D. M.Lowe, Patent reaction extraction: downloads,
https://bitbucket.org/dan2097/patent-reaction-extraction/downloads, 2014.
import torch
from dgllife.data import USPTORank, WLNRankDataset
from dgllife.model import WLNReactionRanking, load_pretrained
from torch.utils.data import DataLoader
from configure import candidate_ranking_config, reaction_center_config
from utils import mkdir_p, prepare_reaction_center, collate_rank_eval, candidate_ranking_eval
def main(args, path_to_candidate_bonds):
if args['test_path'] is None:
test_set = USPTORank(
subset='test', candidate_bond_path=path_to_candidate_bonds['test'],
max_num_change_combos_per_reaction=args['max_num_change_combos_per_reaction_eval'],
num_processes=args['num_processes'])
else:
test_set = WLNRankDataset(
raw_file_path=args['test_path'],
candidate_bond_path=path_to_candidate_bonds['test'], mode='test',
max_num_change_combos_per_reaction=args['max_num_change_combos_per_reaction_eval'],
num_processes=args['num_processes'])
test_loader = DataLoader(test_set, batch_size=1, collate_fn=collate_rank_eval,
shuffle=False, num_workers=args['num_workers'])
if args['model_path'] is None:
model = load_pretrained('wln_rank_uspto')
else:
model = WLNReactionRanking(
node_in_feats=args['node_in_feats'],
edge_in_feats=args['edge_in_feats'],
node_hidden_feats=args['hidden_size'],
num_encode_gnn_layers=args['num_encode_gnn_layers'])
model.load_state_dict(torch.load(
args['model_path'], map_location='cpu')['model_state_dict'])
model = model.to(args['device'])
prediction_summary = candidate_ranking_eval(args, model, test_loader)
with open(args['result_path'] + '/test_eval.txt', 'w') as f:
f.write(prediction_summary)
if __name__ == '__main__':
from argparse import ArgumentParser
parser = ArgumentParser(description='Candidate Ranking')
parser.add_argument('--model-path', type=str, default=None,
help='Path to saved model. If None, we will directly evaluate '
'a pretrained model on the test set.')
parser.add_argument('--result-path', type=str, default='candidate_results',
help='Path to save modeling results')
parser.add_argument('--test-path', type=str, default=None,
help='Path to a new test set. '
'If None, we will use the default test set in USPTO.')
parser.add_argument('-cmp', '--center-model-path', type=str, default=None,
help='Path to a pre-trained model for reaction center prediction. '
'By default we use the official pre-trained model. If not None, '
'the model should follow the hyperparameters specified in '
'reaction_center_config.')
parser.add_argument('-rcb', '--reaction-center-batch-size', type=int, default=200,
help='Batch size to use for preparing candidate bonds from a trained '
'model on reaction center prediction')
parser.add_argument('-np', '--num-processes', type=int, default=8,
help='Number of processes to use for data pre-processing')
parser.add_argument('-nw', '--num-workers', type=int, default=32,
help='Number of workers to use for data loading in PyTorch data loader')
args = parser.parse_args().__dict__
args.update(candidate_ranking_config)
mkdir_p(args['result_path'])
if torch.cuda.is_available():
args['device'] = torch.device('cuda:0')
else:
args['device'] = torch.device('cpu')
path_to_candidate_bonds = prepare_reaction_center(args, reaction_center_config)
main(args, path_to_candidate_bonds)
import time
import torch
from dgllife.data import USPTORank, WLNRankDataset
from dgllife.model import WLNReactionRanking
from torch.nn import CrossEntropyLoss
from torch.optim import Adam
from torch.utils.data import DataLoader
from configure import reaction_center_config, candidate_ranking_config
from utils import prepare_reaction_center, mkdir_p, set_seed, collate_rank_train, \
collate_rank_eval, candidate_ranking_eval
def main(args, path_to_candidate_bonds):
if args['train_path'] is None:
train_set = USPTORank(
subset='train', candidate_bond_path=path_to_candidate_bonds['train'],
max_num_change_combos_per_reaction=args['max_num_change_combos_per_reaction_train'],
num_processes=args['num_processes'])
else:
train_set = WLNRankDataset(
raw_file_path=args['train_path'],
candidate_bond_path=path_to_candidate_bonds['train'], mode='train',
max_num_change_combos_per_reaction=args['max_num_change_combos_per_reaction_train'],
num_processes=args['num_processes'])
train_set.ignore_large()
if args['val_path'] is None:
val_set = USPTORank(
subset='val', candidate_bond_path=path_to_candidate_bonds['val'],
max_num_change_combos_per_reaction=args['max_num_change_combos_per_reaction_eval'],
num_processes=args['num_processes'])
else:
val_set = WLNRankDataset(
raw_file_path=args['val_path'],
candidate_bond_path=path_to_candidate_bonds['val'], mode='val',
max_num_change_combos_per_reaction=args['max_num_change_combos_per_reaction_eval'],
num_processes=args['num_processes'])
train_loader = DataLoader(train_set, batch_size=args['batch_size'],
collate_fn=collate_rank_train,
shuffle=True, num_workers=args['num_workers'])
val_loader = DataLoader(val_set, batch_size=args['batch_size'],
collate_fn=collate_rank_eval,
shuffle=False, num_workers=args['num_workers'])
model = WLNReactionRanking(
node_in_feats=args['node_in_feats'],
edge_in_feats=args['edge_in_feats'],
node_hidden_feats=args['hidden_size'],
num_encode_gnn_layers=args['num_encode_gnn_layers']).to(args['device'])
criterion = CrossEntropyLoss(reduction='sum')
optimizer = Adam(model.parameters(), lr=args['lr'])
from utils import Optimizer
optimizer = Optimizer(model, args['lr'], optimizer, max_grad_norm=args['max_norm'])
acc_sum = 0
grad_norm_sum = 0
dur = []
total_samples = 0
for epoch in range(args['num_epochs']):
t0 = time.time()
model.train()
for batch_id, batch_data in enumerate(train_loader):
batch_reactant_graphs, batch_product_graphs, \
batch_combo_scores, batch_labels, batch_num_candidate_products = batch_data
batch_combo_scores = batch_combo_scores.to(args['device'])
batch_labels = batch_labels.to(args['device'])
reactant_node_feats = batch_reactant_graphs.ndata.pop('hv').to(args['device'])
reactant_edge_feats = batch_reactant_graphs.edata.pop('he').to(args['device'])
product_node_feats = batch_product_graphs.ndata.pop('hv').to(args['device'])
product_edge_feats = batch_product_graphs.edata.pop('he').to(args['device'])
pred = model(reactant_graph=batch_reactant_graphs,
reactant_node_feats=reactant_node_feats,
reactant_edge_feats=reactant_edge_feats,
product_graphs=batch_product_graphs,
product_node_feats=product_node_feats,
product_edge_feats=product_edge_feats,
candidate_scores=batch_combo_scores,
batch_num_candidate_products=batch_num_candidate_products)
# Check if the ground truth candidate has the highest score
batch_loss = 0
product_graph_start = 0
for i in range(len(batch_num_candidate_products)):
product_graph_end = product_graph_start + batch_num_candidate_products[i]
reaction_pred = pred[product_graph_start:product_graph_end, :]
acc_sum += float(reaction_pred.max(dim=0)[1].detach().cpu().data.item() == 0)
batch_loss += criterion(reaction_pred.reshape(1, -1), batch_labels[i, :])
product_graph_start = product_graph_end
grad_norm_sum += optimizer.backward_and_step(batch_loss)
total_samples += args['batch_size']
if total_samples % args['print_every'] == 0:
progress = 'Epoch {:d}/{:d}, iter {:d}/{:d} | time {:.4f} | ' \
'accuracy {:.4f} | grad norm {:.4f}'.format(
epoch + 1, args['num_epochs'],
(batch_id + 1) * args['batch_size'] // args['print_every'],
len(train_set) // args['print_every'],
(sum(dur) + time.time() - t0) / total_samples * args['print_every'],
acc_sum / args['print_every'],
grad_norm_sum / args['print_every'])
print(progress)
acc_sum = 0
grad_norm_sum = 0
if total_samples % args['decay_every'] == 0:
dur.append(time.time() - t0)
old_lr = optimizer.lr
optimizer.decay_lr(args['lr_decay_factor'])
new_lr = optimizer.lr
print('Learning rate decayed from {:.4f} to {:.4f}'.format(old_lr, new_lr))
torch.save({'model_state_dict': model.state_dict()},
args['result_path'] + '/model_{:d}.pkl'.format(total_samples))
prediction_summary = 'total samples {:d}, (epoch {:d}/{:d}, iter {:d}/{:d})\n'.format(
total_samples, epoch + 1, args['num_epochs'],
(batch_id + 1) * args['batch_size'] // args['print_every'],
len(train_set) // args['print_every']) + candidate_ranking_eval(args, model, val_loader)
print(prediction_summary)
with open(args['result_path'] + '/val_eval.txt', 'a') as f:
f.write(prediction_summary)
t0 = time.time()
model.train()
if __name__ == '__main__':
from argparse import ArgumentParser
parser = ArgumentParser(description='Candidate Ranking')
parser.add_argument('--result-path', type=str, default='candidate_results',
help='Path to save modeling results')
parser.add_argument('--train-path', type=str, default=None,
help='Path to a new training set. '
'If None, we will use the default training set in USPTO.')
parser.add_argument('--val-path', type=str, default=None,
help='Path to a new validation set. '
'If None, we will use the default validation set in USPTO.')
parser.add_argument('-cmp', '--center-model-path', type=str, default=None,
help='Path to a pre-trained model for reaction center prediction. '
'By default we use the official pre-trained model. If not None, '
'the model should follow the hyperparameters specified in '
'reaction_center_config.')
parser.add_argument('-rcb', '--reaction-center-batch-size', type=int, default=200,
help='Batch size to use for preparing candidate bonds from a trained '
'model on reaction center prediction')
parser.add_argument('-np', '--num-processes', type=int, default=8,
help='Number of processes to use for data pre-processing')
parser.add_argument('-nw', '--num-workers', type=int, default=100,
help='Number of workers to use for data loading in PyTorch data loader')
args = parser.parse_args().__dict__
args.update(candidate_ranking_config)
mkdir_p(args['result_path'])
set_seed()
if torch.cuda.is_available():
args['device'] = torch.device('cuda:0')
else:
args['device'] = torch.device('cpu')
path_to_candidate_bonds = prepare_reaction_center(args, reaction_center_config)
main(args, path_to_candidate_bonds)
# Configuration for reaction center identification
reaction_center_config = {
'batch_size': 20,
'hidden_size': 300,
'max_norm': 5.0,
'node_in_feats': 82,
'edge_in_feats': 6,
'node_pair_in_feats': 10,
'node_out_feats': 300,
'n_layers': 3,
'n_tasks': 5,
'lr': 0.001,
'num_epochs': 18,
'print_every': 50,
'decay_every': 10000, # Learning rate decay
'lr_decay_factor': 0.9,
'top_ks_val': [12, 16, 20, 40, 80],
'top_ks_test': [6, 8, 10],
'max_k': 80
}
# Configuration for candidate ranking
candidate_ranking_config = {
'batch_size': 4,
'hidden_size': 500,
'num_encode_gnn_layers': 3,
'max_norm': 50.0,
'node_in_feats': 89,
'edge_in_feats': 5,
'lr': 0.001,
'num_epochs': 6,
'print_every': 20,
'decay_every': 100000,
'lr_decay_factor': 0.9,
'top_ks': [1, 2, 3, 5],
'max_k': 10,
'max_num_change_combos_per_reaction_train': 150,
'max_num_change_combos_per_reaction_eval': 1500,
'num_candidate_bond_changes': 16
}
candidate_ranking_config['max_norm'] = candidate_ranking_config['max_norm'] * \
candidate_ranking_config['batch_size']
import torch
from dgllife.data import USPTOCenter, WLNCenterDataset
from dgllife.model import WLNReactionCenter, load_pretrained
from torch.utils.data import DataLoader
from utils import reaction_center_final_eval, set_seed, collate_center, mkdir_p
def main(args):
set_seed()
if torch.cuda.is_available():
args['device'] = torch.device('cuda:0')
else:
args['device'] = torch.device('cpu')
# Set current device
torch.cuda.set_device(args['device'])
if args['test_path'] is None:
test_set = USPTOCenter('test', num_processes=args['num_processes'])
else:
test_set = WLNCenterDataset(raw_file_path=args['test_path'],
mol_graph_path='test.bin',
num_processes=args['num_processes'])
test_loader = DataLoader(test_set, batch_size=args['batch_size'],
collate_fn=collate_center, shuffle=False)
if args['model_path'] is None:
model = load_pretrained('wln_center_uspto')
else:
model = WLNReactionCenter(node_in_feats=args['node_in_feats'],
edge_in_feats=args['edge_in_feats'],
node_pair_in_feats=args['node_pair_in_feats'],
node_out_feats=args['node_out_feats'],
n_layers=args['n_layers'],
n_tasks=args['n_tasks'])
model.load_state_dict(torch.load(
args['model_path'], map_location='cpu')['model_state_dict'])
model = model.to(args['device'])
print('Evaluation on the test set.')
test_result = reaction_center_final_eval(
args, args['top_ks_test'], model, test_loader, args['easy'])
print(test_result)
with open(args['result_path'] + '/test_eval.txt', 'w') as f:
f.write(test_result)
if __name__ == '__main__':
from argparse import ArgumentParser
from configure import reaction_center_config
parser = ArgumentParser(description='Reaction Center Identification -- Evaluation')
parser.add_argument('--model-path', type=str, default=None,
help='Path to saved model. If None, we will directly evaluate '
'a pretrained model on the test set.')
parser.add_argument('--result-path', type=str, default='center_results',
help='Path where we saved model training and evaluation results')
parser.add_argument('--test-path', type=str, default=None,
help='Path to a new test set.'
'If None, we will use the default test set in USPTO.')
parser.add_argument('--easy', action='store_true', default=False,
help='Whether to exclude reactants not contributing heavy atoms to the '
'product in top-k atom pair selection, which will make the '
'task easier.')
parser.add_argument('-np', '--num-processes', type=int, default=32,
help='Number of processes to use for data pre-processing')
args = parser.parse_args().__dict__
args.update(reaction_center_config)
assert args['max_k'] >= max(args['top_ks_test']), \
'Expect max_k to be no smaller than the possible options ' \
'of top_ks_test, got {:d} and {:d}'.format(args['max_k'], max(args['top_ks_test']))
mkdir_p(args['result_path'])
main(args)
import numpy as np
import time
import torch
from dgllife.data import USPTOCenter, WLNCenterDataset
from dgllife.model import WLNReactionCenter
from torch.nn import BCEWithLogitsLoss
from torch.optim import Adam
from torch.utils.data import DataLoader
from utils import collate_center, reaction_center_prediction, \
reaction_center_final_eval, mkdir_p, set_seed, synchronize, get_center_subset, \
count_parameters
def load_dataset(args):
if args['train_path'] is None:
train_set = USPTOCenter('train', num_processes=args['num_processes'])
else:
train_set = WLNCenterDataset(raw_file_path=args['train_path'],
mol_graph_path='train.bin',
num_processes=args['num_processes'])
if args['val_path'] is None:
val_set = USPTOCenter('val', num_processes=args['num_processes'])
else:
val_set = WLNCenterDataset(raw_file_path=args['val_path'],
mol_graph_path='val.bin',
num_processes=args['num_processes'])
return train_set, val_set
def main(rank, dev_id, args):
set_seed()
# Remove the line below will result in problems for multiprocess
if args['num_devices'] > 1:
torch.set_num_threads(1)
if dev_id == -1:
args['device'] = torch.device('cpu')
else:
args['device'] = torch.device('cuda:{}'.format(dev_id))
# Set current device
torch.cuda.set_device(args['device'])
train_set, val_set = load_dataset(args)
get_center_subset(train_set, rank, args['num_devices'])
train_loader = DataLoader(train_set, batch_size=args['batch_size'],
collate_fn=collate_center, shuffle=True)
val_loader = DataLoader(val_set, batch_size=args['batch_size'],
collate_fn=collate_center, shuffle=False)
model = WLNReactionCenter(node_in_feats=args['node_in_feats'],
edge_in_feats=args['edge_in_feats'],
node_pair_in_feats=args['node_pair_in_feats'],
node_out_feats=args['node_out_feats'],
n_layers=args['n_layers'],
n_tasks=args['n_tasks']).to(args['device'])
model.train()
if rank == 0:
print('# trainable parameters in the model: ', count_parameters(model))
criterion = BCEWithLogitsLoss(reduction='sum')
optimizer = Adam(model.parameters(), lr=args['lr'])
if args['num_devices'] <= 1:
from utils import Optimizer
optimizer = Optimizer(model, args['lr'], optimizer, max_grad_norm=args['max_norm'])
else:
from utils import MultiProcessOptimizer
optimizer = MultiProcessOptimizer(args['num_devices'], model, args['lr'],
optimizer, max_grad_norm=args['max_norm'])
total_iter = 0
rank_iter = 0
grad_norm_sum = 0
loss_sum = 0
dur = []
for epoch in range(args['num_epochs']):
t0 = time.time()
for batch_id, batch_data in enumerate(train_loader):
total_iter += args['num_devices']
rank_iter += 1
batch_reactions, batch_graph_edits, batch_mol_graphs, \
batch_complete_graphs, batch_atom_pair_labels = batch_data
labels = batch_atom_pair_labels.to(args['device'])
pred, biased_pred = reaction_center_prediction(
args['device'], model, batch_mol_graphs, batch_complete_graphs)
loss = criterion(pred, labels) / len(batch_reactions)
loss_sum += loss.cpu().detach().data.item()
grad_norm_sum += optimizer.backward_and_step(loss)
if rank_iter % args['print_every'] == 0 and rank == 0:
progress = 'Epoch {:d}/{:d}, iter {:d}/{:d} | ' \
'loss {:.4f} | grad norm {:.4f}'.format(
epoch + 1, args['num_epochs'], batch_id + 1, len(train_loader),
loss_sum / args['print_every'], grad_norm_sum / args['print_every'])
print(progress)
grad_norm_sum = 0
loss_sum = 0
if total_iter % args['decay_every'] == 0:
optimizer.decay_lr(args['lr_decay_factor'])
if total_iter % args['decay_every'] == 0 and rank == 0:
if epoch >= 1:
dur.append(time.time() - t0)
print('Training time per {:d} iterations: {:.4f}'.format(
rank_iter, np.mean(dur)))
total_samples = total_iter * args['batch_size']
prediction_summary = 'total samples {:d}, (epoch {:d}/{:d}, iter {:d}/{:d}) '.format(
total_samples, epoch + 1, args['num_epochs'], batch_id + 1, len(train_loader)) + \
reaction_center_final_eval(args, args['top_ks_val'], model, val_loader, easy=True)
print(prediction_summary)
with open(args['result_path'] + '/val_eval.txt', 'a') as f:
f.write(prediction_summary)
torch.save({'model_state_dict': model.state_dict()},
args['result_path'] + '/model_{:d}.pkl'.format(total_samples))
t0 = time.time()
model.train()
synchronize(args['num_devices'])
def run(rank, dev_id, args):
dist_init_method = 'tcp://{master_ip}:{master_port}'.format(
master_ip=args['master_ip'], master_port=args['master_port'])
torch.distributed.init_process_group(backend="nccl",
init_method=dist_init_method,
world_size=args['num_devices'],
rank=rank)
assert torch.distributed.get_rank() == rank
main(rank, dev_id, args)
if __name__ == '__main__':
from argparse import ArgumentParser
from configure import reaction_center_config
parser = ArgumentParser(description='Reaction Center Identification -- Training')
parser.add_argument('--gpus', default='0', type=str,
help='To use multi-gpu training, '
'pass multiple gpu ids with --gpus id1,id2,...')
parser.add_argument('--result-path', type=str, default='center_results',
help='Path to save modeling results')
parser.add_argument('--train-path', type=str, default=None,
help='Path to a new training set. '
'If None, we will use the default training set in USPTO.')
parser.add_argument('--val-path', type=str, default=None,
help='Path to a new validation set. '
'If None, we will use the default validation set in USPTO.')
parser.add_argument('-np', '--num-processes', type=int, default=32,
help='Number of processes to use for data pre-processing')
parser.add_argument('--master-ip', type=str, default='127.0.0.1',
help='master ip address')
parser.add_argument('--master-port', type=str, default='12345',
help='master port')
args = parser.parse_args().__dict__
args.update(reaction_center_config)
assert args['max_k'] >= max(args['top_ks_val']), \
'Expect max_k to be no smaller than the possible options ' \
'of top_ks, got {:d} and {:d}'.format(args['max_k'], max(args['top_ks_val']))
mkdir_p(args['result_path'])
devices = list(map(int, args['gpus'].split(',')))
args['num_devices'] = len(devices)
if len(devices) == 1:
device_id = devices[0] if torch.cuda.is_available() else -1
main(0, device_id, args)
else:
if (args['train_path'] is not None) or (args['val_path'] is not None):
print('First pass for constructing DGLGraphs with multiprocessing')
load_dataset(args)
# Subprocesses are not allowed for daemon mode
args['num_processes'] = 1
# With multi-gpu training, the batch size increases and we need to
# increase learning rate accordingly.
args['lr'] = args['lr'] * args['num_devices']
mp = torch.multiprocessing.get_context('spawn')
procs = []
for id, device_id in enumerate(devices):
print('Preparing for gpu {:d}/{:d}'.format(id + 1, args['num_devices']))
procs.append(mp.Process(target=run, args=(
id, device_id, args), daemon=True))
procs[-1].start()
for p in procs:
p.join()
import dgl
import errno
import numpy as np
import os
import random
import torch
import torch.distributed as dist
import torch.nn as nn
from collections import defaultdict
from copy import deepcopy
from dgllife.data import USPTOCenter, WLNCenterDataset
from dgllife.model import load_pretrained, WLNReactionCenter
from rdkit import Chem
from rdkit.Chem import AllChem
from torch.nn.utils import clip_grad_norm_
from torch.utils.data import DataLoader
try:
from molvs import Standardizer
except ImportError as e:
print('MolVS is not installed, which is required for candidate ranking')
def mkdir_p(path):
"""Create a folder for the given path.
Parameters
----------
path: str
Folder to create
"""
try:
os.makedirs(path)
print('Created directory {}'.format(path))
except OSError as exc:
if exc.errno == errno.EEXIST and os.path.isdir(path):
print('Directory {} already exists.'.format(path))
else:
raise
def set_seed(seed=0):
"""Fix random seed.
Parameters
----------
seed : int
Random seed to use. Default to 0.
"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
def count_parameters(model):
"""Get the number of trainable parameters in the model.
Parameters
----------
model : nn.Module
The model
Returns
-------
int
Number of trainable parameters in the model
"""
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def get_center_subset(dataset, subset_id, num_subsets):
"""Get subset for reaction center identification.
Parameters
----------
dataset : WLNCenterDataset
Dataset for reaction center prediction with WLN
subset_id : int
Index for the subset
num_subsets : int
Number of total subsets
"""
if num_subsets == 1:
return
total_size = len(dataset)
subset_size = total_size // num_subsets
start = subset_id * subset_size
end = (subset_id + 1) * subset_size
dataset.mols = dataset.mols[start:end]
dataset.reactions = dataset.reactions[start:end]
dataset.graph_edits = dataset.graph_edits[start:end]
dataset.reactant_mol_graphs = dataset.reactant_mol_graphs[start:end]
dataset.atom_pair_features = [None for _ in range(subset_size)]
dataset.atom_pair_labels = [None for _ in range(subset_size)]
class Optimizer(nn.Module):
"""Wrapper for optimization
Parameters
----------
model : nn.Module
Model being trained
lr : float
Initial learning rate
optimizer : torch.optim.Optimizer
model optimizer
num_accum_times : int
Number of times for accumulating gradients
max_grad_norm : float or None
If not None, gradient clipping will be performed
"""
def __init__(self, model, lr, optimizer, num_accum_times=1, max_grad_norm=None):
super(Optimizer, self).__init__()
self.model = model
self.lr = lr
self.optimizer = optimizer
self.step_count = 0
self.num_accum_times = num_accum_times
self.max_grad_norm = max_grad_norm
self._reset()
def _reset(self):
self.optimizer.zero_grad()
def _clip_grad_norm(self):
grad_norm = None
if self.max_grad_norm is not None:
grad_norm = clip_grad_norm_(self.model.parameters(),
self.max_grad_norm)
return grad_norm
def backward_and_step(self, loss):
"""Backward and update model.
Parameters
----------
loss : torch.tensor consisting of a float only
Returns
-------
grad_norm : float
Gradient norm. If self.max_grad_norm is None, None will be returned.
"""
self.step_count += 1
loss.backward()
if self.step_count % self.num_accum_times == 0:
grad_norm = self._clip_grad_norm()
self.optimizer.step()
self._reset()
return grad_norm
else:
return 0
def decay_lr(self, decay_rate):
"""Decay learning rate.
Parameters
----------
decay_rate : float
Multiply the current learning rate by the decay_rate
"""
self.lr *= decay_rate
for param_group in self.optimizer.param_groups:
param_group['lr'] = self.lr
class MultiProcessOptimizer(Optimizer):
"""Wrapper for optimization with multiprocess
Parameters
----------
n_processes : int
Number of processes used
model : nn.Module
Model being trained
lr : float
Initial learning rate
optimizer : torch.optim.Optimizer
model optimizer
max_grad_norm : float or None
If not None, gradient clipping will be performed.
"""
def __init__(self, n_processes, model, lr, optimizer, max_grad_norm=None):
super(MultiProcessOptimizer, self).__init__(lr=lr, model=model, optimizer=optimizer,
max_grad_norm=max_grad_norm)
self.n_processes = n_processes
def _sync_gradient(self):
"""Average gradients across all subprocesses."""
for param_group in self.optimizer.param_groups:
for p in param_group['params']:
if p.requires_grad and p.grad is not None:
dist.all_reduce(p.grad.data, op=dist.ReduceOp.SUM)
p.grad.data /= self.n_processes
def backward_and_step(self, loss):
"""Backward and update model.
Parameters
----------
loss : torch.tensor consisting of a float only
Returns
-------
grad_norm : float
Gradient norm. If self.max_grad_norm is None, None will be returned.
"""
loss.backward()
self._sync_gradient()
grad_norm = self._clip_grad_norm()
self.optimizer.step()
self._reset()
return grad_norm
def synchronize(num_gpus):
"""Synchronize all processes for multi-gpu training.
Parameters
----------
num_gpus : int
Number of gpus used
"""
if num_gpus > 1:
dist.barrier()
def collate_center(data):
"""Collate multiple datapoints for reaction center prediction
Parameters
----------
data : list of 7-tuples
Each tuple is for a single datapoint, consisting of
a reaction, graph edits in the reaction, an RDKit molecule instance for all reactants,
a DGLGraph for all reactants, a complete graph for all reactants, the features for each
pair of atoms and the labels for each pair of atoms.
Returns
-------
reactions : list of str
List of reactions.
graph_edits : list of str
List of graph edits in the reactions.
batch_mol_graphs : DGLGraph
DGLGraph for a batch of molecular graphs.
batch_complete_graphs : DGLGraph
DGLGraph for a batch of complete graphs.
batch_atom_pair_labels : float32 tensor of shape (V, 10)
Labels of atom pairs in the batch of graphs.
"""
reactions, graph_edits, mol_graphs, complete_graphs, \
atom_pair_feats, atom_pair_labels = map(list, zip(*data))
batch_mol_graphs = dgl.batch(mol_graphs)
batch_mol_graphs.set_n_initializer(dgl.init.zero_initializer)
batch_mol_graphs.set_e_initializer(dgl.init.zero_initializer)
batch_complete_graphs = dgl.batch(complete_graphs)
batch_complete_graphs.set_n_initializer(dgl.init.zero_initializer)
batch_complete_graphs.set_e_initializer(dgl.init.zero_initializer)
batch_complete_graphs.edata['feats'] = torch.cat(atom_pair_feats, dim=0)
batch_atom_pair_labels = torch.cat(atom_pair_labels, dim=0)
return reactions, graph_edits, batch_mol_graphs, \
batch_complete_graphs, batch_atom_pair_labels
def reaction_center_prediction(device, model, mol_graphs, complete_graphs):
"""Perform a soft prediction on reaction center.
Parameters
----------
device : str
Device to use for computation, e.g. 'cpu', 'cuda:0'
model : nn.Module
Model for prediction.
mol_graphs : DGLGraph
DGLGraph for a batch of molecular graphs
complete_graphs : DGLGraph
DGLGraph for a batch of complete graphs
Returns
-------
scores : float32 tensor of shape (E_full, 5)
Predicted scores for each pair of atoms to perform one of the following
5 actions in reaction:
* The bond between them gets broken
* Forming a single bond
* Forming a double bond
* Forming a triple bond
* Forming an aromatic bond
biased_scores : float32 tensor of shape (E_full, 5)
Comparing to scores, a bias is added if the pair is for a same atom.
"""
node_feats = mol_graphs.ndata.pop('hv').to(device)
edge_feats = mol_graphs.edata.pop('he').to(device)
node_pair_feats = complete_graphs.edata.pop('feats').to(device)
return model(mol_graphs, complete_graphs, node_feats, edge_feats, node_pair_feats)
bond_change_to_id = {0.0: 0, 1:1, 2:2, 3:3, 1.5:4}
id_to_bond_change = {v: k for k, v in bond_change_to_id.items()}
num_change_types = len(bond_change_to_id)
def get_candidate_bonds(reaction, preds, num_nodes, max_k, easy, include_scores=False):
"""Get candidate bonds for a reaction.
Parameters
----------
reaction : str
Reaction
preds : float32 tensor of shape (E * 5)
E for the number of edges in a complete graph and 5 for the number of possible
bond changes.
num_nodes : int
Number of nodes in the graph.
max_k : int
Maximum number of atom pairs to be selected.
easy : bool
If True, reactants not contributing atoms to the product will be excluded in
top-k atom pair selection, which will make the task easier.
include_scores : bool
Whether to include the scores for the atom pairs selected. Default to False.
Returns
-------
list of 3-tuples or 4-tuples
The first three elements in a tuple separately specify the first atom,
the second atom and the type for bond change. If include_scores is True,
the score for the prediction will be included as a fourth element.
"""
# Decide which atom-pairs will be considered.
reaction_atoms = []
reaction_bonds = defaultdict(bool)
reactants, _, product = reaction.split('>')
product_mol = Chem.MolFromSmiles(product)
product_atoms = set([atom.GetAtomMapNum() for atom in product_mol.GetAtoms()])
for reactant in reactants.split('.'):
reactant_mol = Chem.MolFromSmiles(reactant)
reactant_atoms = [atom.GetAtomMapNum() for atom in reactant_mol.GetAtoms()]
# In the hard mode, all reactant atoms will be included.
# In the easy mode, only reactants contributing atoms to the product will be included.
if (len(set(reactant_atoms) & product_atoms) > 0) or (not easy):
reaction_atoms.extend(reactant_atoms)
for bond in reactant_mol.GetBonds():
end_atoms = sorted([bond.GetBeginAtom().GetAtomMapNum(),
bond.GetEndAtom().GetAtomMapNum()])
bond = tuple(end_atoms + [bond.GetBondTypeAsDouble()])
# Bookkeep bonds already in reactants
reaction_bonds[bond] = True
candidate_bonds = []
topk_values, topk_indices = torch.topk(preds, max_k)
for j in range(max_k):
preds_j = topk_indices[j].cpu().item()
# A bond change can be either losing the bond or forming a
# single, double, triple or aromatic bond
change_id = preds_j % num_change_types
change_type = id_to_bond_change[change_id]
pair_id = preds_j // num_change_types
# Atom map numbers
atom1 = pair_id // num_nodes + 1
atom2 = pair_id % num_nodes + 1
# Avoid duplicates and an atom cannot form a bond with itself
if atom1 >= atom2:
continue
if atom1 not in reaction_atoms:
continue
if atom2 not in reaction_atoms:
continue
candidate = (int(atom1), int(atom2), float(change_type))
if reaction_bonds[candidate]:
continue
if include_scores:
candidate += (float(topk_values[j].cpu().item()),)
candidate_bonds.append(candidate)
return candidate_bonds
def reaction_center_eval(complete_graphs, preds, reactions,
graph_edits, num_correct, max_k, easy):
"""Evaluate top-k accuracies for reaction center prediction.
Parameters
----------
complete_graphs : DGLGraph
DGLGraph for a batch of complete graphs
preds : float32 tensor of shape (E_full, 5)
Soft predictions for reaction center, E_full being the number of possible
atom-pairs and 5 being the number of possible bond changes
reactions : list of str
List of reactions.
graph_edits : list of str
List of graph edits in the reactions.
num_correct : dict
Counting the number of datapoints for meeting top-k accuracies.
max_k : int
Maximum number of atom pairs to be selected. This is intended to be larger
than max(num_correct.keys()) as we will filter out many atom pairs due to
considerations such as avoiding duplicates.
easy : bool
If True, reactants not contributing atoms to the product will be excluded in
top-k atom pair selection, which will make the task easier.
"""
# 0 for losing the bond
# 1, 2, 3, 1.5 separately for forming a single, double, triple or aromatic bond.
batch_size = complete_graphs.batch_size
start = 0
for i in range(batch_size):
end = start + complete_graphs.batch_num_edges[i]
candidate_bonds = get_candidate_bonds(
reactions[i], preds[start:end, :].flatten(),
complete_graphs.batch_num_nodes[i], max_k, easy)
gold_bonds = []
gold_edits = graph_edits[i]
for edit in gold_edits.split(';'):
atom1, atom2, change_type = edit.split('-')
atom1, atom2 = int(atom1), int(atom2)
gold_bonds.append((min(atom1, atom2), max(atom1, atom2), float(change_type)))
for k in num_correct.keys():
if set(gold_bonds) <= set(candidate_bonds[:k]):
num_correct[k] += 1
start = end
def reaction_center_final_eval(args, top_ks, model, data_loader, easy):
"""Final evaluation of model performance.
args : dict
Configurations fot the experiment.
top_ks : list of int
Options for top-k evaluation
model : nn.Module
Model for reaction center prediction.
data_loader : torch.utils.data.DataLoader
Loader for fetching and batching data.
easy : bool
If True, reactants not contributing atoms to the product will be excluded in
top-k atom pair selection, which will make the task easier.
Returns
-------
msg : str
Summary of the top-k evaluation.
"""
model.eval()
num_correct = {k: 0 for k in top_ks}
for batch_id, batch_data in enumerate(data_loader):
batch_reactions, batch_graph_edits, batch_mol_graphs, \
batch_complete_graphs, batch_atom_pair_labels = batch_data
with torch.no_grad():
pred, biased_pred = reaction_center_prediction(
args['device'], model, batch_mol_graphs, batch_complete_graphs)
reaction_center_eval(batch_complete_graphs, biased_pred, batch_reactions,
batch_graph_edits, num_correct, args['max_k'], easy)
msg = '|'
for k, correct_count in num_correct.items():
msg += ' acc@{:d} {:.4f} |'.format(k, correct_count / len(data_loader.dataset))
return msg + '\n'
def output_candidate_bonds_for_a_reaction(info, max_k):
"""Prepare top-k atom pairs for each reaction as candidate bonds
Parameters
----------
info : 3-tuple for a reaction
Consists of the reaction, the scores for atom-pairs in reactants
and the number of nodes in reactants.
max_k : int
Maximum number of atom pairs to be selected.
Returns
-------
candidate_string : str
String representing candidate bonds for a reaction. Each candidate
bond is of format 'atom1 atom2 change_type score'.
"""
reaction, preds, num_nodes = info
# Note that we use the easy mode by default, which is also the
# setting in the paper.
candidate_bonds = get_candidate_bonds(reaction, preds, num_nodes, max_k,
easy=True, include_scores=True)
candidate_string = ''
for candidate in candidate_bonds:
# A 4-tuple consisting of the atom mapping number of atom 1,
# atom 2, the bond change type and the score
candidate_string += '{} {} {:.1f} {:.3f};'.format(
candidate[0], candidate[1], candidate[2], candidate[3])
candidate_string += '\n'
return candidate_string
def prepare_reaction_center(args, reaction_center_config):
"""Use a trained model for reaction center prediction to prepare candidate bonds.
Parameters
----------
args : dict
Configuration for the experiment.
reaction_center_config : dict
Configuration for the experiment on reaction center prediction.
Returns
-------
path_to_candidate_bonds : dict
Mapping 'train', 'val', 'test' to the corresponding files for candidate bonds.
"""
if args['center_model_path'] is None:
reaction_center_model = load_pretrained('wln_center_uspto').to(args['device'])
else:
reaction_center_model = WLNReactionCenter(
node_in_feats=reaction_center_config['node_in_feats'],
edge_in_feats=reaction_center_config['edge_in_feats'],
node_pair_in_feats=reaction_center_config['node_pair_in_feats'],
node_out_feats=reaction_center_config['node_out_feats'],
n_layers=reaction_center_config['n_layers'],
n_tasks=reaction_center_config['n_tasks'])
reaction_center_model.load_state_dict(
torch.load(args['center_model_path'])['model_state_dict'])
reaction_center_model = reaction_center_model.to(args['device'])
reaction_center_model.eval()
path_to_candidate_bonds = dict()
for subset in ['train', 'val', 'test']:
if '{}_path'.format(subset) not in args:
continue
path_to_candidate_bonds[subset] = args['result_path'] + \
'/{}_candidate_bonds.txt'.format(subset)
if os.path.isfile(path_to_candidate_bonds[subset]):
continue
print('Processing subset {}...'.format(subset))
print('Stage 1/3: Loading dataset...')
if args['{}_path'.format(subset)] is None:
dataset = USPTOCenter(subset, num_processes=args['num_processes'])
else:
dataset = WLNCenterDataset(raw_file_path=args['{}_path'.format(subset)],
mol_graph_path='{}.bin'.format(subset),
num_processes=args['num_processes'])
dataloader = DataLoader(dataset, batch_size=args['reaction_center_batch_size'],
collate_fn=collate_center, shuffle=False)
print('Stage 2/3: Performing model prediction...')
output_strings = []
for batch_id, batch_data in enumerate(dataloader):
print('Computing candidate bonds for batch {:d}/{:d}'.format(
batch_id + 1, len(dataloader)))
batch_reactions, batch_graph_edits, batch_mol_graphs, \
batch_complete_graphs, batch_atom_pair_labels = batch_data
with torch.no_grad():
pred, biased_pred = reaction_center_prediction(
args['device'], reaction_center_model,
batch_mol_graphs, batch_complete_graphs)
batch_size = len(batch_reactions)
start = 0
for i in range(batch_size):
end = start + batch_complete_graphs.batch_num_edges[i]
output_strings.append(output_candidate_bonds_for_a_reaction(
(batch_reactions[i], biased_pred[start:end, :].flatten(),
batch_complete_graphs.batch_num_nodes[i]), reaction_center_config['max_k']
))
start = end
print('Stage 3/3: Output candidate bonds...')
with open(path_to_candidate_bonds[subset], 'w') as f:
for candidate_string in output_strings:
f.write(candidate_string)
del dataset
del dataloader
del reaction_center_model
return path_to_candidate_bonds
def collate_rank_train(data):
"""Collate multiple datapoints for candidate product ranking during training
Parameters
----------
data : list of 3-tuples
Each tuple is for a single datapoint, consisting of DGLGraphs for reactants and candidate
products, scores for candidate products by the model for reaction center prediction,
and labels for candidate products.
Returns
-------
batch_reactant_graphs : DGLGraph
DGLGraph for a batch of batch_size reactants.
product_graphs : DGLGraph
DGLGraph for a batch of B candidate products
combo_scores : float32 tensor of shape (B, 1)
Scores for candidate products by the model for reaction center prediction.
labels : int64 tensor of shape (N, 1)
Indices for the true candidate product across reactions, which is always 0
with pre-processing. N is for the number of reactions.
batch_num_candidate_products : list of int
Number of candidate products for the reactions in this batch.
"""
batch_graphs, batch_combo_scores, batch_labels = map(list, zip(*data))
batch_reactant_graphs = dgl.batch([g_list[0] for g_list in batch_graphs])
batch_num_candidate_products = []
batch_product_graphs = []
for g_list in batch_graphs:
batch_num_candidate_products.append(len(g_list) - 1)
batch_product_graphs.extend(g_list[1:])
batch_product_graphs = dgl.batch(batch_product_graphs)
batch_combo_scores = torch.cat(batch_combo_scores, dim=0)
batch_labels = torch.cat(batch_labels, dim=0)
return batch_reactant_graphs, batch_product_graphs, batch_combo_scores, batch_labels, \
batch_num_candidate_products
def collate_rank_eval(data):
"""Collate multiple datapoints for candidate product ranking during evaluation
Parameters
----------
data : list of 3-tuples
Each tuple is for a single datapoint, consisting of DGLGraphs for reactants and candidate
products, scores for candidate products by the model for reaction center prediction,
and valid combos of candidate bond changes, one for each candidate product.
Returns
-------
batch_reactant_graph : DGLGraph
DGLGraph for a batch of batch_size reactants.
None will be returned if no valid candidate products exist.
batch_product_graphs : DGLGraph
DGLGraph for a batch of B candidate products.
None will be returned if no valid candidate products exist.
batch_combo_scores : float32 tensor of shape (B, 1)
Scores for candidate products by the model for reaction center prediction.
None will be returned if no valid candidate products exist.
valid_candidate_combos_list : list of list
valid_candidate_combos_list[i] gives valid combos of candidate bond changes for the
i-th reaction. valid_candidate_combos_list[i][j] gives a list of tuples, which is
the j-th valid combo of candidate bond changes for the reaction. Each tuple is of form
(atom1, atom2, change_type, score). atom1, atom2 are the atom mapping numbers - 1 of the
two end atoms. change_type can be 0, 1, 2, 3, 1.5, separately for losing a bond, forming
a single, double, triple, and aromatic bond. None will be returned if no valid candidate
products exist.
reactant_mols_list : list of rdkit.Chem.rdchem.Mol
RDKit molecule instance for the reactants in the batch.
None will be returned if no valid candidate products exist.
real_bond_changes_list : list of list
real_bond_changes_list[i] gives the ground truth bond changes in the i-th reaction,
which is a list of tuples. Each tuple is of form (atom1, atom2, change_type). atom1,
atom2 are the atom mapping numbers - 1 of the two end atoms. change_type can be
0, 1, 2, 3, 1.5, separately for losing a bond, forming a single, double, triple, and
aromatic bond. None will be returned if no valid candidate products exist.
product_mols_list : list of rdkit.Chem.rdchem.Mol
RDKit molecule instance for the candidate products in each reaction.
None will be returned if no valid candidate products exist.
batch_num_candidate_products : list of int
Number of candidate products for the reactions in this batch.
"""
batch_graphs, batch_combo_scores, batch_valid_candidate_combos, \
batch_reactant_mols, batch_real_bond_changes, batch_product_mols = map(list, zip(*data))
batch_reactant_graphs = []
batch_product_graphs = []
combo_scores_list = []
valid_candidate_combos_list = []
reactant_mols_list = []
real_bond_changes_list = []
product_mols_list = []
batch_num_candidate_products = []
for i in range(len(batch_graphs)):
g_list = batch_graphs[i]
# No valid candidate products have been predicted
if len(g_list) == 1:
continue
batch_reactant_graphs.append(g_list[0])
batch_product_graphs.extend(g_list[1:])
combo_scores_list.append(batch_combo_scores[i])
valid_candidate_combos_list.append(batch_valid_candidate_combos[i])
reactant_mols_list.append(batch_reactant_mols[i])
real_bond_changes_list.append(batch_real_bond_changes[i])
product_mols_list.append(batch_product_mols[i])
batch_num_candidate_products.append(len(g_list) - 1)
if len(batch_product_graphs) == 0:
return None, None, None, None, None, None, None, None
batch_reactant_graphs = dgl.batch(batch_reactant_graphs)
batch_product_graphs = dgl.batch(batch_product_graphs)
batch_combo_scores = torch.cat(combo_scores_list, dim=0)
return batch_reactant_graphs, batch_product_graphs, batch_combo_scores, \
valid_candidate_combos_list, reactant_mols_list, real_bond_changes_list, \
product_mols_list, batch_num_candidate_products
def sanitize_smiles_molvs(smiles, largest_fragment=False):
"""Sanitize a SMILES with MolVS
Parameters
----------
smiles : str
SMILES string for a molecule.
largest_fragment : bool
Whether to select only the largest covalent unit in a molecule with
multiple fragments. Default to False.
Returns
-------
str
SMILES string for the sanitized molecule.
"""
standardizer = Standardizer()
standardizer.prefer_organic = True
mol = Chem.MolFromSmiles(smiles)
if mol is None:
return smiles
try:
mol = standardizer.standardize(mol) # standardize functional group reps
if largest_fragment:
mol = standardizer.largest_fragment(mol) # remove product counterions/salts/etc.
mol = standardizer.uncharge(mol) # neutralize, e.g., carboxylic acids
except Exception:
pass
return Chem.MolToSmiles(mol)
def bookkeep_reactant(mol):
"""Bookkeep bonds in the reactant.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance for reactants.
Returns
-------
pair_to_bond_type : dict
Mapping 2-tuples of atoms to bond type. 1, 2, 3, 1.5 are
separately for single, double, triple and aromatic bond.
"""
pair_to_bond_type = dict()
for bond in mol.GetBonds():
atom1, atom2 = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
atom1, atom2 = min(atom1, atom2), max(atom1, atom2)
type_val = bond.GetBondTypeAsDouble()
pair_to_bond_type[(atom1, atom2)] = type_val
return pair_to_bond_type
bond_change_to_type = {1: Chem.rdchem.BondType.SINGLE, 2: Chem.rdchem.BondType.DOUBLE,
3: Chem.rdchem.BondType.TRIPLE, 1.5: Chem.rdchem.BondType.AROMATIC}
clean_rxns_postsani = [
# two adjacent aromatic nitrogens should allow for H shift
AllChem.ReactionFromSmarts('[n;H1;+0:1]:[n;H0;+1:2]>>[n;H0;+0:1]:[n;H0;+0:2]'),
# two aromatic nitrogens separated by one should allow for H shift
AllChem.ReactionFromSmarts('[n;H1;+0:1]:[c:3]:[n;H0;+1:2]>>[n;H0;+0:1]:[*:3]:[n;H0;+0:2]'),
AllChem.ReactionFromSmarts('[#7;H0;+:1]-[O;H1;+0:2]>>[#7;H0;+:1]-[O;H0;-:2]'),
# neutralize C(=O)[O-]
AllChem.ReactionFromSmarts('[C;H0;+0:1](=[O;H0;+0:2])[O;H0;-1:3]>>[C;H0;+0:1](=[O;H0;+0:2])[O;H1;+0:3]'),
# turn neutral halogens into anions EXCEPT HCl
AllChem.ReactionFromSmarts('[I,Br,F;H1;D0;+0:1]>>[*;H0;-1:1]'),
# inexplicable nitrogen anion in reactants gets fixed in prods
AllChem.ReactionFromSmarts('[N;H0;-1:1]([C:2])[C:3]>>[N;H1;+0:1]([*:2])[*:3]'),
]
def edit_mol(rmol, bond_changes, keep_atom_map=False):
"""Simulate reaction via graph editing
Parameters
----------
rmol : rdkit.Chem.rdchem.Mol
RDKit molecule instance for the reactants
bond_changes : list of 3-tuples
Each tuple is of form (atom1, atom2, change_type)
keep_atom_map : bool
Whether to keep atom mapping number. Default to False.
Returns
-------
pred_smiles : list of str
SMILES for the edited molecule
"""
new_mol = Chem.RWMol(rmol)
# Keep track of aromatic nitrogens, which might cause explicit hydrogen issues
aromatic_nitrogen_ids = set()
aromatic_carbonyl_adj_to_aromatic_nh = dict()
aromatic_carbondeg3_adj_to_aromatic_nh0 = dict()
for atom in new_mol.GetAtoms():
if atom.GetIsAromatic() and atom.GetSymbol() == 'N':
aromatic_nitrogen_ids.add(atom.GetIdx())
for nbr in atom.GetNeighbors():
if atom.GetNumExplicitHs() == 1 and nbr.GetSymbol() == 'C' and \
nbr.GetIsAromatic() and \
any(b.GetBondTypeAsDouble() == 2 for b in nbr.GetBonds()):
aromatic_carbonyl_adj_to_aromatic_nh[nbr.GetIdx()] = atom.GetIdx()
elif atom.GetNumExplicitHs() == 0 and nbr.GetSymbol() == 'C' and \
nbr.GetIsAromatic() and len(nbr.GetBonds()) == 3:
aromatic_carbondeg3_adj_to_aromatic_nh0[nbr.GetIdx()] = atom.GetIdx()
else:
atom.SetNumExplicitHs(0)
new_mol.UpdatePropertyCache()
for atom1_id, atom2_id, change_type in bond_changes:
bond = new_mol.GetBondBetweenAtoms(atom1_id, atom2_id)
atom1 = new_mol.GetAtomWithIdx(atom1_id)
atom2 = new_mol.GetAtomWithIdx(atom2_id)
if bond is not None:
new_mol.RemoveBond(atom1_id, atom2_id)
# Are we losing a bond on an aromatic nitrogen?
if bond.GetBondTypeAsDouble() == 1.0:
if atom1_id in aromatic_nitrogen_ids:
if atom1.GetTotalNumHs() == 0:
atom1.SetNumExplicitHs(1)
elif atom1.GetFormalCharge() == 1:
atom1.SetFormalCharge(0)
elif atom2_id in aromatic_nitrogen_ids:
if atom2.GetTotalNumHs() == 0:
atom2.SetNumExplicitHs(1)
elif atom2.GetFormalCharge() == 1:
atom2.SetFormalCharge(0)
# Are we losing a c=O bond on an aromatic ring?
# If so, remove H from adjacent nH if appropriate
if bond.GetBondTypeAsDouble() == 2.0:
both_aromatic_nh_ids = [
aromatic_carbonyl_adj_to_aromatic_nh.get(atom1_id, None),
aromatic_carbonyl_adj_to_aromatic_nh.get(atom2_id, None)
]
for aromatic_nh_id in both_aromatic_nh_ids:
if aromatic_nh_id is not None:
new_mol.GetAtomWithIdx(aromatic_nh_id).SetNumExplicitHs(0)
if change_type > 0:
new_mol.AddBond(atom1_id, atom2_id, bond_change_to_type[change_type])
# Special alkylation case?
if change_type == 1:
if atom1_id in aromatic_nitrogen_ids:
if atom1.GetTotalNumHs() == 1:
atom1.SetNumExplicitHs(0)
else:
atom1.SetFormalCharge(1)
elif atom2_id in aromatic_nitrogen_ids:
if atom2.GetTotalNumHs() == 1:
atom2.SetNumExplicitHs(0)
else:
atom2.SetFormalCharge(1)
# Are we getting a c=O bond on an aromatic ring?
# If so, add H to adjacent nH0 if appropriate
if change_type == 2:
both_aromatic_nh0_ids = [
aromatic_carbondeg3_adj_to_aromatic_nh0.get(atom1_id, None),
aromatic_carbondeg3_adj_to_aromatic_nh0.get(atom2_id, None)
]
for aromatic_nh0_id in both_aromatic_nh0_ids:
if aromatic_nh0_id is not None:
new_mol.GetAtomWithIdx(aromatic_nh0_id).SetNumExplicitHs(1)
pred_mol = new_mol.GetMol()
# Clear formal charges to make molecules valid
# Note: because S and P (among others) can change valence, be more flexible
for atom in pred_mol.GetAtoms():
if not keep_atom_map:
atom.ClearProp('molAtomMapNumber')
if atom.GetSymbol() == 'N' and atom.GetFormalCharge() == 1:
# exclude negatively-charged azide
bond_vals = sum([bond.GetBondTypeAsDouble() for bond in atom.GetBonds()])
if bond_vals <= 3:
atom.SetFormalCharge(0)
elif atom.GetSymbol() == 'N' and atom.GetFormalCharge() == -1:
# handle negatively-charged azide addition
bond_vals = sum([bond.GetBondTypeAsDouble() for bond in atom.GetBonds()])
if bond_vals == 3 and any([nbr.GetSymbol() == 'N' for nbr in atom.GetNeighbors()]):
atom.SetFormalCharge(0)
elif atom.GetSymbol() == 'N':
bond_vals = sum([bond.GetBondTypeAsDouble() for bond in atom.GetBonds()])
if bond_vals == 4 and not atom.GetIsAromatic():
atom.SetFormalCharge(1)
elif atom.GetSymbol() == 'C' and atom.GetFormalCharge() != 0:
atom.SetFormalCharge(0)
elif atom.GetSymbol() == 'O' and atom.GetFormalCharge() != 0:
bond_vals = sum([bond.GetBondTypeAsDouble() for bond in atom.GetBonds()]) + \
atom.GetNumExplicitHs()
if bond_vals == 2:
atom.SetFormalCharge(0)
elif atom.GetSymbol() in ['Cl', 'Br', 'I', 'F'] and atom.GetFormalCharge() != 0:
bond_vals = sum([bond.GetBondTypeAsDouble() for bond in atom.GetBonds()])
if bond_vals == 1:
atom.SetFormalCharge(0)
elif atom.GetSymbol() == 'S' and atom.GetFormalCharge() != 0:
bond_vals = sum([bond.GetBondTypeAsDouble() for bond in atom.GetBonds()])
if bond_vals in [2, 4, 6]:
atom.SetFormalCharge(0)
elif atom.GetSymbol() == 'P':
# quartenary phosphorous should be pos. charge with 0 H
bond_vals = [bond.GetBondTypeAsDouble() for bond in atom.GetBonds()]
if sum(bond_vals) == 4 and len(bond_vals) == 4:
atom.SetFormalCharge(1)
atom.SetNumExplicitHs(0)
elif sum(bond_vals) == 3 and len(bond_vals) == 3:
# make sure neutral
atom.SetFormalCharge(0)
elif atom.GetSymbol() == 'B':
# quartenary boron should be neg. charge with 0 H
bond_vals = [bond.GetBondTypeAsDouble() for bond in atom.GetBonds()]
if sum(bond_vals) == 4 and len(bond_vals) == 4:
atom.SetFormalCharge(-1)
atom.SetNumExplicitHs(0)
elif atom.GetSymbol() in ['Mg', 'Zn']:
bond_vals = [bond.GetBondTypeAsDouble() for bond in atom.GetBonds()]
if sum(bond_vals) == 1 and len(bond_vals) == 1:
atom.SetFormalCharge(1)
elif atom.GetSymbol() == 'Si':
bond_vals = [bond.GetBondTypeAsDouble() for bond in atom.GetBonds()]
if sum(bond_vals) == len(bond_vals):
atom.SetNumExplicitHs(max(0, 4 - len(bond_vals)))
# Bounce to/from SMILES to try to sanitize
pred_smiles = Chem.MolToSmiles(pred_mol)
pred_list = pred_smiles.split('.')
pred_mols = [Chem.MolFromSmiles(pred_smiles) for pred_smiles in pred_list]
for i, mol in enumerate(pred_mols):
if mol is None:
continue
mol = Chem.MolFromSmiles(Chem.MolToSmiles(mol))
if mol is None:
continue
for rxn in clean_rxns_postsani:
out = rxn.RunReactants((mol,))
if out:
try:
Chem.SanitizeMol(out[0][0])
pred_mols[i] = Chem.MolFromSmiles(Chem.MolToSmiles(out[0][0]))
except Exception as e:
pass
pred_smiles = [Chem.MolToSmiles(pred_mol) for pred_mol in pred_mols if pred_mol is not None]
return pred_smiles
def examine_topk_candidate_product(topks, topk_combos, reactant_mol,
real_bond_changes, product_mol):
"""Perform topk evaluation for predicting the product of a reaction
Parameters
----------
topks : list of int
Options for top-k evaluation, e.g. [1, 3, ...].
topk_combos : list of list
topk_combos[i] gives the combo of valid bond changes ranked i-th,
which is a list of 3-tuples. Each tuple is of form
(atom1, atom2, change_type). atom1, atom2 are the atom mapping numbers - 1 of the two
end atoms. The change_type can be 0, 1, 2, 3, 1.5, separately for losing a bond or
forming a single, double, triple, aromatic bond.
reactant_mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance for the reactants.
real_bond_changes : list of tuples
Ground truth bond changes in a reaction. Each tuple is of form (atom1, atom2,
change_type). atom1, atom2 are the atom mapping numbers - 1 of the two
end atoms. change_type can be 0, 1, 2, 3, 1.5, separately for losing a bond, forming
a single, double, triple, and aromatic bond.
product_mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance for the product.
get_smiles : bool
Whether to get the SMILES of candidate products.
Returns
-------
found_info : dict
Binary values indicating whether we can recover the product from the ground truth
graph edits or top-k predicted edits
"""
found_info = defaultdict(bool)
# Avoid corrupting the RDKit molecule instances in the dataset
reactant_mol = deepcopy(reactant_mol)
product_mol = deepcopy(product_mol)
for atom in product_mol.GetAtoms():
atom.ClearProp('molAtomMapNumber')
product_smiles = Chem.MolToSmiles(product_mol)
product_smiles_sanitized = set(sanitize_smiles_molvs(product_smiles, True).split('.'))
product_smiles = set(product_smiles.split('.'))
########### Use *true* edits to try to recover product
# Generate product by modifying reactants with graph edits
pred_smiles = edit_mol(reactant_mol, real_bond_changes)
pred_smiles_sanitized = set(sanitize_smiles_molvs(smiles) for smiles in pred_smiles)
pred_smiles = set(pred_smiles)
if not product_smiles <= pred_smiles:
# Try again with kekulized form
Chem.Kekulize(reactant_mol)
pred_smiles_kek = edit_mol(reactant_mol, real_bond_changes)
pred_smiles_kek = set(pred_smiles_kek)
if not product_smiles <= pred_smiles_kek:
if product_smiles_sanitized <= pred_smiles_sanitized:
print('\nwarn: mismatch, but only due to standardization')
found_info['ground_sanitized'] = True
else:
print('\nwarn: could not regenerate product {}'.format(product_smiles))
print('sani product: {}'.format(product_smiles_sanitized))
print(Chem.MolToSmiles(reactant_mol))
print(Chem.MolToSmiles(product_mol))
print(real_bond_changes)
print('pred_smiles: {}'.format(pred_smiles))
print('pred_smiles_kek: {}'.format(pred_smiles_kek))
print('pred_smiles_sani: {}'.format(pred_smiles_sanitized))
else:
found_info['ground'] = True
found_info['ground_sanitized'] = True
else:
found_info['ground'] = True
found_info['ground_sanitized'] = True
########### Now use candidate edits to try to recover product
max_topk = max(topks)
current_rank = 0
correct_rank = max_topk + 1
sanitized_correct_rank = max_topk + 1
candidate_smiles_list = []
candidate_smiles_sanitized_list = []
for i, combo in enumerate(topk_combos):
prev_len_candidate_smiles = len(set(candidate_smiles_list))
# Generate products by modifying reactants with predicted edits.
candidate_smiles = edit_mol(reactant_mol, combo)
candidate_smiles = set(candidate_smiles)
candidate_smiles_sanitized = set(sanitize_smiles_molvs(smiles)
for smiles in candidate_smiles)
if product_smiles_sanitized <= candidate_smiles_sanitized:
sanitized_correct_rank = min(sanitized_correct_rank, current_rank + 1)
if product_smiles <= candidate_smiles:
correct_rank = min(correct_rank, current_rank + 1)
# Record unkekulized form
candidate_smiles_list.append('.'.join(candidate_smiles))
candidate_smiles_sanitized_list.append('.'.join(candidate_smiles_sanitized))
# Edit molecules with reactants kekulized. Sometimes previous editing fails due to
# RDKit sanitization error (edited molecule cannot be kekulized)
try:
Chem.Kekulize(reactant_mol)
except Exception as e:
pass
candidate_smiles = edit_mol(reactant_mol, combo)
candidate_smiles = set(candidate_smiles)
candidate_smiles_sanitized = set(sanitize_smiles_molvs(smiles)
for smiles in candidate_smiles)
if product_smiles_sanitized <= candidate_smiles_sanitized:
sanitized_correct_rank = min(sanitized_correct_rank, current_rank + 1)
if product_smiles <= candidate_smiles:
correct_rank = min(correct_rank, current_rank + 1)
# If we failed to come up with a new candidate, don't increment the counter!
if len(set(candidate_smiles_list)) > prev_len_candidate_smiles:
current_rank += 1
if correct_rank < max_topk + 1 and sanitized_correct_rank < max_topk + 1:
break
for k in topks:
if correct_rank <= k:
found_info['top_{:d}'.format(k)] = True
if sanitized_correct_rank <= k:
found_info['top_{:d}_sanitized'.format(k)] = True
return found_info
def summary_candidate_ranking_info(top_ks, found_info, data_size):
"""Get a string for summarizing the candidate ranking results
Parameters
----------
top_ks : list of int
Options for top-k evaluation, e.g. [1, 3, ...].
found_info : dict
Storing the count of correct predictions
data_size : int
Size for the dataset
Returns
-------
string : str
String summarizing the evaluation results
"""
string = '[strict]'
for k in top_ks:
string += ' acc@{:d}: {:.4f}'.format(k, found_info['top_{:d}'.format(k)] / data_size)
string += ' gfound {:.4f}\n'.format(found_info['ground'] / data_size)
string += '[molvs]'
for k in top_ks:
string += ' acc@{:d}: {:.4f}'.format(
k, found_info['top_{:d}_sanitized'.format(k)] / data_size)
string += ' gfound {:.4f}\n'.format(found_info['ground_sanitized'] / data_size)
return string
def candidate_ranking_eval(args, model, data_loader):
"""Evaluate model performance on candidate ranking.
Parameters
----------
args : dict
Configurations fot the experiment.
model : nn.Module
Model for reaction center prediction.
data_loader : torch.utils.data.DataLoader
Loader for fetching and batching data.
Returns
-------
str
String summarizing the evaluation results
"""
model.eval()
# Record how many product can be recovered by real graph edits (with/without sanitization)
found_info_summary = {'ground': 0, 'ground_sanitized': 0}
for k in args['top_ks']:
found_info_summary['top_{:d}'.format(k)] = 0
found_info_summary['top_{:d}_sanitized'.format(k)] = 0
total_samples = 0
for batch_id, batch_data in enumerate(data_loader):
batch_reactant_graphs, batch_product_graphs, batch_combo_scores, \
batch_valid_candidate_combos, batch_reactant_mols, batch_real_bond_changes, \
batch_product_mols, batch_num_candidate_products = batch_data
# No valid candidate products have been predicted
if batch_reactant_graphs is None:
continue
total_samples += len(batch_num_candidate_products)
batch_combo_scores = batch_combo_scores.to(args['device'])
reactant_node_feats = batch_reactant_graphs.ndata.pop('hv').to(args['device'])
reactant_edge_feats = batch_reactant_graphs.edata.pop('he').to(args['device'])
product_node_feats = batch_product_graphs.ndata.pop('hv').to(args['device'])
product_edge_feats = batch_product_graphs.edata.pop('he').to(args['device'])
# Get candidate products with top-k ranking
with torch.no_grad():
pred = model(reactant_graph=batch_reactant_graphs,
reactant_node_feats=reactant_node_feats,
reactant_edge_feats=reactant_edge_feats,
product_graphs=batch_product_graphs,
product_node_feats=product_node_feats,
product_edge_feats=product_edge_feats,
candidate_scores=batch_combo_scores,
batch_num_candidate_products=batch_num_candidate_products)
product_graph_start = 0
for i in range(len(batch_num_candidate_products)):
num_candidate_products = batch_num_candidate_products[i]
reactant_mol = batch_reactant_mols[i]
valid_candidate_combos = batch_valid_candidate_combos[i]
real_bond_changes = batch_real_bond_changes[i]
product_mol = batch_product_mols[i]
product_graph_end = product_graph_start + num_candidate_products
top_k = min(args['max_k'], num_candidate_products)
reaction_pred = pred[product_graph_start:product_graph_end, :]
topk_values, topk_indices = torch.topk(reaction_pred, top_k, dim=0)
# Filter out invalid candidate bond changes
reactant_pair_to_bond = bookkeep_reactant(reactant_mol)
topk_combos = []
for i in topk_indices:
i = i.detach().cpu().item()
combo = []
for atom1, atom2, change_type, score in valid_candidate_combos[i]:
bond_in_reactant = reactant_pair_to_bond.get((atom1, atom2), None)
if (bond_in_reactant is None and change_type > 0) or \
(bond_in_reactant is not None and bond_in_reactant != change_type):
combo.append((atom1, atom2, change_type))
topk_combos.append(combo)
batch_found_info = examine_topk_candidate_product(
args['top_ks'], topk_combos, reactant_mol, real_bond_changes, product_mol)
for k, v in batch_found_info.items():
found_info_summary[k] += float(v)
product_graph_start = product_graph_end
if total_samples % args['print_every'] == 0:
print('Iter {:d}/{:d}'.format(
total_samples // args['print_every'],
len(data_loader.dataset) // args['print_every']))
print(summary_candidate_ranking_info(
args['top_ks'], found_info_summary, total_samples))
return summary_candidate_ranking_info(args['top_ks'], found_info_summary, total_samples)
"""DGL-based package for applications in life science."""
from . import data
from . import model
from . import utils
from .libinfo import __version__
"""Dataset classes."""
from .alchemy import *
from .csv_dataset import *
from .pdbbind import *
from .pubchem_aromaticity import *
from .tox21 import *
from .uspto import *
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment