Unverified Commit 9df8cd32 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[Model Zoo] Fix JTNN (#843)

* Update

* Update

* Update

* Update

* Update
parent 4e0e6697
# pylint: disable=C0111, C0103, E1101, W0611, W0612, C0200
import copy
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import rdkit
import rdkit.Chem as Chem
from rdkit import DataStructs
from rdkit.Chem import AllChem
from dgl import batch, unbatch
from dgl.data.utils import get_download_dir
......@@ -23,7 +19,7 @@ from .jtnn_enc import DGLJTNNEncoder
from .mol_tree import Vocab
from .mpn import DGLMPN
from .mpn import mol2dgl_single as mol2dgl_enc
from .nnutils import create_var, cuda, move_dgl_to_cuda
from .nnutils import cuda, move_dgl_to_cuda
class DGLJTNNVAE(nn.Module):
......@@ -37,11 +33,9 @@ class DGLJTNNVAE(nn.Module):
if vocab_file is None:
vocab_file = '{}/jtnn/{}.txt'.format(
get_download_dir(), 'vocab')
self.vocab = Vocab([x.strip("\r\n ")
for x in open(vocab_file)])
else:
self.vocab = Vocab([x.strip("\r\n ")
for x in open(vocab_file)])
self.vocab = Vocab([x.strip("\r\n ")
for x in open(vocab_file)])
else:
self.vocab = vocab
......@@ -125,7 +119,6 @@ class DGLJTNNVAE(nn.Module):
assm_loss, assm_acc = self.assm(mol_batch, mol_tree_batch, mol_vec)
stereo_loss, stereo_acc = self.stereo(mol_batch, mol_vec)
all_vec = torch.cat([tree_vec, mol_vec], dim=1)
loss = word_loss + topo_loss + assm_loss + 2 * stereo_loss + beta * kl_loss
return loss, kl_loss, word_acc, topo_acc, assm_acc, stereo_acc
......
# pylint: disable=C0111, C0103, E1101, W0611, W0612
import copy
import rdkit
import rdkit.Chem as Chem
......
# pylint: disable=C0111, C0103, E1101, W0611, W0612
# pylint: disable=redefined-outer-name
from functools import partial
import numpy as np
import rdkit.Chem as Chem
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl.function as DGLF
from dgl import DGLGraph, batch, mean_nodes, unbatch
from networkx import DiGraph, Graph, convert_node_labels_to_integers
from dgl import DGLGraph, mean_nodes
from .chemutils import get_mol
# from .nnutils import *
ELEM_LIST = ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na',
'Ca', 'Fe', 'Al', 'I', 'B', 'K', 'Se', 'Zn', 'H', 'Cu', 'Mn', 'unknown']
......
"""Utilities for using pretrained models."""
import os
import torch
from rdkit import Chem
......@@ -8,7 +9,7 @@ from .dgmg import DGMG
from .mgcn import MGCNModel
from .mpnn import MPNNModel
from .schnet import SchNet
from ...data.utils import _get_dgl_url, download, get_download_dir
from ...data.utils import _get_dgl_url, download, get_download_dir, extract_archive
URL = {
'GCN_Tox21' : 'pre_trained/gcn_tox21.pth',
......@@ -122,7 +123,13 @@ def load_pretrained(model_name, log=True):
model = MPNNModel(output_dim=12)
elif model_name == "JTNN_ZINC":
vocab_file = '{}/jtnn/{}.txt'.format(get_download_dir(), 'vocab')
default_dir = get_download_dir()
vocab_file = '{}/jtnn/{}.txt'.format(default_dir, 'vocab')
if not os.path.exists(vocab_file):
zip_file_path = '{}/jtnn.zip'.format(default_dir)
download('https://s3-ap-southeast-1.amazonaws.com/dgl-data-cn/dataset/jtnn.zip',
path=zip_file_path)
extract_archive(zip_file_path, '{}/jtnn'.format(default_dir))
model = DGLJTNNVAE(vocab_file=vocab_file,
depth=3,
hidden_size=450,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment