Commit 3192beb4 authored by VoVAllen's avatar VoVAllen Committed by Mufei Li
Browse files

[Model zoo] JTNN model zoo (#790)

* jtnn model zoo

* poke ci

* fix line sep

* fix

* Fix import order

* fix render

* fix render

* revert

* fix

* Resolve conflict

* dix

* remove create_var

* refactor

* fix

* refactor

* readme

* format

* fix lint

* fix lint

* pylint

* lint

* fix lint

* fix lint

* add hint

* fix

* Remove vocab

* Add explanation for warning

* add directory

* Load model to cpu by default

* Update
parent 30c46251
# pylint: disable=C0111, C0103, E1101, W0611, W0612
import numpy as np
import rdkit.Chem as Chem
from dgl import DGLGraph
from .chemutils import (decode_stereo, enum_assemble_nx, get_clique_mol,
get_mol, get_smiles, set_atommap, tree_decomp)
class DGLMolTree(DGLGraph):
def __init__(self, smiles):
DGLGraph.__init__(self)
self.nodes_dict = {}
if smiles is None:
return
self.smiles = smiles
self.mol = get_mol(smiles)
# Stereo Generation
mol = Chem.MolFromSmiles(smiles)
self.smiles3D = Chem.MolToSmiles(mol, isomericSmiles=True)
self.smiles2D = Chem.MolToSmiles(mol)
self.stereo_cands = decode_stereo(self.smiles2D)
# cliques: a list of list of atom indices
cliques, edges = tree_decomp(self.mol)
root = 0
for i, c in enumerate(cliques):
cmol = get_clique_mol(self.mol, c)
csmiles = get_smiles(cmol)
self.nodes_dict[i] = dict(
smiles=csmiles,
mol=get_mol(csmiles),
clique=c,
)
if min(c) == 0:
root = i
self.add_nodes(len(cliques))
# The clique with atom ID 0 becomes root
if root > 0:
for attr in self.nodes_dict[0]:
self.nodes_dict[0][attr], self.nodes_dict[root][attr] = \
self.nodes_dict[root][attr], self.nodes_dict[0][attr]
src = np.zeros((len(edges) * 2,), dtype='int')
dst = np.zeros((len(edges) * 2,), dtype='int')
for i, (_x, _y) in enumerate(edges):
x = 0 if _x == root else root if _x == 0 else _x
y = 0 if _y == root else root if _y == 0 else _y
src[2 * i] = x
dst[2 * i] = y
src[2 * i + 1] = y
dst[2 * i + 1] = x
self.add_edges(src, dst)
for i in self.nodes_dict:
self.nodes_dict[i]['nid'] = i + 1
if self.out_degree(i) > 1: # Leaf node mol is not marked
set_atommap(self.nodes_dict[i]['mol'],
self.nodes_dict[i]['nid'])
self.nodes_dict[i]['is_leaf'] = (self.out_degree(i) == 1)
def treesize(self):
return self.number_of_nodes()
def _recover_node(self, i, original_mol):
node = self.nodes_dict[i]
clique = []
clique.extend(node['clique'])
if not node['is_leaf']:
for cidx in node['clique']:
original_mol.GetAtomWithIdx(cidx).SetAtomMapNum(node['nid'])
for j in self.successors(i).numpy():
nei_node = self.nodes_dict[j]
clique.extend(nei_node['clique'])
if nei_node['is_leaf']: # Leaf node, no need to mark
continue
for cidx in nei_node['clique']:
# allow singleton node override the atom mapping
if cidx not in node['clique'] or len(nei_node['clique']) == 1:
atom = original_mol.GetAtomWithIdx(cidx)
atom.SetAtomMapNum(nei_node['nid'])
clique = list(set(clique))
label_mol = get_clique_mol(original_mol, clique)
node['label'] = Chem.MolToSmiles(
Chem.MolFromSmiles(get_smiles(label_mol)))
node['label_mol'] = get_mol(node['label'])
for cidx in clique:
original_mol.GetAtomWithIdx(cidx).SetAtomMapNum(0)
return node['label']
def _assemble_node(self, i):
neighbors = [self.nodes_dict[j] for j in self.successors(i).numpy()
if self.nodes_dict[j]['mol'].GetNumAtoms() > 1]
neighbors = sorted(
neighbors, key=lambda x: x['mol'].GetNumAtoms(), reverse=True)
singletons = [self.nodes_dict[j] for j in self.successors(i).numpy()
if self.nodes_dict[j]['mol'].GetNumAtoms() == 1]
neighbors = singletons + neighbors
cands = enum_assemble_nx(self.nodes_dict[i], neighbors)
if len(cands) > 0:
self.nodes_dict[i]['cands'], self.nodes_dict[i]['cand_mols'], _ = list(
zip(*cands))
self.nodes_dict[i]['cands'] = list(self.nodes_dict[i]['cands'])
self.nodes_dict[i]['cand_mols'] = list(
self.nodes_dict[i]['cand_mols'])
else:
self.nodes_dict[i]['cands'] = []
self.nodes_dict[i]['cand_mols'] = []
def recover(self):
for i in self.nodes_dict:
self._recover_node(i, self.mol)
def assemble(self):
for i in self.nodes_dict:
self._assemble_node(i)
# pylint: disable=C0111, C0103, E1101, W0611, W0612
# pylint: disable=redefined-outer-name
from functools import partial
import numpy as np
import rdkit.Chem as Chem
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl.function as DGLF
from dgl import DGLGraph, batch, mean_nodes, unbatch
from networkx import DiGraph, Graph, convert_node_labels_to_integers
from .chemutils import get_mol
# from .nnutils import *
ELEM_LIST = ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na',
'Ca', 'Fe', 'Al', 'I', 'B', 'K', 'Se', 'Zn', 'H', 'Cu', 'Mn', 'unknown']
ATOM_FDIM = len(ELEM_LIST) + 6 + 5 + 4 + 1
BOND_FDIM = 5 + 6
MAX_NB = 6
def onek_encoding_unk(x, allowable_set):
if x not in allowable_set:
x = allowable_set[-1]
return [x == s for s in allowable_set]
def atom_features(atom):
return (torch.Tensor(onek_encoding_unk(atom.GetSymbol(), ELEM_LIST)
+ onek_encoding_unk(atom.GetDegree(),
[0, 1, 2, 3, 4, 5])
+ onek_encoding_unk(atom.GetFormalCharge(),
[-1, -2, 1, 2, 0])
+ onek_encoding_unk(int(atom.GetChiralTag()), [0, 1, 2, 3])
+ [atom.GetIsAromatic()]))
def bond_features(bond):
bt = bond.GetBondType()
stereo = int(bond.GetStereo())
fbond = [bt == Chem.rdchem.BondType.SINGLE, bt == Chem.rdchem.BondType.DOUBLE, bt ==
Chem.rdchem.BondType.TRIPLE, bt == Chem.rdchem.BondType.AROMATIC, bond.IsInRing()]
fstereo = onek_encoding_unk(stereo, [0, 1, 2, 3, 4, 5])
return torch.Tensor(fbond + fstereo)
def mol2dgl_single(smiles):
n_edges = 0
atom_x = []
bond_x = []
mol = get_mol(smiles)
n_atoms = mol.GetNumAtoms()
n_bonds = mol.GetNumBonds()
graph = DGLGraph()
for i, atom in enumerate(mol.GetAtoms()):
assert i == atom.GetIdx()
atom_x.append(atom_features(atom))
graph.add_nodes(n_atoms)
bond_src = []
bond_dst = []
for i, bond in enumerate(mol.GetBonds()):
begin_idx = bond.GetBeginAtom().GetIdx()
end_idx = bond.GetEndAtom().GetIdx()
features = bond_features(bond)
bond_src.append(begin_idx)
bond_dst.append(end_idx)
bond_x.append(features)
# set up the reverse direction
bond_src.append(end_idx)
bond_dst.append(begin_idx)
bond_x.append(features)
graph.add_edges(bond_src, bond_dst)
n_edges += n_bonds
return graph, torch.stack(atom_x), \
torch.stack(bond_x) if len(bond_x) > 0 else torch.zeros(0)
mpn_loopy_bp_msg = DGLF.copy_src(src='msg', out='msg')
mpn_loopy_bp_reduce = DGLF.sum(msg='msg', out='accum_msg')
class LoopyBPUpdate(nn.Module):
def __init__(self, hidden_size):
super(LoopyBPUpdate, self).__init__()
self.hidden_size = hidden_size
self.W_h = nn.Linear(hidden_size, hidden_size, bias=False)
def forward(self, nodes):
msg_input = nodes.data['msg_input']
msg_delta = self.W_h(nodes.data['accum_msg'])
msg = F.relu(msg_input + msg_delta)
return {'msg': msg}
mpn_gather_msg = DGLF.copy_edge(edge='msg', out='msg')
mpn_gather_reduce = DGLF.sum(msg='msg', out='m')
class GatherUpdate(nn.Module):
def __init__(self, hidden_size):
super(GatherUpdate, self).__init__()
self.hidden_size = hidden_size
self.W_o = nn.Linear(ATOM_FDIM + hidden_size, hidden_size)
def forward(self, nodes):
m = nodes.data['m']
return {
'h': F.relu(self.W_o(torch.cat([nodes.data['x'], m], 1))),
}
class DGLMPN(nn.Module):
def __init__(self, hidden_size, depth):
super(DGLMPN, self).__init__()
self.depth = depth
self.W_i = nn.Linear(ATOM_FDIM + BOND_FDIM, hidden_size, bias=False)
self.loopy_bp_updater = LoopyBPUpdate(hidden_size)
self.gather_updater = GatherUpdate(hidden_size)
self.hidden_size = hidden_size
self.n_samples_total = 0
self.n_nodes_total = 0
self.n_edges_total = 0
self.n_passes = 0
def forward(self, mol_graph):
n_samples = mol_graph.batch_size
mol_line_graph = mol_graph.line_graph(backtracking=False, shared=True)
n_nodes = mol_graph.number_of_nodes()
n_edges = mol_graph.number_of_edges()
mol_graph = self.run(mol_graph, mol_line_graph)
# TODO: replace with unbatch or readout
g_repr = mean_nodes(mol_graph, 'h')
self.n_samples_total += n_samples
self.n_nodes_total += n_nodes
self.n_edges_total += n_edges
self.n_passes += 1
return g_repr
def run(self, mol_graph, mol_line_graph):
n_nodes = mol_graph.number_of_nodes()
mol_graph.apply_edges(
func=lambda edges: {'src_x': edges.src['x']},
)
e_repr = mol_line_graph.ndata
bond_features = e_repr['x']
source_features = e_repr['src_x']
features = torch.cat([source_features, bond_features], 1)
msg_input = self.W_i(features)
mol_line_graph.ndata.update({
'msg_input': msg_input,
'msg': F.relu(msg_input),
'accum_msg': torch.zeros_like(msg_input),
})
mol_graph.ndata.update({
'm': bond_features.new(n_nodes, self.hidden_size).zero_(),
'h': bond_features.new(n_nodes, self.hidden_size).zero_(),
})
for i in range(self.depth - 1):
mol_line_graph.update_all(
mpn_loopy_bp_msg,
mpn_loopy_bp_reduce,
self.loopy_bp_updater,
)
mol_graph.update_all(
mpn_gather_msg,
mpn_gather_reduce,
self.gather_updater,
)
return mol_graph
# pylint: disable=C0111, C0103, E1101, W0611, W0612
import os
import torch
import torch.nn as nn
from torch.autograd import Variable
def create_var(tensor, requires_grad=None):
if requires_grad is None:
return Variable(tensor)
else:
return Variable(tensor, requires_grad=requires_grad)
def cuda(tensor):
if torch.cuda.is_available() and not os.getenv('NOCUDA', None):
return tensor.cuda()
else:
return tensor
class GRUUpdate(nn.Module):
def __init__(self, hidden_size):
nn.Module.__init__(self)
self.hidden_size = hidden_size
self.W_z = nn.Linear(2 * hidden_size, hidden_size)
self.W_r = nn.Linear(hidden_size, hidden_size, bias=False)
self.U_r = nn.Linear(hidden_size, hidden_size)
self.W_h = nn.Linear(2 * hidden_size, hidden_size)
def update_zm(self, node):
src_x = node.data['src_x']
s = node.data['s']
rm = node.data['accum_rm']
z = torch.sigmoid(self.W_z(torch.cat([src_x, s], 1)))
m = torch.tanh(self.W_h(torch.cat([src_x, rm], 1)))
m = (1 - z) * s + z * m
return {'m': m, 'z': z}
def update_r(self, node, zm=None):
dst_x = node.data['dst_x']
m = node.data['m'] if zm is None else zm['m']
r_1 = self.W_r(dst_x)
r_2 = self.U_r(m)
r = torch.sigmoid(r_1 + r_2)
return {'r': r, 'rm': r * m}
def forward(self, node):
dic = self.update_zm(node)
dic.update(self.update_r(node, zm=dic))
return dic
def move_dgl_to_cuda(g):
g.ndata.update({k: cuda(g.ndata[k]) for k in g.ndata})
g.edata.update({k: cuda(g.edata[k]) for k in g.edata})
"""Utilities for using pretrained models."""
import torch
from rdkit import Chem
from .dgmg import DGMG
from .gcn import GCNClassifier
from . import DGLJTNNVAE
from .mgcn import MGCNModel
from .mpnn import MPNNModel
from .sch import SchNetModel
from ...data.utils import _get_dgl_url, download
from ...data.utils import _get_dgl_url, download, get_download_dir
URL = {
'GCN_Tox21' : 'pre_trained/gcn_tox21.pth',
......@@ -15,18 +18,16 @@ URL = {
'DGMG_ChEMBL_canonical' : 'pre_trained/dgmg_ChEMBL_canonical.pth',
'DGMG_ChEMBL_random' : 'pre_trained/dgmg_ChEMBL_random.pth',
'DGMG_ZINC_canonical' : 'pre_trained/dgmg_ZINC_canonical.pth',
'DGMG_ZINC_random' : 'pre_trained/dgmg_ZINC_random.pth'
'DGMG_ZINC_random' : 'pre_trained/dgmg_ZINC_random.pth',
'JTNN_ZINC':'pre_trained/JTNN_ZINC.pth'
}
try:
from rdkit import Chem
except ImportError:
pass
def download_and_load_checkpoint(model_name, model, model_postfix,
local_pretrained_path='pre_trained.pth', log=True):
"""Download pretrained model checkpoint
The model will be loaded to CPU.
Parameters
----------
model_name : str
......@@ -48,7 +49,7 @@ def download_and_load_checkpoint(model_name, model, model_postfix,
url_to_pretrained = _get_dgl_url(model_postfix)
local_pretrained_path = '_'.join([model_name, local_pretrained_path])
download(url_to_pretrained, path=local_pretrained_path, log=log)
checkpoint = torch.load(local_pretrained_path)
checkpoint = torch.load(local_pretrained_path, map_location='cpu')
model.load_state_dict(checkpoint['model_state_dict'])
return model
......@@ -67,13 +68,14 @@ def load_pretrained(model_name, log=True):
model
"""
if model_name not in URL:
return RuntimeError("Cannot find a pretrained model with name {}".format(model_name))
raise RuntimeError("Cannot find a pretrained model with name {}".format(model_name))
if model_name == 'GCN_Tox21':
model = GCNClassifier(in_feats=74,
gcn_hidden_feats=[64, 64],
n_tasks=12,
classifier_hidden_feats=64)
elif model_name.startswith('DGMG'):
if model_name.startswith('DGMG_ChEMBL'):
atom_types = ['O', 'Cl', 'C', 'S', 'F', 'Br', 'N']
......@@ -88,13 +90,23 @@ def load_pretrained(model_name, log=True):
node_hidden_size=128,
num_prop_rounds=2,
dropout=0.2)
elif model_name == 'MGCN_Alchemy':
model = MGCNModel(norm=True, output_dim=12)
elif model_name == 'SCHNET_Alchemy':
model = SchNetModel(norm=True, output_dim=12)
elif model_name == 'MPNN_Alchemy':
model = MPNNModel(output_dim=12)
elif model_name == "JTNN_ZINC":
vocab_file = '{}/jtnn/{}.txt'.format(get_download_dir(), 'vocab')
model = DGLJTNNVAE(vocab_file=vocab_file,
depth=3,
hidden_size=450,
latent_size=56)
if log:
print('Pretrained model loaded')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment