Unverified Commit 9df8cd32 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[Model Zoo] Fix JTNN (#843)

* Update

* Update

* Update

* Update

* Update
parent 4e0e6697
# pylint: disable=C0111, C0103, E1101, W0611, W0612, C0200 # pylint: disable=C0111, C0103, E1101, W0611, W0612, C0200
import copy import copy
import math
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import rdkit
import rdkit.Chem as Chem import rdkit.Chem as Chem
from rdkit import DataStructs
from rdkit.Chem import AllChem
from dgl import batch, unbatch from dgl import batch, unbatch
from dgl.data.utils import get_download_dir from dgl.data.utils import get_download_dir
...@@ -23,7 +19,7 @@ from .jtnn_enc import DGLJTNNEncoder ...@@ -23,7 +19,7 @@ from .jtnn_enc import DGLJTNNEncoder
from .mol_tree import Vocab from .mol_tree import Vocab
from .mpn import DGLMPN from .mpn import DGLMPN
from .mpn import mol2dgl_single as mol2dgl_enc from .mpn import mol2dgl_single as mol2dgl_enc
from .nnutils import create_var, cuda, move_dgl_to_cuda from .nnutils import cuda, move_dgl_to_cuda
class DGLJTNNVAE(nn.Module): class DGLJTNNVAE(nn.Module):
...@@ -37,11 +33,9 @@ class DGLJTNNVAE(nn.Module): ...@@ -37,11 +33,9 @@ class DGLJTNNVAE(nn.Module):
if vocab_file is None: if vocab_file is None:
vocab_file = '{}/jtnn/{}.txt'.format( vocab_file = '{}/jtnn/{}.txt'.format(
get_download_dir(), 'vocab') get_download_dir(), 'vocab')
self.vocab = Vocab([x.strip("\r\n ")
for x in open(vocab_file)]) self.vocab = Vocab([x.strip("\r\n ")
else: for x in open(vocab_file)])
self.vocab = Vocab([x.strip("\r\n ")
for x in open(vocab_file)])
else: else:
self.vocab = vocab self.vocab = vocab
...@@ -125,7 +119,6 @@ class DGLJTNNVAE(nn.Module): ...@@ -125,7 +119,6 @@ class DGLJTNNVAE(nn.Module):
assm_loss, assm_acc = self.assm(mol_batch, mol_tree_batch, mol_vec) assm_loss, assm_acc = self.assm(mol_batch, mol_tree_batch, mol_vec)
stereo_loss, stereo_acc = self.stereo(mol_batch, mol_vec) stereo_loss, stereo_acc = self.stereo(mol_batch, mol_vec)
all_vec = torch.cat([tree_vec, mol_vec], dim=1)
loss = word_loss + topo_loss + assm_loss + 2 * stereo_loss + beta * kl_loss loss = word_loss + topo_loss + assm_loss + 2 * stereo_loss + beta * kl_loss
return loss, kl_loss, word_acc, topo_acc, assm_acc, stereo_acc return loss, kl_loss, word_acc, topo_acc, assm_acc, stereo_acc
......
# pylint: disable=C0111, C0103, E1101, W0611, W0612 # pylint: disable=C0111, C0103, E1101, W0611, W0612
import copy import copy
import rdkit
import rdkit.Chem as Chem import rdkit.Chem as Chem
......
# pylint: disable=C0111, C0103, E1101, W0611, W0612 # pylint: disable=C0111, C0103, E1101, W0611, W0612
# pylint: disable=redefined-outer-name # pylint: disable=redefined-outer-name
from functools import partial
import numpy as np
import rdkit.Chem as Chem import rdkit.Chem as Chem
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import dgl.function as DGLF import dgl.function as DGLF
from dgl import DGLGraph, batch, mean_nodes, unbatch from dgl import DGLGraph, mean_nodes
from networkx import DiGraph, Graph, convert_node_labels_to_integers
from .chemutils import get_mol from .chemutils import get_mol
# from .nnutils import *
ELEM_LIST = ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na', ELEM_LIST = ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na',
'Ca', 'Fe', 'Al', 'I', 'B', 'K', 'Se', 'Zn', 'H', 'Cu', 'Mn', 'unknown'] 'Ca', 'Fe', 'Al', 'I', 'B', 'K', 'Se', 'Zn', 'H', 'Cu', 'Mn', 'unknown']
......
"""Utilities for using pretrained models.""" """Utilities for using pretrained models."""
import os
import torch import torch
from rdkit import Chem from rdkit import Chem
...@@ -8,7 +9,7 @@ from .dgmg import DGMG ...@@ -8,7 +9,7 @@ from .dgmg import DGMG
from .mgcn import MGCNModel from .mgcn import MGCNModel
from .mpnn import MPNNModel from .mpnn import MPNNModel
from .schnet import SchNet from .schnet import SchNet
from ...data.utils import _get_dgl_url, download, get_download_dir from ...data.utils import _get_dgl_url, download, get_download_dir, extract_archive
URL = { URL = {
'GCN_Tox21' : 'pre_trained/gcn_tox21.pth', 'GCN_Tox21' : 'pre_trained/gcn_tox21.pth',
...@@ -122,7 +123,13 @@ def load_pretrained(model_name, log=True): ...@@ -122,7 +123,13 @@ def load_pretrained(model_name, log=True):
model = MPNNModel(output_dim=12) model = MPNNModel(output_dim=12)
elif model_name == "JTNN_ZINC": elif model_name == "JTNN_ZINC":
vocab_file = '{}/jtnn/{}.txt'.format(get_download_dir(), 'vocab') default_dir = get_download_dir()
vocab_file = '{}/jtnn/{}.txt'.format(default_dir, 'vocab')
if not os.path.exists(vocab_file):
zip_file_path = '{}/jtnn.zip'.format(default_dir)
download('https://s3-ap-southeast-1.amazonaws.com/dgl-data-cn/dataset/jtnn.zip',
path=zip_file_path)
extract_archive(zip_file_path, '{}/jtnn'.format(default_dir))
model = DGLJTNNVAE(vocab_file=vocab_file, model = DGLJTNNVAE(vocab_file=vocab_file,
depth=3, depth=3,
hidden_size=450, hidden_size=450,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment