Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
9df8cd32
Unverified
Commit
9df8cd32
authored
Sep 10, 2019
by
Mufei Li
Committed by
GitHub
Sep 10, 2019
Browse files
[Model Zoo] Fix JTNN (#843)
* Update * Update * Update * Update * Update
parent
4e0e6697
Changes
24
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
14 additions
and
20 deletions
+14
-20
python/dgl/model_zoo/chem/jtnn/jtnn_vae.py
python/dgl/model_zoo/chem/jtnn/jtnn_vae.py
+4
-11
python/dgl/model_zoo/chem/jtnn/mol_tree.py
python/dgl/model_zoo/chem/jtnn/mol_tree.py
+0
-1
python/dgl/model_zoo/chem/jtnn/mpn.py
python/dgl/model_zoo/chem/jtnn/mpn.py
+1
-6
python/dgl/model_zoo/chem/pretrain.py
python/dgl/model_zoo/chem/pretrain.py
+9
-2
No files found.
python/dgl/model_zoo/chem/jtnn/jtnn_vae.py
View file @
9df8cd32
# pylint: disable=C0111, C0103, E1101, W0611, W0612, C0200
import
copy
import
math
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
rdkit
import
rdkit.Chem
as
Chem
from
rdkit
import
DataStructs
from
rdkit.Chem
import
AllChem
from
dgl
import
batch
,
unbatch
from
dgl.data.utils
import
get_download_dir
...
...
@@ -23,7 +19,7 @@ from .jtnn_enc import DGLJTNNEncoder
from
.mol_tree
import
Vocab
from
.mpn
import
DGLMPN
from
.mpn
import
mol2dgl_single
as
mol2dgl_enc
from
.nnutils
import
create_var
,
cuda
,
move_dgl_to_cuda
from
.nnutils
import
cuda
,
move_dgl_to_cuda
class
DGLJTNNVAE
(
nn
.
Module
):
...
...
@@ -37,11 +33,9 @@ class DGLJTNNVAE(nn.Module):
if
vocab_file
is
None
:
vocab_file
=
'{}/jtnn/{}.txt'
.
format
(
get_download_dir
(),
'vocab'
)
self
.
vocab
=
Vocab
([
x
.
strip
(
"
\r\n
"
)
for
x
in
open
(
vocab_file
)])
else
:
self
.
vocab
=
Vocab
([
x
.
strip
(
"
\r\n
"
)
for
x
in
open
(
vocab_file
)])
self
.
vocab
=
Vocab
([
x
.
strip
(
"
\r\n
"
)
for
x
in
open
(
vocab_file
)])
else
:
self
.
vocab
=
vocab
...
...
@@ -125,7 +119,6 @@ class DGLJTNNVAE(nn.Module):
assm_loss
,
assm_acc
=
self
.
assm
(
mol_batch
,
mol_tree_batch
,
mol_vec
)
stereo_loss
,
stereo_acc
=
self
.
stereo
(
mol_batch
,
mol_vec
)
all_vec
=
torch
.
cat
([
tree_vec
,
mol_vec
],
dim
=
1
)
loss
=
word_loss
+
topo_loss
+
assm_loss
+
2
*
stereo_loss
+
beta
*
kl_loss
return
loss
,
kl_loss
,
word_acc
,
topo_acc
,
assm_acc
,
stereo_acc
...
...
python/dgl/model_zoo/chem/jtnn/mol_tree.py
View file @
9df8cd32
# pylint: disable=C0111, C0103, E1101, W0611, W0612
import
copy
import
rdkit
import
rdkit.Chem
as
Chem
...
...
python/dgl/model_zoo/chem/jtnn/mpn.py
View file @
9df8cd32
# pylint: disable=C0111, C0103, E1101, W0611, W0612
# pylint: disable=redefined-outer-name
from
functools
import
partial
import
numpy
as
np
import
rdkit.Chem
as
Chem
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
dgl.function
as
DGLF
from
dgl
import
DGLGraph
,
batch
,
mean_nodes
,
unbatch
from
networkx
import
DiGraph
,
Graph
,
convert_node_labels_to_integers
from
dgl
import
DGLGraph
,
mean_nodes
from
.chemutils
import
get_mol
# from .nnutils import *
ELEM_LIST
=
[
'C'
,
'N'
,
'O'
,
'S'
,
'F'
,
'Si'
,
'P'
,
'Cl'
,
'Br'
,
'Mg'
,
'Na'
,
'Ca'
,
'Fe'
,
'Al'
,
'I'
,
'B'
,
'K'
,
'Se'
,
'Zn'
,
'H'
,
'Cu'
,
'Mn'
,
'unknown'
]
...
...
python/dgl/model_zoo/chem/pretrain.py
View file @
9df8cd32
"""Utilities for using pretrained models."""
import
os
import
torch
from
rdkit
import
Chem
...
...
@@ -8,7 +9,7 @@ from .dgmg import DGMG
from
.mgcn
import
MGCNModel
from
.mpnn
import
MPNNModel
from
.schnet
import
SchNet
from
...data.utils
import
_get_dgl_url
,
download
,
get_download_dir
from
...data.utils
import
_get_dgl_url
,
download
,
get_download_dir
,
extract_archive
URL
=
{
'GCN_Tox21'
:
'pre_trained/gcn_tox21.pth'
,
...
...
@@ -122,7 +123,13 @@ def load_pretrained(model_name, log=True):
model
=
MPNNModel
(
output_dim
=
12
)
elif
model_name
==
"JTNN_ZINC"
:
vocab_file
=
'{}/jtnn/{}.txt'
.
format
(
get_download_dir
(),
'vocab'
)
default_dir
=
get_download_dir
()
vocab_file
=
'{}/jtnn/{}.txt'
.
format
(
default_dir
,
'vocab'
)
if
not
os
.
path
.
exists
(
vocab_file
):
zip_file_path
=
'{}/jtnn.zip'
.
format
(
default_dir
)
download
(
'https://s3-ap-southeast-1.amazonaws.com/dgl-data-cn/dataset/jtnn.zip'
,
path
=
zip_file_path
)
extract_archive
(
zip_file_path
,
'{}/jtnn'
.
format
(
default_dir
))
model
=
DGLJTNNVAE
(
vocab_file
=
vocab_file
,
depth
=
3
,
hidden_size
=
450
,
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment