"docs/source/vscode:/vscode.git/clone" did not exist on "ca265374bb0c8b01babf4d59c00ba5ad20adfb5a"
Commit 3192beb4 authored by VoVAllen's avatar VoVAllen Committed by Mufei Li
Browse files

[Model zoo] JTNN model zoo (#790)

* jtnn model zoo

* poke ci

* fix line sep

* fix

* Fix import order

* fix render

* fix render

* revert

* fix

* Resolve conflict

* dix

* remove create_var

* refactor

* fix

* refactor

* readme

* format

* fix lint

* fix lint

* pylint

* lint

* fix lint

* fix lint

* add hint

* fix

* Remove vocab

* Add explanation for warning

* add directory

* Load model to cpu by default

* Update
parent 30c46251
# pylint: disable=C0111, C0103, E1101, W0611, W0612
import numpy as np
import rdkit.Chem as Chem
from dgl import DGLGraph
from .chemutils import (decode_stereo, enum_assemble_nx, get_clique_mol,
get_mol, get_smiles, set_atommap, tree_decomp)
class DGLMolTree(DGLGraph):
def __init__(self, smiles):
DGLGraph.__init__(self)
self.nodes_dict = {}
if smiles is None:
return
self.smiles = smiles
self.mol = get_mol(smiles)
# Stereo Generation
mol = Chem.MolFromSmiles(smiles)
self.smiles3D = Chem.MolToSmiles(mol, isomericSmiles=True)
self.smiles2D = Chem.MolToSmiles(mol)
self.stereo_cands = decode_stereo(self.smiles2D)
# cliques: a list of list of atom indices
cliques, edges = tree_decomp(self.mol)
root = 0
for i, c in enumerate(cliques):
cmol = get_clique_mol(self.mol, c)
csmiles = get_smiles(cmol)
self.nodes_dict[i] = dict(
smiles=csmiles,
mol=get_mol(csmiles),
clique=c,
)
if min(c) == 0:
root = i
self.add_nodes(len(cliques))
# The clique with atom ID 0 becomes root
if root > 0:
for attr in self.nodes_dict[0]:
self.nodes_dict[0][attr], self.nodes_dict[root][attr] = \
self.nodes_dict[root][attr], self.nodes_dict[0][attr]
src = np.zeros((len(edges) * 2,), dtype='int')
dst = np.zeros((len(edges) * 2,), dtype='int')
for i, (_x, _y) in enumerate(edges):
x = 0 if _x == root else root if _x == 0 else _x
y = 0 if _y == root else root if _y == 0 else _y
src[2 * i] = x
dst[2 * i] = y
src[2 * i + 1] = y
dst[2 * i + 1] = x
self.add_edges(src, dst)
for i in self.nodes_dict:
self.nodes_dict[i]['nid'] = i + 1
if self.out_degree(i) > 1: # Leaf node mol is not marked
set_atommap(self.nodes_dict[i]['mol'],
self.nodes_dict[i]['nid'])
self.nodes_dict[i]['is_leaf'] = (self.out_degree(i) == 1)
def treesize(self):
return self.number_of_nodes()
def _recover_node(self, i, original_mol):
node = self.nodes_dict[i]
clique = []
clique.extend(node['clique'])
if not node['is_leaf']:
for cidx in node['clique']:
original_mol.GetAtomWithIdx(cidx).SetAtomMapNum(node['nid'])
for j in self.successors(i).numpy():
nei_node = self.nodes_dict[j]
clique.extend(nei_node['clique'])
if nei_node['is_leaf']: # Leaf node, no need to mark
continue
for cidx in nei_node['clique']:
# allow singleton node override the atom mapping
if cidx not in node['clique'] or len(nei_node['clique']) == 1:
atom = original_mol.GetAtomWithIdx(cidx)
atom.SetAtomMapNum(nei_node['nid'])
clique = list(set(clique))
label_mol = get_clique_mol(original_mol, clique)
node['label'] = Chem.MolToSmiles(
Chem.MolFromSmiles(get_smiles(label_mol)))
node['label_mol'] = get_mol(node['label'])
for cidx in clique:
original_mol.GetAtomWithIdx(cidx).SetAtomMapNum(0)
return node['label']
def _assemble_node(self, i):
neighbors = [self.nodes_dict[j] for j in self.successors(i).numpy()
if self.nodes_dict[j]['mol'].GetNumAtoms() > 1]
neighbors = sorted(
neighbors, key=lambda x: x['mol'].GetNumAtoms(), reverse=True)
singletons = [self.nodes_dict[j] for j in self.successors(i).numpy()
if self.nodes_dict[j]['mol'].GetNumAtoms() == 1]
neighbors = singletons + neighbors
cands = enum_assemble_nx(self.nodes_dict[i], neighbors)
if len(cands) > 0:
self.nodes_dict[i]['cands'], self.nodes_dict[i]['cand_mols'], _ = list(
zip(*cands))
self.nodes_dict[i]['cands'] = list(self.nodes_dict[i]['cands'])
self.nodes_dict[i]['cand_mols'] = list(
self.nodes_dict[i]['cand_mols'])
else:
self.nodes_dict[i]['cands'] = []
self.nodes_dict[i]['cand_mols'] = []
def recover(self):
for i in self.nodes_dict:
self._recover_node(i, self.mol)
def assemble(self):
for i in self.nodes_dict:
self._assemble_node(i)
# pylint: disable=C0111, C0103, E1101, W0611, W0612
# pylint: disable=redefined-outer-name
from functools import partial
import numpy as np
import rdkit.Chem as Chem
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl.function as DGLF
from dgl import DGLGraph, batch, mean_nodes, unbatch
from networkx import DiGraph, Graph, convert_node_labels_to_integers
from .chemutils import get_mol
# from .nnutils import *
ELEM_LIST = ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na',
'Ca', 'Fe', 'Al', 'I', 'B', 'K', 'Se', 'Zn', 'H', 'Cu', 'Mn', 'unknown']
ATOM_FDIM = len(ELEM_LIST) + 6 + 5 + 4 + 1
BOND_FDIM = 5 + 6
MAX_NB = 6
def onek_encoding_unk(x, allowable_set):
if x not in allowable_set:
x = allowable_set[-1]
return [x == s for s in allowable_set]
def atom_features(atom):
return (torch.Tensor(onek_encoding_unk(atom.GetSymbol(), ELEM_LIST)
+ onek_encoding_unk(atom.GetDegree(),
[0, 1, 2, 3, 4, 5])
+ onek_encoding_unk(atom.GetFormalCharge(),
[-1, -2, 1, 2, 0])
+ onek_encoding_unk(int(atom.GetChiralTag()), [0, 1, 2, 3])
+ [atom.GetIsAromatic()]))
def bond_features(bond):
bt = bond.GetBondType()
stereo = int(bond.GetStereo())
fbond = [bt == Chem.rdchem.BondType.SINGLE, bt == Chem.rdchem.BondType.DOUBLE, bt ==
Chem.rdchem.BondType.TRIPLE, bt == Chem.rdchem.BondType.AROMATIC, bond.IsInRing()]
fstereo = onek_encoding_unk(stereo, [0, 1, 2, 3, 4, 5])
return torch.Tensor(fbond + fstereo)
def mol2dgl_single(smiles):
n_edges = 0
atom_x = []
bond_x = []
mol = get_mol(smiles)
n_atoms = mol.GetNumAtoms()
n_bonds = mol.GetNumBonds()
graph = DGLGraph()
for i, atom in enumerate(mol.GetAtoms()):
assert i == atom.GetIdx()
atom_x.append(atom_features(atom))
graph.add_nodes(n_atoms)
bond_src = []
bond_dst = []
for i, bond in enumerate(mol.GetBonds()):
begin_idx = bond.GetBeginAtom().GetIdx()
end_idx = bond.GetEndAtom().GetIdx()
features = bond_features(bond)
bond_src.append(begin_idx)
bond_dst.append(end_idx)
bond_x.append(features)
# set up the reverse direction
bond_src.append(end_idx)
bond_dst.append(begin_idx)
bond_x.append(features)
graph.add_edges(bond_src, bond_dst)
n_edges += n_bonds
return graph, torch.stack(atom_x), \
torch.stack(bond_x) if len(bond_x) > 0 else torch.zeros(0)
mpn_loopy_bp_msg = DGLF.copy_src(src='msg', out='msg')
mpn_loopy_bp_reduce = DGLF.sum(msg='msg', out='accum_msg')
class LoopyBPUpdate(nn.Module):
def __init__(self, hidden_size):
super(LoopyBPUpdate, self).__init__()
self.hidden_size = hidden_size
self.W_h = nn.Linear(hidden_size, hidden_size, bias=False)
def forward(self, nodes):
msg_input = nodes.data['msg_input']
msg_delta = self.W_h(nodes.data['accum_msg'])
msg = F.relu(msg_input + msg_delta)
return {'msg': msg}
mpn_gather_msg = DGLF.copy_edge(edge='msg', out='msg')
mpn_gather_reduce = DGLF.sum(msg='msg', out='m')
class GatherUpdate(nn.Module):
def __init__(self, hidden_size):
super(GatherUpdate, self).__init__()
self.hidden_size = hidden_size
self.W_o = nn.Linear(ATOM_FDIM + hidden_size, hidden_size)
def forward(self, nodes):
m = nodes.data['m']
return {
'h': F.relu(self.W_o(torch.cat([nodes.data['x'], m], 1))),
}
class DGLMPN(nn.Module):
def __init__(self, hidden_size, depth):
super(DGLMPN, self).__init__()
self.depth = depth
self.W_i = nn.Linear(ATOM_FDIM + BOND_FDIM, hidden_size, bias=False)
self.loopy_bp_updater = LoopyBPUpdate(hidden_size)
self.gather_updater = GatherUpdate(hidden_size)
self.hidden_size = hidden_size
self.n_samples_total = 0
self.n_nodes_total = 0
self.n_edges_total = 0
self.n_passes = 0
def forward(self, mol_graph):
n_samples = mol_graph.batch_size
mol_line_graph = mol_graph.line_graph(backtracking=False, shared=True)
n_nodes = mol_graph.number_of_nodes()
n_edges = mol_graph.number_of_edges()
mol_graph = self.run(mol_graph, mol_line_graph)
# TODO: replace with unbatch or readout
g_repr = mean_nodes(mol_graph, 'h')
self.n_samples_total += n_samples
self.n_nodes_total += n_nodes
self.n_edges_total += n_edges
self.n_passes += 1
return g_repr
def run(self, mol_graph, mol_line_graph):
n_nodes = mol_graph.number_of_nodes()
mol_graph.apply_edges(
func=lambda edges: {'src_x': edges.src['x']},
)
e_repr = mol_line_graph.ndata
bond_features = e_repr['x']
source_features = e_repr['src_x']
features = torch.cat([source_features, bond_features], 1)
msg_input = self.W_i(features)
mol_line_graph.ndata.update({
'msg_input': msg_input,
'msg': F.relu(msg_input),
'accum_msg': torch.zeros_like(msg_input),
})
mol_graph.ndata.update({
'm': bond_features.new(n_nodes, self.hidden_size).zero_(),
'h': bond_features.new(n_nodes, self.hidden_size).zero_(),
})
for i in range(self.depth - 1):
mol_line_graph.update_all(
mpn_loopy_bp_msg,
mpn_loopy_bp_reduce,
self.loopy_bp_updater,
)
mol_graph.update_all(
mpn_gather_msg,
mpn_gather_reduce,
self.gather_updater,
)
return mol_graph
# pylint: disable=C0111, C0103, E1101, W0611, W0612
import os
import torch
import torch.nn as nn
from torch.autograd import Variable
def create_var(tensor, requires_grad=None):
if requires_grad is None:
return Variable(tensor)
else:
return Variable(tensor, requires_grad=requires_grad)
def cuda(tensor):
if torch.cuda.is_available() and not os.getenv('NOCUDA', None):
return tensor.cuda()
else:
return tensor
class GRUUpdate(nn.Module):
def __init__(self, hidden_size):
nn.Module.__init__(self)
self.hidden_size = hidden_size
self.W_z = nn.Linear(2 * hidden_size, hidden_size)
self.W_r = nn.Linear(hidden_size, hidden_size, bias=False)
self.U_r = nn.Linear(hidden_size, hidden_size)
self.W_h = nn.Linear(2 * hidden_size, hidden_size)
def update_zm(self, node):
src_x = node.data['src_x']
s = node.data['s']
rm = node.data['accum_rm']
z = torch.sigmoid(self.W_z(torch.cat([src_x, s], 1)))
m = torch.tanh(self.W_h(torch.cat([src_x, rm], 1)))
m = (1 - z) * s + z * m
return {'m': m, 'z': z}
def update_r(self, node, zm=None):
dst_x = node.data['dst_x']
m = node.data['m'] if zm is None else zm['m']
r_1 = self.W_r(dst_x)
r_2 = self.U_r(m)
r = torch.sigmoid(r_1 + r_2)
return {'r': r, 'rm': r * m}
def forward(self, node):
dic = self.update_zm(node)
dic.update(self.update_r(node, zm=dic))
return dic
def move_dgl_to_cuda(g):
g.ndata.update({k: cuda(g.ndata[k]) for k in g.ndata})
g.edata.update({k: cuda(g.edata[k]) for k in g.edata})
"""Utilities for using pretrained models.""" """Utilities for using pretrained models."""
import torch import torch
from rdkit import Chem
from .dgmg import DGMG from .dgmg import DGMG
from .gcn import GCNClassifier from .gcn import GCNClassifier
from . import DGLJTNNVAE
from .mgcn import MGCNModel from .mgcn import MGCNModel
from .mpnn import MPNNModel from .mpnn import MPNNModel
from .sch import SchNetModel from .sch import SchNetModel
from ...data.utils import _get_dgl_url, download from ...data.utils import _get_dgl_url, download, get_download_dir
URL = { URL = {
'GCN_Tox21' : 'pre_trained/gcn_tox21.pth', 'GCN_Tox21' : 'pre_trained/gcn_tox21.pth',
...@@ -15,18 +18,16 @@ URL = { ...@@ -15,18 +18,16 @@ URL = {
'DGMG_ChEMBL_canonical' : 'pre_trained/dgmg_ChEMBL_canonical.pth', 'DGMG_ChEMBL_canonical' : 'pre_trained/dgmg_ChEMBL_canonical.pth',
'DGMG_ChEMBL_random' : 'pre_trained/dgmg_ChEMBL_random.pth', 'DGMG_ChEMBL_random' : 'pre_trained/dgmg_ChEMBL_random.pth',
'DGMG_ZINC_canonical' : 'pre_trained/dgmg_ZINC_canonical.pth', 'DGMG_ZINC_canonical' : 'pre_trained/dgmg_ZINC_canonical.pth',
'DGMG_ZINC_random' : 'pre_trained/dgmg_ZINC_random.pth' 'DGMG_ZINC_random' : 'pre_trained/dgmg_ZINC_random.pth',
'JTNN_ZINC':'pre_trained/JTNN_ZINC.pth'
} }
try:
from rdkit import Chem
except ImportError:
pass
def download_and_load_checkpoint(model_name, model, model_postfix, def download_and_load_checkpoint(model_name, model, model_postfix,
local_pretrained_path='pre_trained.pth', log=True): local_pretrained_path='pre_trained.pth', log=True):
"""Download pretrained model checkpoint """Download pretrained model checkpoint
The model will be loaded to CPU.
Parameters Parameters
---------- ----------
model_name : str model_name : str
...@@ -48,7 +49,7 @@ def download_and_load_checkpoint(model_name, model, model_postfix, ...@@ -48,7 +49,7 @@ def download_and_load_checkpoint(model_name, model, model_postfix,
url_to_pretrained = _get_dgl_url(model_postfix) url_to_pretrained = _get_dgl_url(model_postfix)
local_pretrained_path = '_'.join([model_name, local_pretrained_path]) local_pretrained_path = '_'.join([model_name, local_pretrained_path])
download(url_to_pretrained, path=local_pretrained_path, log=log) download(url_to_pretrained, path=local_pretrained_path, log=log)
checkpoint = torch.load(local_pretrained_path) checkpoint = torch.load(local_pretrained_path, map_location='cpu')
model.load_state_dict(checkpoint['model_state_dict']) model.load_state_dict(checkpoint['model_state_dict'])
return model return model
...@@ -67,13 +68,14 @@ def load_pretrained(model_name, log=True): ...@@ -67,13 +68,14 @@ def load_pretrained(model_name, log=True):
model model
""" """
if model_name not in URL: if model_name not in URL:
return RuntimeError("Cannot find a pretrained model with name {}".format(model_name)) raise RuntimeError("Cannot find a pretrained model with name {}".format(model_name))
if model_name == 'GCN_Tox21': if model_name == 'GCN_Tox21':
model = GCNClassifier(in_feats=74, model = GCNClassifier(in_feats=74,
gcn_hidden_feats=[64, 64], gcn_hidden_feats=[64, 64],
n_tasks=12, n_tasks=12,
classifier_hidden_feats=64) classifier_hidden_feats=64)
elif model_name.startswith('DGMG'): elif model_name.startswith('DGMG'):
if model_name.startswith('DGMG_ChEMBL'): if model_name.startswith('DGMG_ChEMBL'):
atom_types = ['O', 'Cl', 'C', 'S', 'F', 'Br', 'N'] atom_types = ['O', 'Cl', 'C', 'S', 'F', 'Br', 'N']
...@@ -88,13 +90,23 @@ def load_pretrained(model_name, log=True): ...@@ -88,13 +90,23 @@ def load_pretrained(model_name, log=True):
node_hidden_size=128, node_hidden_size=128,
num_prop_rounds=2, num_prop_rounds=2,
dropout=0.2) dropout=0.2)
elif model_name == 'MGCN_Alchemy': elif model_name == 'MGCN_Alchemy':
model = MGCNModel(norm=True, output_dim=12) model = MGCNModel(norm=True, output_dim=12)
elif model_name == 'SCHNET_Alchemy': elif model_name == 'SCHNET_Alchemy':
model = SchNetModel(norm=True, output_dim=12) model = SchNetModel(norm=True, output_dim=12)
elif model_name == 'MPNN_Alchemy': elif model_name == 'MPNN_Alchemy':
model = MPNNModel(output_dim=12) model = MPNNModel(output_dim=12)
elif model_name == "JTNN_ZINC":
vocab_file = '{}/jtnn/{}.txt'.format(get_download_dir(), 'vocab')
model = DGLJTNNVAE(vocab_file=vocab_file,
depth=3,
hidden_size=450,
latent_size=56)
if log: if log:
print('Pretrained model loaded') print('Pretrained model loaded')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment