Commit 3192beb4 authored by VoVAllen's avatar VoVAllen Committed by Mufei Li
Browse files

[Model zoo] JTNN model zoo (#790)

* jtnn model zoo

* poke ci

* fix line sep

* fix

* Fix import order

* fix render

* fix render

* revert

* fix

* Resolve conflict

* dix

* remove create_var

* refactor

* fix

* refactor

* readme

* format

* fix lint

* fix lint

* pylint

* lint

* fix lint

* fix lint

* add hint

* fix

* Remove vocab

* Add explanation for warning

* add directory

* Load model to cpu by default

* Update
parent 30c46251
from .mol_tree import Vocab
from .jtnn_vae import DGLJTNNVAE
from .mpn import DGLMPN
from .nnutils import create_var, cuda
from .nnutils import cuda
from .datautils import JTNNDataset, JTNNCollator
from .chemutils import decode_stereo
......@@ -2,7 +2,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from .mol_tree import Vocab
from .nnutils import create_var, cuda, move_dgl_to_cuda
from .nnutils import cuda, move_dgl_to_cuda
from .chemutils import set_atommap, copy_edit_mol, enum_assemble_nx, \
attach_mols_nx, decode_stereo
from .jtnn_enc import DGLJTNNEncoder
......
......@@ -3,11 +3,6 @@ import torch.nn as nn
from torch.autograd import Variable
import os
def create_var(tensor, requires_grad=None):
if requires_grad is None:
return Variable(tensor)
else:
return Variable(tensor, requires_grad=requires_grad)
def cuda(tensor):
if torch.cuda.is_available() and not os.getenv('NOCUDA', None):
......
# Junction Tree Variational Autoencoder for Molecular Graph Generation (JTNN)
Wengong Jin, Regina Barzilay, Tommi Jaakkola.
Junction Tree Variational Autoencoder for Molecular Graph Generation.
*arXiv preprint arXiv:1802.04364*, 2018.
JTNN uses algorithm called junction tree algorithm to form a tree from the molecular graph.
Then the model will encode the tree and graph into two separate vectors `z_G` and `z_T`. Details can
be found in original paper. The brief process is as below (from original paper):
![image](https://user-images.githubusercontent.com/8686776/63677300-3fb6d980-c81f-11e9-8a65-57c8b03aaf52.png)
**Goal**: JTNN is an auto-encoder model, aiming to learn hidden representation for molecular graphs.
These representations can be used for downstream tasks, such as property prediction, or molecule optimizations.
## Dataset
### ZINC
> The ZINC database is a curated collection of commercially available chemical compounds
prepared especially for virtual screening. (introduction from Wikipedia)
Generally speaking, molecules in the ZINC dataset are more drug-like. We uses ~220,000
molecules for training and 5000 molecules for validation.
### Preprocessing
Class `JTNNDataset` will process a SMILES into a dict, including the junction tree, graph with
encoded nodes(atoms) and edges(bonds), and other information for model to use.
## Usage
### Training
To start training, use `python train.py`. By default, the script will use ZINC dataset
with preprocessed vocabulary, and save model checkpoint at the current working directory.
```
-s SAVE_PATH, --save_dir SAVE_PATH
Path to save checkpoint models, default to be current
working directory (default: ./)
-m MODEL_PATH, --model MODEL_PATH
Path to load pre-trained model (default: None)
-b BATCH_SIZE, --batch BATCH_SIZE
Batch size (default: 40)
-w HIDDEN_SIZE, --hidden HIDDEN_SIZE
Size of representation vectors (default: 200)
-l LATENT_SIZE, --latent LATENT_SIZE
Latent Size of node(atom) features and edge(atom)
features (default: 56)
-d DEPTH, --depth DEPTH
Depth of message passing hops (default: 3)
-z BETA, --beta BETA Coefficient of KL Divergence term (default: 1.0)
-q LR, --lr LR Learning Rate (default: 0.001)
-T, --test Add this flag to run test mode (default: False)
```
Model will be saved periodically.
All training checkpoint will be stored at `SAVE_PATH`, passed by command line or by default.
#### Dataset configuration
If you want to use your own dataset, please create a file contains one SMILES a line,
and pass the file path to the `-t` or `--train` option.
```
-t TRAIN, --train TRAIN
Training file name (default: train)
```
### Evaluation
To start evaluation, use `python reconstruct_eval.py`, and following arguments
```
-t TRAIN, --train TRAIN
Training file name (default: test)
-m MODEL_PATH, --model MODEL_PATH
Pre-trained model to be loaded for evalutaion. If not
specified, would use pre-trained model from model zoo
(default: None)
-w HIDDEN_SIZE, --hidden HIDDEN_SIZE
Hidden size of representation vector, should be
consistent with pre-trained model (default: 450)
-l LATENT_SIZE, --latent LATENT_SIZE
Latent Size of node(atom) features and edge(atom)
features, should be consistent with pre-trained model
(default: 56)
-d DEPTH, --depth DEPTH
Depth of message passing hops, should be consistent
with pre-trained model (default: 3)
```
And it would print out the success rate of reconstructing the same molecules.
### Pre-trained models
Below gives the statistics of pre-trained `JTNN_ZINC` model.
| Pre-trained model | % Reconstruction Accuracy
| ------------------ | -------
| `JTNN_ZINC` | 73.7
### Visualization
Here we draw some "neighbor" of a given molecule, by adding noises on the intermediate representations. Detailed script can be found [here](https://s3.us-east-2.amazonaws.com/dgl.ai/model_zoo/drug_discovery/jtnn/viz_neighbor_mol.ipynb). Please put this script at the current directory (`examples/pytorch/model_zoo/chem/generative_models/jtnn/`).
#### Given Molecule
![image](https://user-images.githubusercontent.com/8686776/63773593-0d37da00-c90e-11e9-8933-0abca4b430db.png)
#### Neighbor Molecules
![image](https://user-images.githubusercontent.com/8686776/63773602-1163f780-c90e-11e9-8341-5122dc0d0c82.png)
### Warnings from PyTorch 1.2
If you are using PyTorch 1.2, there might be warning saying
`UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.`. This is due to the new feature in PyTorch 1.2. Please kindly ignore it.
\ No newline at end of file
from .mol_tree import Vocab
from .nnutils import cuda
from .datautils import JTNNDataset, JTNNCollator
from .chemutils import decode_stereo
import rdkit
import rdkit.Chem as Chem
import torch
from scipy.sparse import csr_matrix
from scipy.sparse.csgraph import minimum_spanning_tree
from collections import defaultdict
from rdkit.Chem.EnumerateStereoisomers import EnumerateStereoisomers, StereoEnumerationOptions
from dgl import DGLGraph
ELEM_LIST = ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na',
'Ca', 'Fe', 'Al', 'I', 'B', 'K', 'Se', 'Zn', 'H', 'Cu', 'Mn', 'unknown']
MST_MAX_WEIGHT = 100
MAX_NCAND = 2000
def onek_encoding_unk(x, allowable_set):
if x not in allowable_set:
x = allowable_set[-1]
return [x == s for s in allowable_set]
def set_atommap(mol, num=0):
for atom in mol.GetAtoms():
atom.SetAtomMapNum(num)
def get_mol(smiles):
mol = Chem.MolFromSmiles(smiles)
if mol is None:
return None
Chem.Kekulize(mol)
return mol
def get_smiles(mol):
return Chem.MolToSmiles(mol, kekuleSmiles=True)
def decode_stereo(smiles2D):
mol = Chem.MolFromSmiles(smiles2D)
dec_isomers = list(EnumerateStereoisomers(mol))
dec_isomers = [Chem.MolFromSmiles(Chem.MolToSmiles(
mol, isomericSmiles=True)) for mol in dec_isomers]
smiles3D = [Chem.MolToSmiles(mol, isomericSmiles=True)
for mol in dec_isomers]
chiralN = [atom.GetIdx() for atom in dec_isomers[0].GetAtoms() if int(
atom.GetChiralTag()) > 0 and atom.GetSymbol() == "N"]
if len(chiralN) > 0:
for mol in dec_isomers:
for idx in chiralN:
mol.GetAtomWithIdx(idx).SetChiralTag(
Chem.rdchem.ChiralType.CHI_UNSPECIFIED)
smiles3D.append(Chem.MolToSmiles(mol, isomericSmiles=True))
return smiles3D
def sanitize(mol):
try:
smiles = get_smiles(mol)
mol = get_mol(smiles)
except Exception as e:
return None
return mol
def copy_atom(atom):
new_atom = Chem.Atom(atom.GetSymbol())
new_atom.SetFormalCharge(atom.GetFormalCharge())
new_atom.SetAtomMapNum(atom.GetAtomMapNum())
return new_atom
def copy_edit_mol(mol):
new_mol = Chem.RWMol(Chem.MolFromSmiles(''))
for atom in mol.GetAtoms():
new_atom = copy_atom(atom)
new_mol.AddAtom(new_atom)
for bond in mol.GetBonds():
a1 = bond.GetBeginAtom().GetIdx()
a2 = bond.GetEndAtom().GetIdx()
bt = bond.GetBondType()
new_mol.AddBond(a1, a2, bt)
return new_mol
def get_clique_mol(mol, atoms):
smiles = Chem.MolFragmentToSmiles(mol, atoms, kekuleSmiles=True)
new_mol = Chem.MolFromSmiles(smiles, sanitize=False)
new_mol = copy_edit_mol(new_mol).GetMol()
new_mol = sanitize(new_mol) # We assume this is not None
return new_mol
def tree_decomp(mol):
n_atoms = mol.GetNumAtoms()
if n_atoms == 1:
return [[0]], []
cliques = []
for bond in mol.GetBonds():
a1 = bond.GetBeginAtom().GetIdx()
a2 = bond.GetEndAtom().GetIdx()
if not bond.IsInRing():
cliques.append([a1, a2])
ssr = [list(x) for x in Chem.GetSymmSSSR(mol)]
cliques.extend(ssr)
nei_list = [[] for i in range(n_atoms)]
for i in range(len(cliques)):
for atom in cliques[i]:
nei_list[atom].append(i)
# Merge Rings with intersection > 2 atoms
for i in range(len(cliques)):
if len(cliques[i]) <= 2:
continue
for atom in cliques[i]:
for j in nei_list[atom]:
if i >= j or len(cliques[j]) <= 2:
continue
inter = set(cliques[i]) & set(cliques[j])
if len(inter) > 2:
cliques[i].extend(cliques[j])
cliques[i] = list(set(cliques[i]))
cliques[j] = []
cliques = [c for c in cliques if len(c) > 0]
nei_list = [[] for i in range(n_atoms)]
for i in range(len(cliques)):
for atom in cliques[i]:
nei_list[atom].append(i)
# Build edges and add singleton cliques
edges = defaultdict(int)
for atom in range(n_atoms):
if len(nei_list[atom]) <= 1:
continue
cnei = nei_list[atom]
bonds = [c for c in cnei if len(cliques[c]) == 2]
rings = [c for c in cnei if len(cliques[c]) > 4]
# In general, if len(cnei) >= 3, a singleton should be added, but 1 bond + 2 ring is currently not dealt with.
if len(bonds) > 2 or (len(bonds) == 2 and len(cnei) > 2):
cliques.append([atom])
c2 = len(cliques) - 1
for c1 in cnei:
edges[(c1, c2)] = 1
elif len(rings) > 2: # Multiple (n>2) complex rings
cliques.append([atom])
c2 = len(cliques) - 1
for c1 in cnei:
edges[(c1, c2)] = MST_MAX_WEIGHT - 1
else:
for i in range(len(cnei)):
for j in range(i + 1, len(cnei)):
c1, c2 = cnei[i], cnei[j]
inter = set(cliques[c1]) & set(cliques[c2])
if edges[(c1, c2)] < len(inter):
# cnei[i] < cnei[j] by construction
edges[(c1, c2)] = len(inter)
edges = [u + (MST_MAX_WEIGHT-v,) for u, v in edges.items()]
if len(edges) == 0:
return cliques, edges
# Compute Maximum Spanning Tree
row, col, data = list(zip(*edges))
n_clique = len(cliques)
clique_graph = csr_matrix((data, (row, col)), shape=(n_clique, n_clique))
junc_tree = minimum_spanning_tree(clique_graph)
row, col = junc_tree.nonzero()
edges = [(row[i], col[i]) for i in range(len(row))]
return (cliques, edges)
def atom_equal(a1, a2):
return a1.GetSymbol() == a2.GetSymbol() and a1.GetFormalCharge() == a2.GetFormalCharge()
# Bond type not considered because all aromatic (so SINGLE matches DOUBLE)
def ring_bond_equal(b1, b2, reverse=False):
b1 = (b1.GetBeginAtom(), b1.GetEndAtom())
if reverse:
b2 = (b2.GetEndAtom(), b2.GetBeginAtom())
else:
b2 = (b2.GetBeginAtom(), b2.GetEndAtom())
return atom_equal(b1[0], b2[0]) and atom_equal(b1[1], b2[1])
def attach_mols_nx(ctr_mol, neighbors, prev_nodes, nei_amap):
prev_nids = [node['nid'] for node in prev_nodes]
for nei_node in prev_nodes + neighbors:
nei_id, nei_mol = nei_node['nid'], nei_node['mol']
amap = nei_amap[nei_id]
for atom in nei_mol.GetAtoms():
if atom.GetIdx() not in amap:
new_atom = copy_atom(atom)
amap[atom.GetIdx()] = ctr_mol.AddAtom(new_atom)
if nei_mol.GetNumBonds() == 0:
nei_atom = nei_mol.GetAtomWithIdx(0)
ctr_atom = ctr_mol.GetAtomWithIdx(amap[0])
ctr_atom.SetAtomMapNum(nei_atom.GetAtomMapNum())
else:
for bond in nei_mol.GetBonds():
a1 = amap[bond.GetBeginAtom().GetIdx()]
a2 = amap[bond.GetEndAtom().GetIdx()]
if ctr_mol.GetBondBetweenAtoms(a1, a2) is None:
ctr_mol.AddBond(a1, a2, bond.GetBondType())
elif nei_id in prev_nids: # father node overrides
ctr_mol.RemoveBond(a1, a2)
ctr_mol.AddBond(a1, a2, bond.GetBondType())
return ctr_mol
def local_attach_nx(ctr_mol, neighbors, prev_nodes, amap_list):
ctr_mol = copy_edit_mol(ctr_mol)
nei_amap = {nei['nid']: {} for nei in prev_nodes + neighbors}
for nei_id, ctr_atom, nei_atom in amap_list:
nei_amap[nei_id][nei_atom] = ctr_atom
ctr_mol = attach_mols_nx(ctr_mol, neighbors, prev_nodes, nei_amap)
return ctr_mol.GetMol()
# This version records idx mapping between ctr_mol and nei_mol
def enum_attach_nx(ctr_mol, nei_node, amap, singletons):
nei_mol, nei_idx = nei_node['mol'], nei_node['nid']
att_confs = []
black_list = [atom_idx for nei_id, atom_idx,
_ in amap if nei_id in singletons]
ctr_atoms = [atom for atom in ctr_mol.GetAtoms() if atom.GetIdx()
not in black_list]
ctr_bonds = [bond for bond in ctr_mol.GetBonds()]
if nei_mol.GetNumBonds() == 0: # neighbor singleton
nei_atom = nei_mol.GetAtomWithIdx(0)
used_list = [atom_idx for _, atom_idx, _ in amap]
for atom in ctr_atoms:
if atom_equal(atom, nei_atom) and atom.GetIdx() not in used_list:
new_amap = amap + [(nei_idx, atom.GetIdx(), 0)]
att_confs.append(new_amap)
elif nei_mol.GetNumBonds() == 1: # neighbor is a bond
bond = nei_mol.GetBondWithIdx(0)
bond_val = int(bond.GetBondTypeAsDouble())
b1, b2 = bond.GetBeginAtom(), bond.GetEndAtom()
for atom in ctr_atoms:
# Optimize if atom is carbon (other atoms may change valence)
if atom.GetAtomicNum() == 6 and atom.GetTotalNumHs() < bond_val:
continue
if atom_equal(atom, b1):
new_amap = amap + [(nei_idx, atom.GetIdx(), b1.GetIdx())]
att_confs.append(new_amap)
elif atom_equal(atom, b2):
new_amap = amap + [(nei_idx, atom.GetIdx(), b2.GetIdx())]
att_confs.append(new_amap)
else:
# intersection is an atom
for a1 in ctr_atoms:
for a2 in nei_mol.GetAtoms():
if atom_equal(a1, a2):
# Optimize if atom is carbon (other atoms may change valence)
if a1.GetAtomicNum() == 6 and a1.GetTotalNumHs() + a2.GetTotalNumHs() < 4:
continue
new_amap = amap + [(nei_idx, a1.GetIdx(), a2.GetIdx())]
att_confs.append(new_amap)
# intersection is an bond
if ctr_mol.GetNumBonds() > 1:
for b1 in ctr_bonds:
for b2 in nei_mol.GetBonds():
if ring_bond_equal(b1, b2):
new_amap = amap + [(nei_idx, b1.GetBeginAtom().GetIdx(), b2.GetBeginAtom(
).GetIdx()), (nei_idx, b1.GetEndAtom().GetIdx(), b2.GetEndAtom().GetIdx())]
att_confs.append(new_amap)
if ring_bond_equal(b1, b2, reverse=True):
new_amap = amap + [(nei_idx, b1.GetBeginAtom().GetIdx(), b2.GetEndAtom(
).GetIdx()), (nei_idx, b1.GetEndAtom().GetIdx(), b2.GetBeginAtom().GetIdx())]
att_confs.append(new_amap)
return att_confs
# Try rings first: Speed-Up
def enum_assemble_nx(node, neighbors, prev_nodes=[], prev_amap=[]):
all_attach_confs = []
singletons = [nei_node['nid'] for nei_node in neighbors +
prev_nodes if nei_node['mol'].GetNumAtoms() == 1]
def search(cur_amap, depth):
if len(all_attach_confs) > MAX_NCAND:
return
if depth == len(neighbors):
all_attach_confs.append(cur_amap)
return
nei_node = neighbors[depth]
cand_amap = enum_attach_nx(node['mol'], nei_node, cur_amap, singletons)
cand_smiles = set()
candidates = []
for amap in cand_amap:
cand_mol = local_attach_nx(
node['mol'], neighbors[:depth+1], prev_nodes, amap)
cand_mol = sanitize(cand_mol)
if cand_mol is None:
continue
smiles = get_smiles(cand_mol)
if smiles in cand_smiles:
continue
cand_smiles.add(smiles)
candidates.append(amap)
if len(candidates) == 0:
return []
for new_amap in candidates:
search(new_amap, depth + 1)
search(prev_amap, 0)
cand_smiles = set()
candidates = []
for amap in all_attach_confs:
cand_mol = local_attach_nx(node['mol'], neighbors, prev_nodes, amap)
cand_mol = Chem.MolFromSmiles(Chem.MolToSmiles(cand_mol))
smiles = Chem.MolToSmiles(cand_mol)
if smiles in cand_smiles:
continue
cand_smiles.add(smiles)
Chem.Kekulize(cand_mol)
candidates.append((smiles, cand_mol, amap))
return candidates
# Only used for debugging purpose
def dfs_assemble_nx(graph, cur_mol, global_amap, fa_amap, cur_node_id, fa_node_id):
cur_node = graph.nodes_dict[cur_node_id]
fa_node = graph.nodes_dict[fa_node_id] if fa_node_id is not None else None
fa_nid = fa_node['nid'] if fa_node is not None else -1
prev_nodes = [fa_node] if fa_node is not None else []
children_id = [nei for nei in graph[cur_node_id]
if graph.nodes_dict[nei]['nid'] != fa_nid]
children = [graph.nodes_dict[nei] for nei in children_id]
neighbors = [nei for nei in children if nei['mol'].GetNumAtoms() > 1]
neighbors = sorted(
neighbors, key=lambda x: x['mol'].GetNumAtoms(), reverse=True)
singletons = [nei for nei in children if nei['mol'].GetNumAtoms() == 1]
neighbors = singletons + neighbors
cur_amap = [(fa_nid, a2, a1)
for nid, a1, a2 in fa_amap if nid == cur_node['nid']]
cands = enum_assemble_nx(
graph.nodes_dict[cur_node_id], neighbors, prev_nodes, cur_amap)
if len(cands) == 0:
return
cand_smiles, _, cand_amap = zip(*cands)
label_idx = cand_smiles.index(cur_node['label'])
label_amap = cand_amap[label_idx]
for nei_id, ctr_atom, nei_atom in label_amap:
if nei_id == fa_nid:
continue
global_amap[nei_id][nei_atom] = global_amap[cur_node['nid']][ctr_atom]
# father is already attached
cur_mol = attach_mols_nx(cur_mol, children, [], global_amap)
for nei_node_id, nei_node in zip(children_id, children):
if not nei_node['is_leaf']:
dfs_assemble_nx(graph, cur_mol, global_amap,
label_amap, nei_node_id, cur_node_id)
def mol2dgl_dec(cand_batch):
# Note that during graph decoding they don't predict stereochemistry-related
# characteristics (i.e. Chiral Atoms, E-Z, Cis-Trans). Instead, they decode
# the 2-D graph first, then enumerate all possible 3-D forms and find the
# one with highest score.
def atom_features(atom):
return (torch.Tensor(onek_encoding_unk(atom.GetSymbol(), ELEM_LIST)
+ onek_encoding_unk(atom.GetDegree(),
[0, 1, 2, 3, 4, 5])
+ onek_encoding_unk(atom.GetFormalCharge(),
[-1, -2, 1, 2, 0])
+ [atom.GetIsAromatic()]))
def bond_features(bond):
bt = bond.GetBondType()
return (torch.Tensor([bt == Chem.rdchem.BondType.SINGLE, bt == Chem.rdchem.BondType.DOUBLE, bt == Chem.rdchem.BondType.TRIPLE, bt == Chem.rdchem.BondType.AROMATIC, bond.IsInRing()]))
cand_graphs = []
tree_mess_source_edges = [] # map these edges from trees to...
tree_mess_target_edges = [] # these edges on candidate graphs
tree_mess_target_nodes = []
n_nodes = 0
n_edges = 0
atom_x = []
bond_x = []
for mol, mol_tree, ctr_node_id in cand_batch:
n_atoms = mol.GetNumAtoms()
n_bonds = mol.GetNumBonds()
ctr_node = mol_tree.nodes_dict[ctr_node_id]
ctr_bid = ctr_node['idx']
g = DGLGraph()
for i, atom in enumerate(mol.GetAtoms()):
assert i == atom.GetIdx()
atom_x.append(atom_features(atom))
g.add_nodes(n_atoms)
bond_src = []
bond_dst = []
for i, bond in enumerate(mol.GetBonds()):
a1 = bond.GetBeginAtom()
a2 = bond.GetEndAtom()
begin_idx = a1.GetIdx()
end_idx = a2.GetIdx()
features = bond_features(bond)
bond_src.append(begin_idx)
bond_dst.append(end_idx)
bond_x.append(features)
bond_src.append(end_idx)
bond_dst.append(begin_idx)
bond_x.append(features)
x_nid, y_nid = a1.GetAtomMapNum(), a2.GetAtomMapNum()
# Tree node ID in the batch
x_bid = mol_tree.nodes_dict[x_nid - 1]['idx'] if x_nid > 0 else -1
y_bid = mol_tree.nodes_dict[y_nid - 1]['idx'] if y_nid > 0 else -1
if x_bid >= 0 and y_bid >= 0 and x_bid != y_bid:
if mol_tree.has_edge_between(x_bid, y_bid):
tree_mess_target_edges.append(
(begin_idx + n_nodes, end_idx + n_nodes))
tree_mess_source_edges.append((x_bid, y_bid))
tree_mess_target_nodes.append(end_idx + n_nodes)
if mol_tree.has_edge_between(y_bid, x_bid):
tree_mess_target_edges.append(
(end_idx + n_nodes, begin_idx + n_nodes))
tree_mess_source_edges.append((y_bid, x_bid))
tree_mess_target_nodes.append(begin_idx + n_nodes)
n_nodes += n_atoms
g.add_edges(bond_src, bond_dst)
cand_graphs.append(g)
return cand_graphs, torch.stack(atom_x), \
torch.stack(bond_x) if len(bond_x) > 0 else torch.zeros(0), \
torch.LongTensor(tree_mess_source_edges), \
torch.LongTensor(tree_mess_target_edges), \
torch.LongTensor(tree_mess_target_nodes)
def mol2dgl_enc(smiles):
def atom_features(atom):
return (torch.Tensor(onek_encoding_unk(atom.GetSymbol(), ELEM_LIST)
+ onek_encoding_unk(atom.GetDegree(),
[0, 1, 2, 3, 4, 5])
+ onek_encoding_unk(atom.GetFormalCharge(), [-1, -2, 1, 2, 0])
+ onek_encoding_unk(int(atom.GetChiralTag()), [0, 1, 2, 3])
+ [atom.GetIsAromatic()]))
def bond_features(bond):
bt = bond.GetBondType()
stereo = int(bond.GetStereo())
fbond = [bt == Chem.rdchem.BondType.SINGLE, bt == Chem.rdchem.BondType.DOUBLE, bt ==
Chem.rdchem.BondType.TRIPLE, bt == Chem.rdchem.BondType.AROMATIC, bond.IsInRing()]
fstereo = onek_encoding_unk(stereo, [0, 1, 2, 3, 4, 5])
return (torch.Tensor(fbond + fstereo))
n_edges = 0
atom_x = []
bond_x = []
mol = get_mol(smiles)
n_atoms = mol.GetNumAtoms()
n_bonds = mol.GetNumBonds()
graph = DGLGraph()
for i, atom in enumerate(mol.GetAtoms()):
assert i == atom.GetIdx()
atom_x.append(atom_features(atom))
graph.add_nodes(n_atoms)
bond_src = []
bond_dst = []
for i, bond in enumerate(mol.GetBonds()):
begin_idx = bond.GetBeginAtom().GetIdx()
end_idx = bond.GetEndAtom().GetIdx()
features = bond_features(bond)
bond_src.append(begin_idx)
bond_dst.append(end_idx)
bond_x.append(features)
# set up the reverse direction
bond_src.append(end_idx)
bond_dst.append(begin_idx)
bond_x.append(features)
graph.add_edges(bond_src, bond_dst)
n_edges += n_bonds
return graph, torch.stack(atom_x), \
torch.stack(bond_x) if len(bond_x) > 0 else torch.zeros(0)
import torch
from torch.utils.data import Dataset
import dgl
from dgl.data.utils import download, extract_archive, get_download_dir
import os
from .mol_tree import Vocab, DGLMolTree
from .chemutils import mol2dgl_dec, mol2dgl_enc
ELEM_LIST = ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na', 'Ca', 'Fe', 'Al', 'I', 'B', 'K', 'Se', 'Zn', 'H', 'Cu', 'Mn', 'unknown']
ATOM_FDIM_DEC = len(ELEM_LIST) + 6 + 5 + 1
BOND_FDIM_DEC = 5
MAX_NB = 10
PAPER = os.getenv('PAPER', False)
_url = 'https://s3-ap-southeast-1.amazonaws.com/dgl-data-cn/dataset/jtnn.zip'
def _unpack_field(examples, field):
return [e[field] for e in examples]
def _set_node_id(mol_tree, vocab):
wid = []
for i, node in enumerate(mol_tree.nodes_dict):
mol_tree.nodes_dict[node]['idx'] = i
wid.append(vocab.get_index(mol_tree.nodes_dict[node]['smiles']))
return wid
class JTNNDataset(Dataset):
def __init__(self, data, vocab, training=True):
self.dir = get_download_dir()
self.zip_file_path='{}/jtnn.zip'.format(self.dir)
download(_url, path=self.zip_file_path)
extract_archive(self.zip_file_path, '{}/jtnn'.format(self.dir))
print('Loading data...')
if data in ['train', 'test']:
data_file = '{}/jtnn/{}.txt'.format(self.dir, data)
else:
data_file = data
with open(data_file) as f:
self.data = [line.strip("\r\n ").split()[0] for line in f]
self.vocab_file = '{}/jtnn/{}.txt'.format(self.dir, vocab)
print('Loading finished.')
print('\tNum samples:', len(self.data))
print('\tVocab file:', self.vocab_file)
self.training = training
self.vocab = Vocab([x.strip("\r\n ") for x in open(self.vocab_file)])
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
smiles = self.data[idx]
mol_tree = DGLMolTree(smiles)
mol_tree.recover()
mol_tree.assemble()
wid = _set_node_id(mol_tree, self.vocab)
# prebuild the molecule graph
mol_graph, atom_x_enc, bond_x_enc = mol2dgl_enc(mol_tree.smiles)
result = {
'mol_tree': mol_tree,
'mol_graph': mol_graph,
'atom_x_enc': atom_x_enc,
'bond_x_enc': bond_x_enc,
'wid': wid,
}
if not self.training:
return result
# prebuild the candidate graph list
cands = []
for node_id, node in mol_tree.nodes_dict.items():
# fill in ground truth
if node['label'] not in node['cands']:
node['cands'].append(node['label'])
node['cand_mols'].append(node['label_mol'])
if node['is_leaf'] or len(node['cands']) == 1:
continue
cands.extend([(cand, mol_tree, node_id)
for cand in node['cand_mols']])
if len(cands) > 0:
cand_graphs, atom_x_dec, bond_x_dec, tree_mess_src_e, \
tree_mess_tgt_e, tree_mess_tgt_n = mol2dgl_dec(cands)
else:
cand_graphs = []
atom_x_dec = torch.zeros(0, ATOM_FDIM_DEC)
bond_x_dec = torch.zeros(0, BOND_FDIM_DEC)
tree_mess_src_e = torch.zeros(0, 2).long()
tree_mess_tgt_e = torch.zeros(0, 2).long()
tree_mess_tgt_n = torch.zeros(0).long()
# prebuild the stereoisomers
cands = mol_tree.stereo_cands
if len(cands) > 1:
if mol_tree.smiles3D not in cands:
cands.append(mol_tree.smiles3D)
stereo_graphs = [mol2dgl_enc(c) for c in cands]
stereo_cand_graphs, stereo_atom_x_enc, stereo_bond_x_enc = \
zip(*stereo_graphs)
stereo_atom_x_enc = torch.cat(stereo_atom_x_enc)
stereo_bond_x_enc = torch.cat(stereo_bond_x_enc)
stereo_cand_label = [(cands.index(mol_tree.smiles3D), len(cands))]
else:
stereo_cand_graphs = []
stereo_atom_x_enc = torch.zeros(0, atom_x_enc.shape[1])
stereo_bond_x_enc = torch.zeros(0, bond_x_enc.shape[1])
stereo_cand_label = []
result.update({
'cand_graphs': cand_graphs,
'atom_x_dec': atom_x_dec,
'bond_x_dec': bond_x_dec,
'tree_mess_src_e': tree_mess_src_e,
'tree_mess_tgt_e': tree_mess_tgt_e,
'tree_mess_tgt_n': tree_mess_tgt_n,
'stereo_cand_graphs': stereo_cand_graphs,
'stereo_atom_x_enc': stereo_atom_x_enc,
'stereo_bond_x_enc': stereo_bond_x_enc,
'stereo_cand_label': stereo_cand_label,
})
return result
class JTNNCollator(object):
def __init__(self, vocab, training):
self.vocab = vocab
self.training = training
@staticmethod
def _batch_and_set(graphs, atom_x, bond_x, flatten):
if flatten:
graphs = [g for f in graphs for g in f]
graph_batch = dgl.batch(graphs)
graph_batch.ndata['x'] = atom_x
graph_batch.edata.update({
'x': bond_x,
'src_x': atom_x.new(bond_x.shape[0], atom_x.shape[1]).zero_(),
})
return graph_batch
def __call__(self, examples):
# get list of trees
mol_trees = _unpack_field(examples, 'mol_tree')
wid = _unpack_field(examples, 'wid')
for _wid, mol_tree in zip(wid, mol_trees):
mol_tree.ndata['wid'] = torch.LongTensor(_wid)
# TODO: either support pickling or get around ctypes pointers using scipy
# batch molecule graphs
mol_graphs = _unpack_field(examples, 'mol_graph')
atom_x = torch.cat(_unpack_field(examples, 'atom_x_enc'))
bond_x = torch.cat(_unpack_field(examples, 'bond_x_enc'))
mol_graph_batch = self._batch_and_set(mol_graphs, atom_x, bond_x, False)
result = {
'mol_trees': mol_trees,
'mol_graph_batch': mol_graph_batch,
}
if not self.training:
return result
# batch candidate graphs
cand_graphs = _unpack_field(examples, 'cand_graphs')
cand_batch_idx = []
atom_x = torch.cat(_unpack_field(examples, 'atom_x_dec'))
bond_x = torch.cat(_unpack_field(examples, 'bond_x_dec'))
tree_mess_src_e = _unpack_field(examples, 'tree_mess_src_e')
tree_mess_tgt_e = _unpack_field(examples, 'tree_mess_tgt_e')
tree_mess_tgt_n = _unpack_field(examples, 'tree_mess_tgt_n')
n_graph_nodes = 0
n_tree_nodes = 0
for i in range(len(cand_graphs)):
tree_mess_tgt_e[i] += n_graph_nodes
tree_mess_src_e[i] += n_tree_nodes
tree_mess_tgt_n[i] += n_graph_nodes
n_graph_nodes += sum(g.number_of_nodes() for g in cand_graphs[i])
n_tree_nodes += mol_trees[i].number_of_nodes()
cand_batch_idx.extend([i] * len(cand_graphs[i]))
tree_mess_tgt_e = torch.cat(tree_mess_tgt_e)
tree_mess_src_e = torch.cat(tree_mess_src_e)
tree_mess_tgt_n = torch.cat(tree_mess_tgt_n)
cand_graph_batch = self._batch_and_set(cand_graphs, atom_x, bond_x, True)
# batch stereoisomers
stereo_cand_graphs = _unpack_field(examples, 'stereo_cand_graphs')
atom_x = torch.cat(_unpack_field(examples, 'stereo_atom_x_enc'))
bond_x = torch.cat(_unpack_field(examples, 'stereo_bond_x_enc'))
stereo_cand_batch_idx = []
for i in range(len(stereo_cand_graphs)):
stereo_cand_batch_idx.extend([i] * len(stereo_cand_graphs[i]))
if len(stereo_cand_batch_idx) > 0:
stereo_cand_labels = [
(label, length)
for ex in _unpack_field(examples, 'stereo_cand_label')
for label, length in ex
]
stereo_cand_labels, stereo_cand_lengths = zip(*stereo_cand_labels)
stereo_cand_graph_batch = self._batch_and_set(
stereo_cand_graphs, atom_x, bond_x, True)
else:
stereo_cand_labels = []
stereo_cand_lengths = []
stereo_cand_graph_batch = None
stereo_cand_batch_idx = []
result.update({
'cand_graph_batch': cand_graph_batch,
'cand_batch_idx': cand_batch_idx,
'tree_mess_tgt_e': tree_mess_tgt_e,
'tree_mess_src_e': tree_mess_src_e,
'tree_mess_tgt_n': tree_mess_tgt_n,
'stereo_cand_graph_batch': stereo_cand_graph_batch,
'stereo_cand_batch_idx': stereo_cand_batch_idx,
'stereo_cand_labels': stereo_cand_labels,
'stereo_cand_lengths': stereo_cand_lengths,
})
return result
import rdkit
import rdkit.Chem as Chem
import copy
import numpy as np
from dgl import DGLGraph
import rdkit.Chem as Chem
from .chemutils import get_clique_mol, tree_decomp, get_mol, get_smiles, \
set_atommap, enum_assemble_nx, decode_stereo
def get_slots(smiles):
mol = Chem.MolFromSmiles(smiles)
return [(atom.GetSymbol(), atom.GetFormalCharge(), atom.GetTotalNumHs()) for atom in mol.GetAtoms()]
class Vocab(object):
def __init__(self, smiles_list):
self.vocab = smiles_list
self.vmap = {x:i for i,x in enumerate(self.vocab)}
self.slots = [get_slots(smiles) for smiles in self.vocab]
def get_index(self, smiles):
return self.vmap[smiles]
def get_smiles(self, idx):
return self.vocab[idx]
def get_slots(self, idx):
return copy.deepcopy(self.slots[idx])
def size(self):
return len(self.vocab)
class DGLMolTree(DGLGraph):
def __init__(self, smiles):
DGLGraph.__init__(self)
self.nodes_dict = {}
if smiles is None:
return
self.smiles = smiles
self.mol = get_mol(smiles)
# Stereo Generation
mol = Chem.MolFromSmiles(smiles)
self.smiles3D = Chem.MolToSmiles(mol, isomericSmiles=True)
self.smiles2D = Chem.MolToSmiles(mol)
self.stereo_cands = decode_stereo(self.smiles2D)
# cliques: a list of list of atom indices
cliques, edges = tree_decomp(self.mol)
root = 0
for i, c in enumerate(cliques):
cmol = get_clique_mol(self.mol, c)
csmiles = get_smiles(cmol)
self.nodes_dict[i] = dict(
smiles=csmiles,
mol=get_mol(csmiles),
clique=c,
)
if min(c) == 0:
root = i
self.add_nodes(len(cliques))
# The clique with atom ID 0 becomes root
if root > 0:
for attr in self.nodes_dict[0]:
self.nodes_dict[0][attr], self.nodes_dict[root][attr] = \
self.nodes_dict[root][attr], self.nodes_dict[0][attr]
src = np.zeros((len(edges) * 2,), dtype='int')
dst = np.zeros((len(edges) * 2,), dtype='int')
for i, (_x, _y) in enumerate(edges):
x = 0 if _x == root else root if _x == 0 else _x
y = 0 if _y == root else root if _y == 0 else _y
src[2 * i] = x
dst[2 * i] = y
src[2 * i + 1] = y
dst[2 * i + 1] = x
self.add_edges(src, dst)
for i in self.nodes_dict:
self.nodes_dict[i]['nid'] = i + 1
if self.out_degree(i) > 1: # Leaf node mol is not marked
set_atommap(self.nodes_dict[i]['mol'], self.nodes_dict[i]['nid'])
self.nodes_dict[i]['is_leaf'] = (self.out_degree(i) == 1)
def treesize(self):
return self.number_of_nodes()
def _recover_node(self, i, original_mol):
node = self.nodes_dict[i]
clique = []
clique.extend(node['clique'])
if not node['is_leaf']:
for cidx in node['clique']:
original_mol.GetAtomWithIdx(cidx).SetAtomMapNum(node['nid'])
for j in self.successors(i).numpy():
nei_node = self.nodes_dict[j]
clique.extend(nei_node['clique'])
if nei_node['is_leaf']: # Leaf node, no need to mark
continue
for cidx in nei_node['clique']:
# allow singleton node override the atom mapping
if cidx not in node['clique'] or len(nei_node['clique']) == 1:
atom = original_mol.GetAtomWithIdx(cidx)
atom.SetAtomMapNum(nei_node['nid'])
clique = list(set(clique))
label_mol = get_clique_mol(original_mol, clique)
node['label'] = Chem.MolToSmiles(Chem.MolFromSmiles(get_smiles(label_mol)))
node['label_mol'] = get_mol(node['label'])
for cidx in clique:
original_mol.GetAtomWithIdx(cidx).SetAtomMapNum(0)
return node['label']
def _assemble_node(self, i):
neighbors = [self.nodes_dict[j] for j in self.successors(i).numpy()
if self.nodes_dict[j]['mol'].GetNumAtoms() > 1]
neighbors = sorted(neighbors, key=lambda x: x['mol'].GetNumAtoms(), reverse=True)
singletons = [self.nodes_dict[j] for j in self.successors(i).numpy()
if self.nodes_dict[j]['mol'].GetNumAtoms() == 1]
neighbors = singletons + neighbors
cands = enum_assemble_nx(self.nodes_dict[i], neighbors)
if len(cands) > 0:
self.nodes_dict[i]['cands'], self.nodes_dict[i]['cand_mols'], _ = list(zip(*cands))
self.nodes_dict[i]['cands'] = list(self.nodes_dict[i]['cands'])
self.nodes_dict[i]['cand_mols'] = list(self.nodes_dict[i]['cand_mols'])
else:
self.nodes_dict[i]['cands'] = []
self.nodes_dict[i]['cand_mols'] = []
def recover(self):
for i in self.nodes_dict:
self._recover_node(i, self.mol)
def assemble(self):
for i in self.nodes_dict:
self._assemble_node(i)
\ No newline at end of file
import torch
import torch.nn as nn
import os
def cuda(tensor):
if torch.cuda.is_available() and not os.getenv('NOCUDA', None):
return tensor.cuda()
else:
return tensor
class GRUUpdate(nn.Module):
def __init__(self, hidden_size):
nn.Module.__init__(self)
self.hidden_size = hidden_size
self.W_z = nn.Linear(2 * hidden_size, hidden_size)
self.W_r = nn.Linear(hidden_size, hidden_size, bias=False)
self.U_r = nn.Linear(hidden_size, hidden_size)
self.W_h = nn.Linear(2 * hidden_size, hidden_size)
def update_zm(self, node):
src_x = node.data['src_x']
s = node.data['s']
rm = node.data['accum_rm']
z = torch.sigmoid(self.W_z(torch.cat([src_x, s], 1)))
m = torch.tanh(self.W_h(torch.cat([src_x, rm], 1)))
m = (1 - z) * s + z * m
return {'m': m, 'z': z}
def update_r(self, node, zm=None):
dst_x = node.data['dst_x']
m = node.data['m'] if zm is None else zm['m']
r_1 = self.W_r(dst_x)
r_2 = self.U_r(m)
r = torch.sigmoid(r_1 + r_2)
return {'r': r, 'rm': r * m}
def forward(self, node):
dic = self.update_zm(node)
dic.update(self.update_r(node, zm=dic))
return dic
def move_dgl_to_cuda(g):
g.ndata.update({k: cuda(g.ndata[k]) for k in g.ndata})
g.edata.update({k: cuda(g.edata[k]) for k in g.edata})
import torch
from torch.utils.data import DataLoader
import argparse
from dgl import model_zoo
import rdkit
from jtnn import *
def worker_init_fn(id_):
lg = rdkit.RDLogger.logger()
lg.setLevel(rdkit.RDLogger.CRITICAL)
worker_init_fn(None)
parser = argparse.ArgumentParser(description="Evaluation for JTNN",
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("-t", "--train", dest="train",
default='test', help='Training file name')
parser.add_argument("-v", "--vocab", dest="vocab",
default='vocab', help='Vocab file name')
parser.add_argument("-m", "--model", dest="model_path", default=None,
help="Pre-trained model to be loaded for evalutaion. If not specified,"
" would use pre-trained model from model zoo")
parser.add_argument("-w", "--hidden", dest="hidden_size", default=450,
help="Hidden size of representation vector, "
"should be consistent with pre-trained model")
parser.add_argument("-l", "--latent", dest="latent_size", default=56,
help="Latent Size of node(atom) features and edge(atom) features, "
"should be consistent with pre-trained model")
parser.add_argument("-d", "--depth", dest="depth", default=3,
help="Depth of message passing hops, "
"should be consistent with pre-trained model")
args = parser.parse_args()
dataset = JTNNDataset(data=args.train, vocab=args.vocab, training=False)
vocab_file = dataset.vocab_file
hidden_size = int(args.hidden_size)
latent_size = int(args.latent_size)
depth = int(args.depth)
model = model_zoo.chem.DGLJTNNVAE(vocab_file=vocab_file,
hidden_size=hidden_size,
latent_size=latent_size,
depth=depth)
if args.model_path is not None:
model.load_state_dict(torch.load(args.model_path))
else:
model = model_zoo.chem.load_pretrained("JTNN_ZINC")
model = cuda(model)
model.eval()
print("Model #Params: %dK" %
(sum([x.nelement() for x in model.parameters()]) / 1000,))
MAX_EPOCH = 100
PRINT_ITER = 20
def reconstruct():
dataset.training = False
dataloader = DataLoader(
dataset,
batch_size=1,
shuffle=False,
num_workers=0,
collate_fn=JTNNCollator(dataset.vocab, False),
drop_last=True,
worker_init_fn=worker_init_fn)
# Just an example of molecule decoding; in reality you may want to sample
# tree and molecule vectors.
acc = 0.0
tot = 0
with torch.no_grad():
for it, batch in enumerate(dataloader):
gt_smiles = batch['mol_trees'][0].smiles
# print(gt_smiles)
model.move_to_cuda(batch)
try:
_, tree_vec, mol_vec = model.encode(batch)
tree_mean = model.T_mean(tree_vec)
# Following Mueller et al.
tree_log_var = -torch.abs(model.T_var(tree_vec))
mol_mean = model.G_mean(mol_vec)
# Following Mueller et al.
mol_log_var = -torch.abs(model.G_var(mol_vec))
epsilon = torch.randn(1, model.latent_size // 2).cuda()
tree_vec = tree_mean + torch.exp(tree_log_var // 2) * epsilon
epsilon = torch.randn(1, model.latent_size // 2).cuda()
mol_vec = mol_mean + torch.exp(mol_log_var // 2) * epsilon
dec_smiles = model.decode(tree_vec, mol_vec)
if dec_smiles == gt_smiles:
acc += 1
tot += 1
except Exception as e:
print("Failed to encode: {}".format(gt_smiles))
print(e)
if it % 20 == 1:
print("Progress {}/{}; Current Reconstruction Accuracy: {:.4f}".format(it,
len(dataloader), acc / tot))
return acc / tot
if __name__ == '__main__':
reconstruct_acc = reconstruct()
print("Reconstruction Accuracy: {}".format(reconstruct_acc))
import torch
import torch.nn as nn
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
from dgl import model_zoo
from torch.utils.data import DataLoader
import math, random, sys
import argparse
from collections import deque
import rdkit
from jtnn import *
torch.multiprocessing.set_sharing_strategy('file_system')
def worker_init_fn(id_):
lg = rdkit.RDLogger.logger()
lg.setLevel(rdkit.RDLogger.CRITICAL)
worker_init_fn(None)
parser = argparse.ArgumentParser(description="Training for JTNN",
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("-t", "--train", dest="train", default='train', help='Training file name')
parser.add_argument("-v", "--vocab", dest="vocab", default='vocab', help='Vocab file name')
parser.add_argument("-s", "--save_dir", dest="save_path", default='./',
help="Path to save checkpoint models, default to be current working directory")
parser.add_argument("-m", "--model", dest="model_path", default=None,
help="Path to load pre-trained model")
parser.add_argument("-b", "--batch", dest="batch_size", default=40,
help="Batch size")
parser.add_argument("-w", "--hidden", dest="hidden_size", default=200,
help="Size of representation vectors")
parser.add_argument("-l", "--latent", dest="latent_size", default=56,
help="Latent Size of node(atom) features and edge(atom) features")
parser.add_argument("-d", "--depth", dest="depth", default=3,
help="Depth of message passing hops")
parser.add_argument("-z", "--beta", dest="beta", default=1.0,
help="Coefficient of KL Divergence term")
parser.add_argument("-q", "--lr", dest="lr", default=1e-3,
help="Learning Rate")
parser.add_argument("-T", "--test", dest="test", action="store_true",
help="Add this flag to run test mode")
args = parser.parse_args()
dataset = JTNNDataset(data=args.train, vocab=args.vocab, training=True)
vocab_file = dataset.vocab_file
batch_size = int(args.batch_size)
hidden_size = int(args.hidden_size)
latent_size = int(args.latent_size)
depth = int(args.depth)
beta = float(args.beta)
lr = float(args.lr)
model = model_zoo.chem.DGLJTNNVAE(vocab_file=vocab_file,
hidden_size=hidden_size,
latent_size=latent_size,
depth=depth)
if args.model_path is not None:
model.load_state_dict(torch.load(args.model_path))
else:
for param in model.parameters():
if param.dim() == 1:
nn.init.constant_(param, 0)
else:
nn.init.xavier_normal_(param)
model = cuda(model)
print("Model #Params: %dK" % (sum([x.nelement() for x in model.parameters()]) / 1000,))
optimizer = optim.Adam(model.parameters(), lr=lr)
scheduler = lr_scheduler.ExponentialLR(optimizer, 0.9)
scheduler.step()
MAX_EPOCH = 100
PRINT_ITER = 20
def train():
dataset.training = True
dataloader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=True,
num_workers=4,
collate_fn=JTNNCollator(dataset.vocab, True),
drop_last=True,
worker_init_fn=worker_init_fn)
for epoch in range(MAX_EPOCH):
word_acc, topo_acc, assm_acc, steo_acc = 0, 0, 0, 0
for it, batch in enumerate(dataloader):
model.zero_grad()
try:
loss, kl_div, wacc, tacc, sacc, dacc = model(batch, beta)
except:
print([t.smiles for t in batch['mol_trees']])
raise
loss.backward()
optimizer.step()
word_acc += wacc
topo_acc += tacc
assm_acc += sacc
steo_acc += dacc
if (it + 1) % PRINT_ITER == 0:
word_acc = word_acc / PRINT_ITER * 100
topo_acc = topo_acc / PRINT_ITER * 100
assm_acc = assm_acc / PRINT_ITER * 100
steo_acc = steo_acc / PRINT_ITER * 100
print("KL: %.1f, Word: %.2f, Topo: %.2f, Assm: %.2f, Steo: %.2f, Loss: %.6f" % (
kl_div, word_acc, topo_acc, assm_acc, steo_acc, loss.item()))
word_acc, topo_acc, assm_acc, steo_acc = 0, 0, 0, 0
sys.stdout.flush()
if (it + 1) % 1500 == 0: # Fast annealing
scheduler.step()
print("learning rate: %.6f" % scheduler.get_lr()[0])
torch.save(model.state_dict(),
args.save_path + "/model.iter-%d-%d" % (epoch, it + 1))
scheduler.step()
print("learning rate: %.6f" % scheduler.get_lr()[0])
torch.save(model.state_dict(), args.save_path + "/model.iter-" + str(epoch))
def test():
dataset.training = False
dataloader = DataLoader(
dataset,
batch_size=1,
shuffle=False,
num_workers=0,
collate_fn=JTNNCollator(vocab, False),
drop_last=True,
worker_init_fn=worker_init_fn)
# Just an example of molecule decoding; in reality you may want to sample
# tree and molecule vectors.
for it, batch in enumerate(dataloader):
gt_smiles = batch['mol_trees'][0].smiles
print(gt_smiles)
model.move_to_cuda(batch)
_, tree_vec, mol_vec = model.encode(batch)
tree_vec, mol_vec, _, _ = model.sample(tree_vec, mol_vec)
smiles = model.decode(tree_vec, mol_vec)
print(smiles)
if __name__ == '__main__':
if args.test:
test()
else:
train()
print('# passes:', model.n_passes)
print('Total # nodes processed:', model.n_nodes_total)
print('Total # edges processed:', model.n_edges_total)
print('Total # tree nodes processed:', model.n_tree_nodes_total)
print('Graph decoder: # passes:', model.jtmpn.n_passes)
print('Graph decoder: Total # candidates processed:', model.jtmpn.n_samples_total)
print('Graph decoder: Total # nodes processed:', model.jtmpn.n_nodes_total)
print('Graph decoder: Total # edges processed:', model.jtmpn.n_edges_total)
print('Graph encoder: # passes:', model.mpn.n_passes)
print('Graph encoder: Total # candidates processed:', model.mpn.n_samples_total)
print('Graph encoder: Total # nodes processed:', model.mpn.n_nodes_total)
print('Graph encoder: Total # edges processed:', model.mpn.n_edges_total)
......@@ -6,4 +6,5 @@ from .sch import SchNetModel
from .mgcn import MGCNModel
from .mpnn import MPNNModel
from .dgmg import DGMG
from .jtnn import DGLJTNNVAE
from .pretrain import load_pretrained
"""JTNN Module
"""
from .chemutils import decode_stereo
from .jtnn_vae import DGLJTNNVAE
from .mol_tree import Vocab
from .mpn import DGLMPN
from .nnutils import create_var, cuda
# pylint: disable=C0111, C0103, E1101, W0611, W0612, W0703, C0200, R1710
from collections import defaultdict
import rdkit
import rdkit.Chem as Chem
from rdkit.Chem.EnumerateStereoisomers import (EnumerateStereoisomers,
StereoEnumerationOptions)
from scipy.sparse import csr_matrix
from scipy.sparse.csgraph import minimum_spanning_tree
MST_MAX_WEIGHT = 100
MAX_NCAND = 2000
def set_atommap(mol, num=0):
for atom in mol.GetAtoms():
atom.SetAtomMapNum(num)
def get_mol(smiles):
mol = Chem.MolFromSmiles(smiles)
if mol is None:
return None
Chem.Kekulize(mol)
return mol
def get_smiles(mol):
return Chem.MolToSmiles(mol, kekuleSmiles=True)
def decode_stereo(smiles2D):
mol = Chem.MolFromSmiles(smiles2D)
dec_isomers = list(EnumerateStereoisomers(mol))
dec_isomers = [Chem.MolFromSmiles(Chem.MolToSmiles(
mol, isomericSmiles=True)) for mol in dec_isomers]
smiles3D = [Chem.MolToSmiles(mol, isomericSmiles=True)
for mol in dec_isomers]
chiralN = [atom.GetIdx() for atom in dec_isomers[0].GetAtoms() if int(
atom.GetChiralTag()) > 0 and atom.GetSymbol() == "N"]
if len(chiralN) > 0:
for mol in dec_isomers:
for idx in chiralN:
mol.GetAtomWithIdx(idx).SetChiralTag(
Chem.rdchem.ChiralType.CHI_UNSPECIFIED)
smiles3D.append(Chem.MolToSmiles(mol, isomericSmiles=True))
return smiles3D
def sanitize(mol):
try:
smiles = get_smiles(mol)
mol = get_mol(smiles)
except Exception:
return None
return mol
def copy_atom(atom):
new_atom = Chem.Atom(atom.GetSymbol())
new_atom.SetFormalCharge(atom.GetFormalCharge())
new_atom.SetAtomMapNum(atom.GetAtomMapNum())
return new_atom
def copy_edit_mol(mol):
new_mol = Chem.RWMol(Chem.MolFromSmiles(''))
for atom in mol.GetAtoms():
new_atom = copy_atom(atom)
new_mol.AddAtom(new_atom)
for bond in mol.GetBonds():
a1 = bond.GetBeginAtom().GetIdx()
a2 = bond.GetEndAtom().GetIdx()
bt = bond.GetBondType()
new_mol.AddBond(a1, a2, bt)
return new_mol
def get_clique_mol(mol, atoms):
smiles = Chem.MolFragmentToSmiles(mol, atoms, kekuleSmiles=True)
new_mol = Chem.MolFromSmiles(smiles, sanitize=False)
new_mol = copy_edit_mol(new_mol).GetMol()
new_mol = sanitize(new_mol) # We assume this is not None
return new_mol
def tree_decomp(mol):
n_atoms = mol.GetNumAtoms()
if n_atoms == 1:
return [[0]], []
cliques = []
for bond in mol.GetBonds():
a1 = bond.GetBeginAtom().GetIdx()
a2 = bond.GetEndAtom().GetIdx()
if not bond.IsInRing():
cliques.append([a1, a2])
ssr = [list(x) for x in Chem.GetSymmSSSR(mol)]
cliques.extend(ssr)
nei_list = [[] for i in range(n_atoms)]
for i in range(len(cliques)):
for atom in cliques[i]:
nei_list[atom].append(i)
# Merge Rings with intersection > 2 atoms
for i in range(len(cliques)):
if len(cliques[i]) <= 2:
continue
for atom in cliques[i]:
for j in nei_list[atom]:
if i >= j or len(cliques[j]) <= 2:
continue
inter = set(cliques[i]) & set(cliques[j])
if len(inter) > 2:
cliques[i].extend(cliques[j])
cliques[i] = list(set(cliques[i]))
cliques[j] = []
cliques = [c for c in cliques if len(c) > 0]
nei_list = [[] for i in range(n_atoms)]
for i in range(len(cliques)):
for atom in cliques[i]:
nei_list[atom].append(i)
# Build edges and add singleton cliques
edges = defaultdict(int)
for atom in range(n_atoms):
if len(nei_list[atom]) <= 1:
continue
cnei = nei_list[atom]
bonds = [c for c in cnei if len(cliques[c]) == 2]
rings = [c for c in cnei if len(cliques[c]) > 4]
# In general, if len(cnei) >= 3, a singleton should be added, but 1
# bond + 2 ring is currently not dealt with.
if len(bonds) > 2 or (len(bonds) == 2 and len(cnei) > 2):
cliques.append([atom])
c2 = len(cliques) - 1
for c1 in cnei:
edges[(c1, c2)] = 1
elif len(rings) > 2: # Multiple (n>2) complex rings
cliques.append([atom])
c2 = len(cliques) - 1
for c1 in cnei:
edges[(c1, c2)] = MST_MAX_WEIGHT - 1
else:
for i in range(len(cnei)):
for j in range(i + 1, len(cnei)):
c1, c2 = cnei[i], cnei[j]
inter = set(cliques[c1]) & set(cliques[c2])
if edges[(c1, c2)] < len(inter):
# cnei[i] < cnei[j] by construction
edges[(c1, c2)] = len(inter)
edges = [u + (MST_MAX_WEIGHT - v,) for u, v in edges.items()]
if len(edges) == 0:
return cliques, edges
# Compute Maximum Spanning Tree
row, col, data = list(zip(*edges))
n_clique = len(cliques)
clique_graph = csr_matrix((data, (row, col)), shape=(n_clique, n_clique))
junc_tree = minimum_spanning_tree(clique_graph)
row, col = junc_tree.nonzero()
edges = [(row[i], col[i]) for i in range(len(row))]
return (cliques, edges)
def atom_equal(a1, a2):
return a1.GetSymbol() == a2.GetSymbol() and a1.GetFormalCharge() == a2.GetFormalCharge()
# Bond type not considered because all aromatic (so SINGLE matches DOUBLE)
def ring_bond_equal(b1, b2, reverse=False):
b1 = (b1.GetBeginAtom(), b1.GetEndAtom())
if reverse:
b2 = (b2.GetEndAtom(), b2.GetBeginAtom())
else:
b2 = (b2.GetBeginAtom(), b2.GetEndAtom())
return atom_equal(b1[0], b2[0]) and atom_equal(b1[1], b2[1])
def attach_mols_nx(ctr_mol, neighbors, prev_nodes, nei_amap):
prev_nids = [node['nid'] for node in prev_nodes]
for nei_node in prev_nodes + neighbors:
nei_id, nei_mol = nei_node['nid'], nei_node['mol']
amap = nei_amap[nei_id]
for atom in nei_mol.GetAtoms():
if atom.GetIdx() not in amap:
new_atom = copy_atom(atom)
amap[atom.GetIdx()] = ctr_mol.AddAtom(new_atom)
if nei_mol.GetNumBonds() == 0:
nei_atom = nei_mol.GetAtomWithIdx(0)
ctr_atom = ctr_mol.GetAtomWithIdx(amap[0])
ctr_atom.SetAtomMapNum(nei_atom.GetAtomMapNum())
else:
for bond in nei_mol.GetBonds():
a1 = amap[bond.GetBeginAtom().GetIdx()]
a2 = amap[bond.GetEndAtom().GetIdx()]
if ctr_mol.GetBondBetweenAtoms(a1, a2) is None:
ctr_mol.AddBond(a1, a2, bond.GetBondType())
elif nei_id in prev_nids: # father node overrides
ctr_mol.RemoveBond(a1, a2)
ctr_mol.AddBond(a1, a2, bond.GetBondType())
return ctr_mol
def local_attach_nx(ctr_mol, neighbors, prev_nodes, amap_list):
ctr_mol = copy_edit_mol(ctr_mol)
nei_amap = {nei['nid']: {} for nei in prev_nodes + neighbors}
for nei_id, ctr_atom, nei_atom in amap_list:
nei_amap[nei_id][nei_atom] = ctr_atom
ctr_mol = attach_mols_nx(ctr_mol, neighbors, prev_nodes, nei_amap)
return ctr_mol.GetMol()
# This version records idx mapping between ctr_mol and nei_mol
def enum_attach_nx(ctr_mol, nei_node, amap, singletons):
nei_mol, nei_idx = nei_node['mol'], nei_node['nid']
att_confs = []
black_list = [atom_idx for nei_id, atom_idx,
_ in amap if nei_id in singletons]
ctr_atoms = [atom for atom in ctr_mol.GetAtoms() if atom.GetIdx()
not in black_list]
ctr_bonds = [bond for bond in ctr_mol.GetBonds()]
if nei_mol.GetNumBonds() == 0: # neighbor singleton
nei_atom = nei_mol.GetAtomWithIdx(0)
used_list = [atom_idx for _, atom_idx, _ in amap]
for atom in ctr_atoms:
if atom_equal(atom, nei_atom) and atom.GetIdx() not in used_list:
new_amap = amap + [(nei_idx, atom.GetIdx(), 0)]
att_confs.append(new_amap)
elif nei_mol.GetNumBonds() == 1: # neighbor is a bond
bond = nei_mol.GetBondWithIdx(0)
bond_val = int(bond.GetBondTypeAsDouble())
b1, b2 = bond.GetBeginAtom(), bond.GetEndAtom()
for atom in ctr_atoms:
# Optimize if atom is carbon (other atoms may change valence)
if atom.GetAtomicNum() == 6 and atom.GetTotalNumHs() < bond_val:
continue
if atom_equal(atom, b1):
new_amap = amap + [(nei_idx, atom.GetIdx(), b1.GetIdx())]
att_confs.append(new_amap)
elif atom_equal(atom, b2):
new_amap = amap + [(nei_idx, atom.GetIdx(), b2.GetIdx())]
att_confs.append(new_amap)
else:
# intersection is an atom
for a1 in ctr_atoms:
for a2 in nei_mol.GetAtoms():
if atom_equal(a1, a2):
# Optimize if atom is carbon (other atoms may change
# valence)
if a1.GetAtomicNum() == 6 and a1.GetTotalNumHs() + a2.GetTotalNumHs() < 4:
continue
new_amap = amap + [(nei_idx, a1.GetIdx(), a2.GetIdx())]
att_confs.append(new_amap)
# intersection is an bond
if ctr_mol.GetNumBonds() > 1:
for b1 in ctr_bonds:
for b2 in nei_mol.GetBonds():
if ring_bond_equal(b1, b2):
new_amap = amap + [(nei_idx,
b1.GetBeginAtom().GetIdx(),
b2.GetBeginAtom().GetIdx()),
(nei_idx,
b1.GetEndAtom().GetIdx(),
b2.GetEndAtom().GetIdx())]
att_confs.append(new_amap)
if ring_bond_equal(b1, b2, reverse=True):
new_amap = amap + [(nei_idx,
b1.GetBeginAtom().GetIdx(),
b2.GetEndAtom().GetIdx()),
(nei_idx,
b1.GetEndAtom().GetIdx(),
b2.GetBeginAtom().GetIdx())]
att_confs.append(new_amap)
return att_confs
# Try rings first: Speed-Up
def enum_assemble_nx(node, neighbors, prev_nodes=None, prev_amap=None):
if prev_nodes is None:
prev_nodes = []
if prev_amap is None:
prev_amap = []
all_attach_confs = []
singletons = [nei_node['nid'] for nei_node in neighbors +
prev_nodes if nei_node['mol'].GetNumAtoms() == 1]
def search(cur_amap, depth):
if len(all_attach_confs) > MAX_NCAND:
return None
if depth == len(neighbors):
all_attach_confs.append(cur_amap)
return None
nei_node = neighbors[depth]
cand_amap = enum_attach_nx(node['mol'], nei_node, cur_amap, singletons)
cand_smiles = set()
candidates = []
for amap in cand_amap:
cand_mol = local_attach_nx(
node['mol'], neighbors[:depth + 1], prev_nodes, amap)
cand_mol = sanitize(cand_mol)
if cand_mol is None:
continue
smiles = get_smiles(cand_mol)
if smiles in cand_smiles:
continue
cand_smiles.add(smiles)
candidates.append(amap)
if len(candidates) == 0:
return []
for new_amap in candidates:
search(new_amap, depth + 1)
search(prev_amap, 0)
cand_smiles = set()
candidates = []
for amap in all_attach_confs:
cand_mol = local_attach_nx(node['mol'], neighbors, prev_nodes, amap)
cand_mol = Chem.MolFromSmiles(Chem.MolToSmiles(cand_mol))
smiles = Chem.MolToSmiles(cand_mol)
if smiles in cand_smiles:
continue
cand_smiles.add(smiles)
Chem.Kekulize(cand_mol)
candidates.append((smiles, cand_mol, amap))
return candidates
# Only used for debugging purpose
def dfs_assemble_nx(
graph,
cur_mol,
global_amap,
fa_amap,
cur_node_id,
fa_node_id):
cur_node = graph.nodes_dict[cur_node_id]
fa_node = graph.nodes_dict[fa_node_id] if fa_node_id is not None else None
fa_nid = fa_node['nid'] if fa_node is not None else -1
prev_nodes = [fa_node] if fa_node is not None else []
children_id = [nei for nei in graph[cur_node_id]
if graph.nodes_dict[nei]['nid'] != fa_nid]
children = [graph.nodes_dict[nei] for nei in children_id]
neighbors = [nei for nei in children if nei['mol'].GetNumAtoms() > 1]
neighbors = sorted(
neighbors, key=lambda x: x['mol'].GetNumAtoms(), reverse=True)
singletons = [nei for nei in children if nei['mol'].GetNumAtoms() == 1]
neighbors = singletons + neighbors
cur_amap = [(fa_nid, a2, a1)
for nid, a1, a2 in fa_amap if nid == cur_node['nid']]
cands = enum_assemble_nx(
graph.nodes_dict[cur_node_id], neighbors, prev_nodes, cur_amap)
if len(cands) == 0:
return
cand_smiles, _, cand_amap = zip(*cands)
label_idx = cand_smiles.index(cur_node['label'])
label_amap = cand_amap[label_idx]
for nei_id, ctr_atom, nei_atom in label_amap:
if nei_id == fa_nid:
continue
global_amap[nei_id][nei_atom] = global_amap[cur_node['nid']][ctr_atom]
# father is already attached
cur_mol = attach_mols_nx(cur_mol, children, [], global_amap)
for nei_node_id, nei_node in zip(children_id, children):
if not nei_node['is_leaf']:
dfs_assemble_nx(graph, cur_mol, global_amap,
label_amap, nei_node_id, cur_node_id)
# pylint: disable=C0111, C0103, E1101, W0611, W0612, W1508
# pylint: disable=redefined-outer-name
import os
import rdkit.Chem as Chem
import torch
import torch.nn as nn
import dgl.function as DGLF
from dgl import DGLGraph, mean_nodes
from .chemutils import get_mol
from .nnutils import cuda
ELEM_LIST = ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na',
'Ca', 'Fe', 'Al', 'I', 'B', 'K', 'Se', 'Zn', 'H', 'Cu', 'Mn', 'unknown']
ATOM_FDIM = len(ELEM_LIST) + 6 + 5 + 1
BOND_FDIM = 5
MAX_NB = 10
PAPER = os.getenv('PAPER', False)
def onek_encoding_unk(x, allowable_set):
if x not in allowable_set:
x = allowable_set[-1]
return [x == s for s in allowable_set]
# Note that during graph decoding they don't predict stereochemistry-related
# characteristics (i.e. Chiral Atoms, E-Z, Cis-Trans). Instead, they decode
# the 2-D graph first, then enumerate all possible 3-D forms and find the
# one with highest score.
def atom_features(atom):
return (torch.Tensor(onek_encoding_unk(atom.GetSymbol(), ELEM_LIST)
+ onek_encoding_unk(atom.GetDegree(),
[0, 1, 2, 3, 4, 5])
+ onek_encoding_unk(atom.GetFormalCharge(),
[-1, -2, 1, 2, 0])
+ [atom.GetIsAromatic()]))
def bond_features(bond):
bt = bond.GetBondType()
return torch.Tensor([bt == Chem.rdchem.BondType.SINGLE,
bt == Chem.rdchem.BondType.DOUBLE,
bt == Chem.rdchem.BondType.TRIPLE,
bt == Chem.rdchem.BondType.AROMATIC,
bond.IsInRing()])
def mol2dgl_single(cand_batch):
cand_graphs = []
tree_mess_source_edges = [] # map these edges from trees to...
tree_mess_target_edges = [] # these edges on candidate graphs
tree_mess_target_nodes = []
n_nodes = 0
n_edges = 0
atom_x = []
bond_x = []
for mol, mol_tree, ctr_node_id in cand_batch:
n_atoms = mol.GetNumAtoms()
n_bonds = mol.GetNumBonds()
ctr_node = mol_tree.nodes_dict[ctr_node_id]
ctr_bid = ctr_node['idx']
g = DGLGraph()
for i, atom in enumerate(mol.GetAtoms()):
assert i == atom.GetIdx()
atom_x.append(atom_features(atom))
g.add_nodes(n_atoms)
bond_src = []
bond_dst = []
for i, bond in enumerate(mol.GetBonds()):
a1 = bond.GetBeginAtom()
a2 = bond.GetEndAtom()
begin_idx = a1.GetIdx()
end_idx = a2.GetIdx()
features = bond_features(bond)
bond_src.append(begin_idx)
bond_dst.append(end_idx)
bond_x.append(features)
bond_src.append(end_idx)
bond_dst.append(begin_idx)
bond_x.append(features)
x_nid, y_nid = a1.GetAtomMapNum(), a2.GetAtomMapNum()
# Tree node ID in the batch
x_bid = mol_tree.nodes_dict[x_nid - 1]['idx'] if x_nid > 0 else -1
y_bid = mol_tree.nodes_dict[y_nid - 1]['idx'] if y_nid > 0 else -1
if x_bid >= 0 and y_bid >= 0 and x_bid != y_bid:
if mol_tree.has_edge_between(x_bid, y_bid):
tree_mess_target_edges.append(
(begin_idx + n_nodes, end_idx + n_nodes))
tree_mess_source_edges.append((x_bid, y_bid))
tree_mess_target_nodes.append(end_idx + n_nodes)
if mol_tree.has_edge_between(y_bid, x_bid):
tree_mess_target_edges.append(
(end_idx + n_nodes, begin_idx + n_nodes))
tree_mess_source_edges.append((y_bid, x_bid))
tree_mess_target_nodes.append(begin_idx + n_nodes)
n_nodes += n_atoms
g.add_edges(bond_src, bond_dst)
cand_graphs.append(g)
return cand_graphs, torch.stack(atom_x), \
torch.stack(bond_x) if len(bond_x) > 0 else torch.zeros(0), \
torch.LongTensor(tree_mess_source_edges), \
torch.LongTensor(tree_mess_target_edges), \
torch.LongTensor(tree_mess_target_nodes)
mpn_loopy_bp_msg = DGLF.copy_src(src='msg', out='msg')
mpn_loopy_bp_reduce = DGLF.sum(msg='msg', out='accum_msg')
class LoopyBPUpdate(nn.Module):
def __init__(self, hidden_size):
super(LoopyBPUpdate, self).__init__()
self.hidden_size = hidden_size
self.W_h = nn.Linear(hidden_size, hidden_size, bias=False)
def forward(self, node):
msg_input = node.data['msg_input']
msg_delta = self.W_h(node.data['accum_msg'] + node.data['alpha'])
msg = torch.relu(msg_input + msg_delta)
return {'msg': msg}
if PAPER:
mpn_gather_msg = [
DGLF.copy_edge(edge='msg', out='msg'),
DGLF.copy_edge(edge='alpha', out='alpha')
]
else:
mpn_gather_msg = DGLF.copy_edge(edge='msg', out='msg')
if PAPER:
mpn_gather_reduce = [
DGLF.sum(msg='msg', out='m'),
DGLF.sum(msg='alpha', out='accum_alpha'),
]
else:
mpn_gather_reduce = DGLF.sum(msg='msg', out='m')
class GatherUpdate(nn.Module):
def __init__(self, hidden_size):
super(GatherUpdate, self).__init__()
self.hidden_size = hidden_size
self.W_o = nn.Linear(ATOM_FDIM + hidden_size, hidden_size)
def forward(self, node):
if PAPER:
#m = node['m']
m = node.data['m'] + node.data['accum_alpha']
else:
m = node.data['m'] + node.data['alpha']
return {
'h': torch.relu(self.W_o(torch.cat([node.data['x'], m], 1))),
}
class DGLJTMPN(nn.Module):
def __init__(self, hidden_size, depth):
nn.Module.__init__(self)
self.depth = depth
self.W_i = nn.Linear(ATOM_FDIM + BOND_FDIM, hidden_size, bias=False)
self.loopy_bp_updater = LoopyBPUpdate(hidden_size)
self.gather_updater = GatherUpdate(hidden_size)
self.hidden_size = hidden_size
self.n_samples_total = 0
self.n_nodes_total = 0
self.n_edges_total = 0
self.n_passes = 0
def forward(self, cand_batch, mol_tree_batch):
cand_graphs, tree_mess_src_edges, tree_mess_tgt_edges, tree_mess_tgt_nodes = cand_batch
n_samples = len(cand_graphs)
cand_line_graph = cand_graphs.line_graph(
backtracking=False, shared=True)
n_nodes = cand_graphs.number_of_nodes()
n_edges = cand_graphs.number_of_edges()
cand_graphs = self.run(
cand_graphs, cand_line_graph, tree_mess_src_edges, tree_mess_tgt_edges,
tree_mess_tgt_nodes, mol_tree_batch)
g_repr = mean_nodes(cand_graphs, 'h')
self.n_samples_total += n_samples
self.n_nodes_total += n_nodes
self.n_edges_total += n_edges
self.n_passes += 1
return g_repr
def run(self, cand_graphs, cand_line_graph, tree_mess_src_edges, tree_mess_tgt_edges,
tree_mess_tgt_nodes, mol_tree_batch):
n_nodes = cand_graphs.number_of_nodes()
cand_graphs.apply_edges(
func=lambda edges: {'src_x': edges.src['x']},
)
bond_features = cand_line_graph.ndata['x']
source_features = cand_line_graph.ndata['src_x']
features = torch.cat([source_features, bond_features], 1)
msg_input = self.W_i(features)
cand_line_graph.ndata.update({
'msg_input': msg_input,
'msg': torch.relu(msg_input),
'accum_msg': torch.zeros_like(msg_input),
})
zero_node_state = bond_features.new(n_nodes, self.hidden_size).zero_()
cand_graphs.ndata.update({
'm': zero_node_state.clone(),
'h': zero_node_state.clone(),
})
cand_graphs.edata['alpha'] = \
cuda(torch.zeros(cand_graphs.number_of_edges(), self.hidden_size))
cand_graphs.ndata['alpha'] = zero_node_state
if tree_mess_src_edges.shape[0] > 0:
if PAPER:
src_u, src_v = tree_mess_src_edges.unbind(1)
tgt_u, tgt_v = tree_mess_tgt_edges.unbind(1)
alpha = mol_tree_batch.edges[src_u, src_v].data['m']
cand_graphs.edges[tgt_u, tgt_v].data['alpha'] = alpha
else:
src_u, src_v = tree_mess_src_edges.unbind(1)
alpha = mol_tree_batch.edges[src_u, src_v].data['m']
node_idx = (tree_mess_tgt_nodes
.to(device=zero_node_state.device)[:, None]
.expand_as(alpha))
node_alpha = zero_node_state.clone().scatter_add(0, node_idx, alpha)
cand_graphs.ndata['alpha'] = node_alpha
cand_graphs.apply_edges(
func=lambda edges: {'alpha': edges.src['alpha']},
)
for i in range(self.depth - 1):
cand_line_graph.update_all(
mpn_loopy_bp_msg,
mpn_loopy_bp_reduce,
self.loopy_bp_updater,
)
cand_graphs.update_all(
mpn_gather_msg,
mpn_gather_reduce,
self.gather_updater,
)
return cand_graphs
# pylint: disable=C0111, C0103, E1101, W0611, W0612
import copy
import itertools
import networkx as nx
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl.function as DGLF
from dgl import batch, dfs_labeled_edges_generator
from .chemutils import enum_assemble_nx, get_mol
from .mol_tree import Vocab
from .mol_tree_nx import DGLMolTree
from .nnutils import GRUUpdate, cuda
MAX_NB = 8
MAX_DECODE_LEN = 100
def dfs_order(forest, roots):
edges = dfs_labeled_edges_generator(forest, roots, has_reverse_edge=True)
for e, l in zip(*edges):
# I exploited the fact that the reverse edge ID equal to 1 xor forward
# edge ID for molecule trees. Normally, I should locate reverse edges
# using find_edges().
yield e ^ l, l
dec_tree_node_msg = DGLF.copy_edge(edge='m', out='m')
dec_tree_node_reduce = DGLF.sum(msg='m', out='h')
def dec_tree_node_update(nodes):
return {'new': nodes.data['new'].clone().zero_()}
dec_tree_edge_msg = [DGLF.copy_src(
src='m', out='m'), DGLF.copy_src(src='rm', out='rm')]
dec_tree_edge_reduce = [
DGLF.sum(msg='m', out='s'), DGLF.sum(msg='rm', out='accum_rm')]
def have_slots(fa_slots, ch_slots):
if len(fa_slots) > 2 and len(ch_slots) > 2:
return True
matches = []
for i, s1 in enumerate(fa_slots):
a1, c1, h1 = s1
for j, s2 in enumerate(ch_slots):
a2, c2, h2 = s2
if a1 == a2 and c1 == c2 and (a1 != "C" or h1 + h2 >= 4):
matches.append((i, j))
if len(matches) == 0:
return False
fa_match, ch_match = list(zip(*matches))
if len(set(fa_match)) == 1 and 1 < len(fa_slots) <= 2: # never remove atom from ring
fa_slots.pop(fa_match[0])
if len(set(ch_match)) == 1 and 1 < len(ch_slots) <= 2: # never remove atom from ring
ch_slots.pop(ch_match[0])
return True
def can_assemble(mol_tree, u, v_node_dict):
u_node_dict = mol_tree.nodes_dict[u]
u_neighbors = mol_tree.successors(u)
u_neighbors_node_dict = [
mol_tree.nodes_dict[_u]
for _u in u_neighbors
if _u in mol_tree.nodes_dict
]
neis = u_neighbors_node_dict + [v_node_dict]
for i, nei in enumerate(neis):
nei['nid'] = i
neighbors = [nei for nei in neis if nei['mol'].GetNumAtoms() > 1]
neighbors = sorted(
neighbors, key=lambda x: x['mol'].GetNumAtoms(), reverse=True)
singletons = [nei for nei in neis if nei['mol'].GetNumAtoms() == 1]
neighbors = singletons + neighbors
cands = enum_assemble_nx(u_node_dict, neighbors)
return len(cands) > 0
def create_node_dict(smiles, clique=None):
if clique is None:
clique = []
return dict(
smiles=smiles,
mol=get_mol(smiles),
clique=clique,
)
class DGLJTNNDecoder(nn.Module):
def __init__(self, vocab, hidden_size, latent_size, embedding=None):
nn.Module.__init__(self)
self.hidden_size = hidden_size
self.vocab_size = vocab.size()
self.vocab = vocab
if embedding is None:
self.embedding = nn.Embedding(self.vocab_size, hidden_size)
else:
self.embedding = embedding
self.dec_tree_edge_update = GRUUpdate(hidden_size)
self.W = nn.Linear(latent_size + hidden_size, hidden_size)
self.U = nn.Linear(latent_size + 2 * hidden_size, hidden_size)
self.W_o = nn.Linear(hidden_size, self.vocab_size)
self.U_s = nn.Linear(hidden_size, 1)
def forward(self, mol_trees, tree_vec):
'''
The training procedure which computes the prediction loss given the
ground truth tree
'''
mol_tree_batch = batch(mol_trees)
mol_tree_batch_lg = mol_tree_batch.line_graph(
backtracking=False, shared=True)
n_trees = len(mol_trees)
return self.run(mol_tree_batch, mol_tree_batch_lg, n_trees, tree_vec)
def run(self, mol_tree_batch, mol_tree_batch_lg, n_trees, tree_vec):
node_offset = np.cumsum([0] + mol_tree_batch.batch_num_nodes)
root_ids = node_offset[:-1]
n_nodes = mol_tree_batch.number_of_nodes()
n_edges = mol_tree_batch.number_of_edges()
mol_tree_batch.ndata.update({
'x': self.embedding(mol_tree_batch.ndata['wid']),
'h': cuda(torch.zeros(n_nodes, self.hidden_size)),
# whether it's newly generated node
'new': cuda(torch.ones(n_nodes).byte()),
})
mol_tree_batch.edata.update({
's': cuda(torch.zeros(n_edges, self.hidden_size)),
'm': cuda(torch.zeros(n_edges, self.hidden_size)),
'r': cuda(torch.zeros(n_edges, self.hidden_size)),
'z': cuda(torch.zeros(n_edges, self.hidden_size)),
'src_x': cuda(torch.zeros(n_edges, self.hidden_size)),
'dst_x': cuda(torch.zeros(n_edges, self.hidden_size)),
'rm': cuda(torch.zeros(n_edges, self.hidden_size)),
'accum_rm': cuda(torch.zeros(n_edges, self.hidden_size)),
})
mol_tree_batch.apply_edges(
func=lambda edges: {
'src_x': edges.src['x'], 'dst_x': edges.dst['x']},
)
# input tensors for stop prediction (p) and label prediction (q)
p_inputs = []
p_targets = []
q_inputs = []
q_targets = []
# Predict root
mol_tree_batch.pull(
root_ids,
dec_tree_node_msg,
dec_tree_node_reduce,
dec_tree_node_update,
)
# Extract hidden states and store them for stop/label prediction
h = mol_tree_batch.nodes[root_ids].data['h']
x = mol_tree_batch.nodes[root_ids].data['x']
p_inputs.append(torch.cat([x, h, tree_vec], 1))
# If the out degree is 0 we don't generate any edges at all
root_out_degrees = mol_tree_batch.out_degrees(root_ids)
q_inputs.append(torch.cat([h, tree_vec], 1))
q_targets.append(mol_tree_batch.nodes[root_ids].data['wid'])
# Traverse the tree and predict on children
for eid, p in dfs_order(mol_tree_batch, root_ids):
u, v = mol_tree_batch.find_edges(eid)
p_target_list = torch.zeros_like(root_out_degrees)
p_target_list[root_out_degrees > 0] = 1 - p
p_target_list = p_target_list[root_out_degrees >= 0]
p_targets.append(torch.tensor(p_target_list))
root_out_degrees -= (root_out_degrees == 0).long()
root_out_degrees -= torch.tensor(np.isin(root_ids,
v).astype('int64'))
mol_tree_batch_lg.pull(
eid,
dec_tree_edge_msg,
dec_tree_edge_reduce,
self.dec_tree_edge_update,
)
is_new = mol_tree_batch.nodes[v].data['new']
mol_tree_batch.pull(
v,
dec_tree_node_msg,
dec_tree_node_reduce,
dec_tree_node_update,
)
# Extract
n_repr = mol_tree_batch.nodes[v].data
h = n_repr['h']
x = n_repr['x']
tree_vec_set = tree_vec[root_out_degrees >= 0]
wid = n_repr['wid']
p_inputs.append(torch.cat([x, h, tree_vec_set], 1))
# Only newly generated nodes are needed for label prediction
# NOTE: The following works since the uncomputed messages are zeros.
q_input = torch.cat([h, tree_vec_set], 1)[is_new]
q_target = wid[is_new]
if q_input.shape[0] > 0:
q_inputs.append(q_input)
q_targets.append(q_target)
p_targets.append(torch.zeros((root_out_degrees == 0).sum()).long())
# Batch compute the stop/label prediction losses
p_inputs = torch.cat(p_inputs, 0)
p_targets = cuda(torch.cat(p_targets, 0))
q_inputs = torch.cat(q_inputs, 0)
q_targets = torch.cat(q_targets, 0)
q = self.W_o(torch.relu(self.W(q_inputs)))
p = self.U_s(torch.relu(self.U(p_inputs)))[:, 0]
p_loss = F.binary_cross_entropy_with_logits(
p, p_targets.float(), size_average=False
) / n_trees
q_loss = F.cross_entropy(q, q_targets, size_average=False) / n_trees
p_acc = ((p > 0).long() == p_targets).sum().float() / \
p_targets.shape[0]
q_acc = (q.max(1)[1] == q_targets).float().sum() / q_targets.shape[0]
self.q_inputs = q_inputs
self.q_targets = q_targets
self.q = q
self.p_inputs = p_inputs
self.p_targets = p_targets
self.p = p
return q_loss, p_loss, q_acc, p_acc
def decode(self, mol_vec):
assert mol_vec.shape[0] == 1
mol_tree = DGLMolTree(None)
init_hidden = cuda(torch.zeros(1, self.hidden_size))
root_hidden = torch.cat([init_hidden, mol_vec], 1)
root_hidden = F.relu(self.W(root_hidden))
root_score = self.W_o(root_hidden)
_, root_wid = torch.max(root_score, 1)
root_wid = root_wid.view(1)
mol_tree.add_nodes(1) # root
mol_tree.nodes[0].data['wid'] = root_wid
mol_tree.nodes[0].data['x'] = self.embedding(root_wid)
mol_tree.nodes[0].data['h'] = init_hidden
mol_tree.nodes[0].data['fail'] = cuda(torch.tensor([0]))
mol_tree.nodes_dict[0] = root_node_dict = create_node_dict(
self.vocab.get_smiles(root_wid))
stack, trace = [], []
stack.append((0, self.vocab.get_slots(root_wid)))
all_nodes = {0: root_node_dict}
h = {}
first = True
new_node_id = 0
new_edge_id = 0
for step in range(MAX_DECODE_LEN):
u, u_slots = stack[-1]
udata = mol_tree.nodes[u].data
wid = udata['wid']
x = udata['x']
h = udata['h']
# Predict stop
p_input = torch.cat([x, h, mol_vec], 1)
p_score = torch.sigmoid(self.U_s(torch.relu(self.U(p_input))))
backtrack = (p_score.item() < 0.5)
if not backtrack:
# Predict next clique. Note that the prediction may fail due
# to lack of assemblable components
mol_tree.add_nodes(1)
new_node_id += 1
v = new_node_id
mol_tree.add_edges(u, v)
uv = new_edge_id
new_edge_id += 1
if first:
mol_tree.edata.update({
's': cuda(torch.zeros(1, self.hidden_size)),
'm': cuda(torch.zeros(1, self.hidden_size)),
'r': cuda(torch.zeros(1, self.hidden_size)),
'z': cuda(torch.zeros(1, self.hidden_size)),
'src_x': cuda(torch.zeros(1, self.hidden_size)),
'dst_x': cuda(torch.zeros(1, self.hidden_size)),
'rm': cuda(torch.zeros(1, self.hidden_size)),
'accum_rm': cuda(torch.zeros(1, self.hidden_size)),
})
first = False
mol_tree.edges[uv].data['src_x'] = mol_tree.nodes[u].data['x']
# keeping dst_x 0 is fine as h on new edge doesn't depend on that.
# DGL doesn't dynamically maintain a line graph.
mol_tree_lg = mol_tree.line_graph(
backtracking=False, shared=True)
mol_tree_lg.pull(
uv,
dec_tree_edge_msg,
dec_tree_edge_reduce,
self.dec_tree_edge_update.update_zm,
)
mol_tree.pull(
v,
dec_tree_node_msg,
dec_tree_node_reduce,
)
vdata = mol_tree.nodes[v].data
h_v = vdata['h']
q_input = torch.cat([h_v, mol_vec], 1)
q_score = torch.softmax(
self.W_o(torch.relu(self.W(q_input))), -1)
_, sort_wid = torch.sort(q_score, 1, descending=True)
sort_wid = sort_wid.squeeze()
next_wid = None
for wid in sort_wid.tolist()[:5]:
slots = self.vocab.get_slots(wid)
cand_node_dict = create_node_dict(
self.vocab.get_smiles(wid))
if (have_slots(u_slots, slots) and can_assemble(mol_tree, u, cand_node_dict)):
next_wid = wid
next_slots = slots
next_node_dict = cand_node_dict
break
if next_wid is None:
# Failed adding an actual children; v is a spurious node
# and we mark it.
vdata['fail'] = cuda(torch.tensor([1]))
backtrack = True
else:
next_wid = cuda(torch.tensor([next_wid]))
vdata['wid'] = next_wid
vdata['x'] = self.embedding(next_wid)
mol_tree.nodes_dict[v] = next_node_dict
all_nodes[v] = next_node_dict
stack.append((v, next_slots))
mol_tree.add_edge(v, u)
vu = new_edge_id
new_edge_id += 1
mol_tree.edges[uv].data['dst_x'] = mol_tree.nodes[v].data['x']
mol_tree.edges[vu].data['src_x'] = mol_tree.nodes[v].data['x']
mol_tree.edges[vu].data['dst_x'] = mol_tree.nodes[u].data['x']
# DGL doesn't dynamically maintain a line graph.
mol_tree_lg = mol_tree.line_graph(
backtracking=False, shared=True)
mol_tree_lg.apply_nodes(
self.dec_tree_edge_update.update_r,
uv
)
if backtrack:
if len(stack) == 1:
break # At root, terminate
pu, _ = stack[-2]
u_pu = mol_tree.edge_id(u, pu)
mol_tree_lg.pull(
u_pu,
dec_tree_edge_msg,
dec_tree_edge_reduce,
self.dec_tree_edge_update,
)
mol_tree.pull(
pu,
dec_tree_node_msg,
dec_tree_node_reduce,
)
stack.pop()
effective_nodes = mol_tree.filter_nodes(
lambda nodes: nodes.data['fail'] != 1)
effective_nodes, _ = torch.sort(effective_nodes)
return mol_tree, all_nodes, effective_nodes
# pylint: disable=C0111, C0103, E1101, W0611, W0612
import itertools
from collections import deque
import networkx as nx
import numpy as np
import torch
import torch.nn as nn
import dgl.function as DGLF
from dgl import batch, bfs_edges_generator, unbatch
from .mol_tree import Vocab
from .nnutils import GRUUpdate, cuda
MAX_NB = 8
def level_order(forest, roots):
edges = bfs_edges_generator(forest, roots)
_, leaves = forest.find_edges(edges[-1])
edges_back = bfs_edges_generator(forest, roots, reverse=True)
yield from reversed(edges_back)
yield from edges
enc_tree_msg = [DGLF.copy_src(src='m', out='m'),
DGLF.copy_src(src='rm', out='rm')]
enc_tree_reduce = [DGLF.sum(msg='m', out='s'),
DGLF.sum(msg='rm', out='accum_rm')]
enc_tree_gather_msg = DGLF.copy_edge(edge='m', out='m')
enc_tree_gather_reduce = DGLF.sum(msg='m', out='m')
class EncoderGatherUpdate(nn.Module):
def __init__(self, hidden_size):
nn.Module.__init__(self)
self.hidden_size = hidden_size
self.W = nn.Linear(2 * hidden_size, hidden_size)
def forward(self, nodes):
x = nodes.data['x']
m = nodes.data['m']
return {
'h': torch.relu(self.W(torch.cat([x, m], 1))),
}
class DGLJTNNEncoder(nn.Module):
def __init__(self, vocab, hidden_size, embedding=None):
nn.Module.__init__(self)
self.hidden_size = hidden_size
self.vocab_size = vocab.size()
self.vocab = vocab
if embedding is None:
self.embedding = nn.Embedding(self.vocab_size, hidden_size)
else:
self.embedding = embedding
self.enc_tree_update = GRUUpdate(hidden_size)
self.enc_tree_gather_update = EncoderGatherUpdate(hidden_size)
def forward(self, mol_trees):
mol_tree_batch = batch(mol_trees)
# Build line graph to prepare for belief propagation
mol_tree_batch_lg = mol_tree_batch.line_graph(
backtracking=False, shared=True)
return self.run(mol_tree_batch, mol_tree_batch_lg)
def run(self, mol_tree_batch, mol_tree_batch_lg):
# Since tree roots are designated to 0. In the batched graph we can
# simply find the corresponding node ID by looking at node_offset
node_offset = np.cumsum([0] + mol_tree_batch.batch_num_nodes)
root_ids = node_offset[:-1]
n_nodes = mol_tree_batch.number_of_nodes()
n_edges = mol_tree_batch.number_of_edges()
# Assign structure embeddings to tree nodes
mol_tree_batch.ndata.update({
'x': self.embedding(mol_tree_batch.ndata['wid']),
'h': cuda(torch.zeros(n_nodes, self.hidden_size)),
})
# Initialize the intermediate variables according to Eq (4)-(8).
# Also initialize the src_x and dst_x fields.
# TODO: context?
mol_tree_batch.edata.update({
's': cuda(torch.zeros(n_edges, self.hidden_size)),
'm': cuda(torch.zeros(n_edges, self.hidden_size)),
'r': cuda(torch.zeros(n_edges, self.hidden_size)),
'z': cuda(torch.zeros(n_edges, self.hidden_size)),
'src_x': cuda(torch.zeros(n_edges, self.hidden_size)),
'dst_x': cuda(torch.zeros(n_edges, self.hidden_size)),
'rm': cuda(torch.zeros(n_edges, self.hidden_size)),
'accum_rm': cuda(torch.zeros(n_edges, self.hidden_size)),
})
# Send the source/destination node features to edges
mol_tree_batch.apply_edges(
func=lambda edges: {
'src_x': edges.src['x'], 'dst_x': edges.dst['x']},
)
# Message passing
# I exploited the fact that the reduce function is a sum of incoming
# messages, and the uncomputed messages are zero vectors. Essentially,
# we can always compute s_ij as the sum of incoming m_ij, no matter
# if m_ij is actually computed or not.
for eid in level_order(mol_tree_batch, root_ids):
#eid = mol_tree_batch.edge_ids(u, v)
mol_tree_batch_lg.pull(
eid,
enc_tree_msg,
enc_tree_reduce,
self.enc_tree_update,
)
# Readout
mol_tree_batch.update_all(
enc_tree_gather_msg,
enc_tree_gather_reduce,
self.enc_tree_gather_update,
)
root_vecs = mol_tree_batch.nodes[root_ids].data['h']
return mol_tree_batch, root_vecs
# pylint: disable=C0111, C0103, E1101, W0611, W0612, C0200
import copy
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import rdkit
import rdkit.Chem as Chem
from rdkit import DataStructs
from rdkit.Chem import AllChem
from dgl import batch, unbatch
from dgl.data.utils import get_download_dir
from .chemutils import (attach_mols_nx, copy_edit_mol, decode_stereo,
enum_assemble_nx, set_atommap)
from .jtmpn import DGLJTMPN
from .jtmpn import mol2dgl_single as mol2dgl_dec
from .jtnn_dec import DGLJTNNDecoder
from .jtnn_enc import DGLJTNNEncoder
from .mol_tree import Vocab
from .mpn import DGLMPN
from .mpn import mol2dgl_single as mol2dgl_enc
from .nnutils import create_var, cuda, move_dgl_to_cuda
class DGLJTNNVAE(nn.Module):
def __init__(self, hidden_size, latent_size, depth, vocab=None, vocab_file=None):
super(DGLJTNNVAE, self).__init__()
if vocab is None:
if vocab_file is None:
vocab_file = '{}/jtnn/{}.txt'.format(
get_download_dir(), 'vocab')
self.vocab = Vocab([x.strip("\r\n ")
for x in open(vocab_file)])
else:
self.vocab = Vocab([x.strip("\r\n ")
for x in open(vocab_file)])
else:
self.vocab = vocab
self.hidden_size = hidden_size
self.latent_size = latent_size
self.depth = depth
self.embedding = nn.Embedding(self.vocab.size(), hidden_size)
self.mpn = DGLMPN(hidden_size, depth)
self.jtnn = DGLJTNNEncoder(self.vocab, hidden_size, self.embedding)
self.decoder = DGLJTNNDecoder(
self.vocab, hidden_size, latent_size // 2, self.embedding)
self.jtmpn = DGLJTMPN(hidden_size, depth)
self.T_mean = nn.Linear(hidden_size, latent_size // 2)
self.T_var = nn.Linear(hidden_size, latent_size // 2)
self.G_mean = nn.Linear(hidden_size, latent_size // 2)
self.G_var = nn.Linear(hidden_size, latent_size // 2)
self.n_nodes_total = 0
self.n_passes = 0
self.n_edges_total = 0
self.n_tree_nodes_total = 0
@staticmethod
def move_to_cuda(mol_batch):
for t in mol_batch['mol_trees']:
move_dgl_to_cuda(t)
move_dgl_to_cuda(mol_batch['mol_graph_batch'])
if 'cand_graph_batch' in mol_batch:
move_dgl_to_cuda(mol_batch['cand_graph_batch'])
if mol_batch.get('stereo_cand_graph_batch') is not None:
move_dgl_to_cuda(mol_batch['stereo_cand_graph_batch'])
def encode(self, mol_batch):
mol_graphs = mol_batch['mol_graph_batch']
mol_vec = self.mpn(mol_graphs)
mol_tree_batch, tree_vec = self.jtnn(mol_batch['mol_trees'])
self.n_nodes_total += mol_graphs.number_of_nodes()
self.n_edges_total += mol_graphs.number_of_edges()
self.n_tree_nodes_total += sum(t.number_of_nodes()
for t in mol_batch['mol_trees'])
self.n_passes += 1
return mol_tree_batch, tree_vec, mol_vec
def sample(self, tree_vec, mol_vec, e1=None, e2=None):
tree_mean = self.T_mean(tree_vec)
tree_log_var = -torch.abs(self.T_var(tree_vec))
mol_mean = self.G_mean(mol_vec)
mol_log_var = -torch.abs(self.G_var(mol_vec))
epsilon = cuda(torch.randn(*tree_mean.shape)) if e1 is None else e1
tree_vec = tree_mean + torch.exp(tree_log_var / 2) * epsilon
epsilon = cuda(torch.randn(*mol_mean.shape)) if e2 is None else e2
mol_vec = mol_mean + torch.exp(mol_log_var / 2) * epsilon
z_mean = torch.cat([tree_mean, mol_mean], 1)
z_log_var = torch.cat([tree_log_var, mol_log_var], 1)
return tree_vec, mol_vec, z_mean, z_log_var
def forward(self, mol_batch, beta=0, e1=None, e2=None):
self.move_to_cuda(mol_batch)
mol_trees = mol_batch['mol_trees']
batch_size = len(mol_trees)
mol_tree_batch, tree_vec, mol_vec = self.encode(mol_batch)
tree_vec, mol_vec, z_mean, z_log_var = self.sample(
tree_vec, mol_vec, e1, e2)
kl_loss = -0.5 * torch.sum(
1.0 + z_log_var - z_mean * z_mean - torch.exp(z_log_var)) / batch_size
word_loss, topo_loss, word_acc, topo_acc = self.decoder(
mol_trees, tree_vec)
assm_loss, assm_acc = self.assm(mol_batch, mol_tree_batch, mol_vec)
stereo_loss, stereo_acc = self.stereo(mol_batch, mol_vec)
all_vec = torch.cat([tree_vec, mol_vec], dim=1)
loss = word_loss + topo_loss + assm_loss + 2 * stereo_loss + beta * kl_loss
return loss, kl_loss, word_acc, topo_acc, assm_acc, stereo_acc
def assm(self, mol_batch, mol_tree_batch, mol_vec):
cands = [mol_batch['cand_graph_batch'],
mol_batch['tree_mess_src_e'],
mol_batch['tree_mess_tgt_e'],
mol_batch['tree_mess_tgt_n']]
cand_vec = self.jtmpn(cands, mol_tree_batch)
cand_vec = self.G_mean(cand_vec)
batch_idx = cuda(torch.LongTensor(mol_batch['cand_batch_idx']))
mol_vec = mol_vec[batch_idx]
mol_vec = mol_vec.view(-1, 1, self.latent_size // 2)
cand_vec = cand_vec.view(-1, self.latent_size // 2, 1)
scores = (mol_vec @ cand_vec)[:, 0, 0]
cnt, tot, acc = 0, 0, 0
all_loss = []
for i, mol_tree in enumerate(mol_batch['mol_trees']):
comp_nodes = [node_id for node_id, node in mol_tree.nodes_dict.items()
if len(node['cands']) > 1 and not node['is_leaf']]
cnt += len(comp_nodes)
# segmented accuracy and cross entropy
for node_id in comp_nodes:
node = mol_tree.nodes_dict[node_id]
label = node['cands'].index(node['label'])
ncand = len(node['cands'])
cur_score = scores[tot:tot + ncand]
tot += ncand
if cur_score[label].item() >= cur_score.max().item():
acc += 1
label = cuda(torch.LongTensor([label]))
all_loss.append(
F.cross_entropy(cur_score.view(1, -1), label, size_average=False))
all_loss = sum(all_loss) / len(mol_batch['mol_trees'])
return all_loss, acc / cnt
def stereo(self, mol_batch, mol_vec):
stereo_cands = mol_batch['stereo_cand_graph_batch']
batch_idx = mol_batch['stereo_cand_batch_idx']
labels = mol_batch['stereo_cand_labels']
lengths = mol_batch['stereo_cand_lengths']
if len(labels) == 0:
# Only one stereoisomer exists; do nothing
return cuda(torch.tensor(0.)), 1.
batch_idx = cuda(torch.LongTensor(batch_idx))
stereo_cands = self.mpn(stereo_cands)
stereo_cands = self.G_mean(stereo_cands)
stereo_labels = mol_vec[batch_idx]
scores = F.cosine_similarity(stereo_cands, stereo_labels)
st, acc = 0, 0
all_loss = []
for label, le in zip(labels, lengths):
cur_scores = scores[st:st + le]
if cur_scores.data[label].item() >= cur_scores.max().item():
acc += 1
label = cuda(torch.LongTensor([label]))
all_loss.append(
F.cross_entropy(cur_scores.view(1, -1), label, size_average=False))
st += le
all_loss = sum(all_loss) / len(labels)
return all_loss, acc / len(labels)
def decode(self, tree_vec, mol_vec):
mol_tree, nodes_dict, effective_nodes = self.decoder.decode(tree_vec)
effective_nodes_list = effective_nodes.tolist()
nodes_dict = [nodes_dict[v] for v in effective_nodes_list]
for i, (node_id, node) in enumerate(zip(effective_nodes_list, nodes_dict)):
node['idx'] = i
node['nid'] = i + 1
node['is_leaf'] = True
if mol_tree.in_degree(node_id) > 1:
node['is_leaf'] = False
set_atommap(node['mol'], node['nid'])
mol_tree_sg = mol_tree.subgraph(effective_nodes)
mol_tree_sg.copy_from_parent()
mol_tree_msg, _ = self.jtnn([mol_tree_sg])
mol_tree_msg = unbatch(mol_tree_msg)[0]
mol_tree_msg.nodes_dict = nodes_dict
cur_mol = copy_edit_mol(nodes_dict[0]['mol'])
global_amap = [{}] + [{} for node in nodes_dict]
global_amap[1] = {atom.GetIdx(): atom.GetIdx()
for atom in cur_mol.GetAtoms()}
cur_mol = self.dfs_assemble(
mol_tree_msg, mol_vec, cur_mol, global_amap, [], 0, None)
if cur_mol is None:
return None
cur_mol = cur_mol.GetMol()
set_atommap(cur_mol)
cur_mol = Chem.MolFromSmiles(Chem.MolToSmiles(cur_mol))
if cur_mol is None:
return None
smiles2D = Chem.MolToSmiles(cur_mol)
stereo_cands = decode_stereo(smiles2D)
if len(stereo_cands) == 1:
return stereo_cands[0]
stereo_graphs = [mol2dgl_enc(c) for c in stereo_cands]
stereo_cand_graphs, atom_x, bond_x = \
zip(*stereo_graphs)
stereo_cand_graphs = batch(stereo_cand_graphs)
atom_x = cuda(torch.cat(atom_x))
bond_x = cuda(torch.cat(bond_x))
stereo_cand_graphs.ndata['x'] = atom_x
stereo_cand_graphs.edata['x'] = bond_x
stereo_cand_graphs.edata['src_x'] = atom_x.new(
bond_x.shape[0], atom_x.shape[1]).zero_()
stereo_vecs = self.mpn(stereo_cand_graphs)
stereo_vecs = self.G_mean(stereo_vecs)
scores = F.cosine_similarity(stereo_vecs, mol_vec)
_, max_id = scores.max(0)
return stereo_cands[max_id.item()]
def dfs_assemble(self, mol_tree_msg, mol_vec, cur_mol,
global_amap, fa_amap, cur_node_id, fa_node_id):
nodes_dict = mol_tree_msg.nodes_dict
fa_node = nodes_dict[fa_node_id] if fa_node_id is not None else None
cur_node = nodes_dict[cur_node_id]
fa_nid = fa_node['nid'] if fa_node is not None else -1
prev_nodes = [fa_node] if fa_node is not None else []
children_node_id = [v for v in mol_tree_msg.successors(cur_node_id).tolist()
if nodes_dict[v]['nid'] != fa_nid]
children = [nodes_dict[v] for v in children_node_id]
neighbors = [nei for nei in children if nei['mol'].GetNumAtoms() > 1]
neighbors = sorted(
neighbors, key=lambda x: x['mol'].GetNumAtoms(), reverse=True)
singletons = [nei for nei in children if nei['mol'].GetNumAtoms() == 1]
neighbors = singletons + neighbors
cur_amap = [(fa_nid, a2, a1)
for nid, a1, a2 in fa_amap if nid == cur_node['nid']]
cands = enum_assemble_nx(cur_node, neighbors, prev_nodes, cur_amap)
if len(cands) == 0:
return None
cand_smiles, cand_mols, cand_amap = list(zip(*cands))
cands = [(candmol, mol_tree_msg, cur_node_id) for candmol in cand_mols]
cand_graphs, atom_x, bond_x, tree_mess_src_edges, \
tree_mess_tgt_edges, tree_mess_tgt_nodes = mol2dgl_dec(
cands)
cand_graphs = batch(cand_graphs)
atom_x = cuda(atom_x)
bond_x = cuda(bond_x)
cand_graphs.ndata['x'] = atom_x
cand_graphs.edata['x'] = bond_x
cand_graphs.edata['src_x'] = atom_x.new(
bond_x.shape[0], atom_x.shape[1]).zero_()
cand_vecs = self.jtmpn(
(cand_graphs, tree_mess_src_edges,
tree_mess_tgt_edges, tree_mess_tgt_nodes),
mol_tree_msg,
)
cand_vecs = self.G_mean(cand_vecs)
mol_vec = mol_vec.squeeze()
scores = cand_vecs @ mol_vec
_, cand_idx = torch.sort(scores, descending=True)
backup_mol = Chem.RWMol(cur_mol)
for i in range(len(cand_idx)):
cur_mol = Chem.RWMol(backup_mol)
pred_amap = cand_amap[cand_idx[i].item()]
new_global_amap = copy.deepcopy(global_amap)
for nei_id, ctr_atom, nei_atom in pred_amap:
if nei_id == fa_nid:
continue
new_global_amap[nei_id][nei_atom] = new_global_amap[cur_node['nid']][ctr_atom]
cur_mol = attach_mols_nx(cur_mol, children, [], new_global_amap)
new_mol = cur_mol.GetMol()
new_mol = Chem.MolFromSmiles(Chem.MolToSmiles(new_mol))
if new_mol is None:
continue
result = True
for nei_node_id, nei_node in zip(children_node_id, children):
if nei_node['is_leaf']:
continue
cur_mol = self.dfs_assemble(
mol_tree_msg, mol_vec, cur_mol, new_global_amap, pred_amap,
nei_node_id, cur_node_id)
if cur_mol is None:
result = False
break
if result:
return cur_mol
return None
# pylint: disable=C0111, C0103, E1101, W0611, W0612
import copy
import rdkit
import rdkit.Chem as Chem
def get_slots(smiles):
mol = Chem.MolFromSmiles(smiles)
return [(atom.GetSymbol(), atom.GetFormalCharge(), atom.GetTotalNumHs())
for atom in mol.GetAtoms()]
class Vocab(object):
def __init__(self, smiles_list):
self.vocab = smiles_list
self.vmap = {x: i for i, x in enumerate(self.vocab)}
self.slots = [get_slots(smiles) for smiles in self.vocab]
def get_index(self, smiles):
return self.vmap[smiles]
def get_smiles(self, idx):
return self.vocab[idx]
def get_slots(self, idx):
return copy.deepcopy(self.slots[idx])
def size(self):
return len(self.vocab)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment