Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
9df8cd32
Unverified
Commit
9df8cd32
authored
Sep 10, 2019
by
Mufei Li
Committed by
GitHub
Sep 10, 2019
Browse files
[Model Zoo] Fix JTNN (#843)
* Update * Update * Update * Update * Update
parent
4e0e6697
Changes
24
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
14 additions
and
20 deletions
+14
-20
python/dgl/model_zoo/chem/jtnn/jtnn_vae.py
python/dgl/model_zoo/chem/jtnn/jtnn_vae.py
+4
-11
python/dgl/model_zoo/chem/jtnn/mol_tree.py
python/dgl/model_zoo/chem/jtnn/mol_tree.py
+0
-1
python/dgl/model_zoo/chem/jtnn/mpn.py
python/dgl/model_zoo/chem/jtnn/mpn.py
+1
-6
python/dgl/model_zoo/chem/pretrain.py
python/dgl/model_zoo/chem/pretrain.py
+9
-2
No files found.
python/dgl/model_zoo/chem/jtnn/jtnn_vae.py
View file @
9df8cd32
# pylint: disable=C0111, C0103, E1101, W0611, W0612, C0200
# pylint: disable=C0111, C0103, E1101, W0611, W0612, C0200
import
copy
import
copy
import
math
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
import
rdkit
import
rdkit.Chem
as
Chem
import
rdkit.Chem
as
Chem
from
rdkit
import
DataStructs
from
rdkit.Chem
import
AllChem
from
dgl
import
batch
,
unbatch
from
dgl
import
batch
,
unbatch
from
dgl.data.utils
import
get_download_dir
from
dgl.data.utils
import
get_download_dir
...
@@ -23,7 +19,7 @@ from .jtnn_enc import DGLJTNNEncoder
...
@@ -23,7 +19,7 @@ from .jtnn_enc import DGLJTNNEncoder
from
.mol_tree
import
Vocab
from
.mol_tree
import
Vocab
from
.mpn
import
DGLMPN
from
.mpn
import
DGLMPN
from
.mpn
import
mol2dgl_single
as
mol2dgl_enc
from
.mpn
import
mol2dgl_single
as
mol2dgl_enc
from
.nnutils
import
create_var
,
cuda
,
move_dgl_to_cuda
from
.nnutils
import
cuda
,
move_dgl_to_cuda
class
DGLJTNNVAE
(
nn
.
Module
):
class
DGLJTNNVAE
(
nn
.
Module
):
...
@@ -37,11 +33,9 @@ class DGLJTNNVAE(nn.Module):
...
@@ -37,11 +33,9 @@ class DGLJTNNVAE(nn.Module):
if
vocab_file
is
None
:
if
vocab_file
is
None
:
vocab_file
=
'{}/jtnn/{}.txt'
.
format
(
vocab_file
=
'{}/jtnn/{}.txt'
.
format
(
get_download_dir
(),
'vocab'
)
get_download_dir
(),
'vocab'
)
self
.
vocab
=
Vocab
([
x
.
strip
(
"
\r\n
"
)
for
x
in
open
(
vocab_file
)])
self
.
vocab
=
Vocab
([
x
.
strip
(
"
\r\n
"
)
else
:
for
x
in
open
(
vocab_file
)])
self
.
vocab
=
Vocab
([
x
.
strip
(
"
\r\n
"
)
for
x
in
open
(
vocab_file
)])
else
:
else
:
self
.
vocab
=
vocab
self
.
vocab
=
vocab
...
@@ -125,7 +119,6 @@ class DGLJTNNVAE(nn.Module):
...
@@ -125,7 +119,6 @@ class DGLJTNNVAE(nn.Module):
assm_loss
,
assm_acc
=
self
.
assm
(
mol_batch
,
mol_tree_batch
,
mol_vec
)
assm_loss
,
assm_acc
=
self
.
assm
(
mol_batch
,
mol_tree_batch
,
mol_vec
)
stereo_loss
,
stereo_acc
=
self
.
stereo
(
mol_batch
,
mol_vec
)
stereo_loss
,
stereo_acc
=
self
.
stereo
(
mol_batch
,
mol_vec
)
all_vec
=
torch
.
cat
([
tree_vec
,
mol_vec
],
dim
=
1
)
loss
=
word_loss
+
topo_loss
+
assm_loss
+
2
*
stereo_loss
+
beta
*
kl_loss
loss
=
word_loss
+
topo_loss
+
assm_loss
+
2
*
stereo_loss
+
beta
*
kl_loss
return
loss
,
kl_loss
,
word_acc
,
topo_acc
,
assm_acc
,
stereo_acc
return
loss
,
kl_loss
,
word_acc
,
topo_acc
,
assm_acc
,
stereo_acc
...
...
python/dgl/model_zoo/chem/jtnn/mol_tree.py
View file @
9df8cd32
# pylint: disable=C0111, C0103, E1101, W0611, W0612
# pylint: disable=C0111, C0103, E1101, W0611, W0612
import
copy
import
copy
import
rdkit
import
rdkit.Chem
as
Chem
import
rdkit.Chem
as
Chem
...
...
python/dgl/model_zoo/chem/jtnn/mpn.py
View file @
9df8cd32
# pylint: disable=C0111, C0103, E1101, W0611, W0612
# pylint: disable=C0111, C0103, E1101, W0611, W0612
# pylint: disable=redefined-outer-name
# pylint: disable=redefined-outer-name
from
functools
import
partial
import
numpy
as
np
import
rdkit.Chem
as
Chem
import
rdkit.Chem
as
Chem
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
import
dgl.function
as
DGLF
import
dgl.function
as
DGLF
from
dgl
import
DGLGraph
,
batch
,
mean_nodes
,
unbatch
from
dgl
import
DGLGraph
,
mean_nodes
from
networkx
import
DiGraph
,
Graph
,
convert_node_labels_to_integers
from
.chemutils
import
get_mol
from
.chemutils
import
get_mol
# from .nnutils import *
ELEM_LIST
=
[
'C'
,
'N'
,
'O'
,
'S'
,
'F'
,
'Si'
,
'P'
,
'Cl'
,
'Br'
,
'Mg'
,
'Na'
,
ELEM_LIST
=
[
'C'
,
'N'
,
'O'
,
'S'
,
'F'
,
'Si'
,
'P'
,
'Cl'
,
'Br'
,
'Mg'
,
'Na'
,
'Ca'
,
'Fe'
,
'Al'
,
'I'
,
'B'
,
'K'
,
'Se'
,
'Zn'
,
'H'
,
'Cu'
,
'Mn'
,
'unknown'
]
'Ca'
,
'Fe'
,
'Al'
,
'I'
,
'B'
,
'K'
,
'Se'
,
'Zn'
,
'H'
,
'Cu'
,
'Mn'
,
'unknown'
]
...
...
python/dgl/model_zoo/chem/pretrain.py
View file @
9df8cd32
"""Utilities for using pretrained models."""
"""Utilities for using pretrained models."""
import
os
import
torch
import
torch
from
rdkit
import
Chem
from
rdkit
import
Chem
...
@@ -8,7 +9,7 @@ from .dgmg import DGMG
...
@@ -8,7 +9,7 @@ from .dgmg import DGMG
from
.mgcn
import
MGCNModel
from
.mgcn
import
MGCNModel
from
.mpnn
import
MPNNModel
from
.mpnn
import
MPNNModel
from
.schnet
import
SchNet
from
.schnet
import
SchNet
from
...data.utils
import
_get_dgl_url
,
download
,
get_download_dir
from
...data.utils
import
_get_dgl_url
,
download
,
get_download_dir
,
extract_archive
URL
=
{
URL
=
{
'GCN_Tox21'
:
'pre_trained/gcn_tox21.pth'
,
'GCN_Tox21'
:
'pre_trained/gcn_tox21.pth'
,
...
@@ -122,7 +123,13 @@ def load_pretrained(model_name, log=True):
...
@@ -122,7 +123,13 @@ def load_pretrained(model_name, log=True):
model
=
MPNNModel
(
output_dim
=
12
)
model
=
MPNNModel
(
output_dim
=
12
)
elif
model_name
==
"JTNN_ZINC"
:
elif
model_name
==
"JTNN_ZINC"
:
vocab_file
=
'{}/jtnn/{}.txt'
.
format
(
get_download_dir
(),
'vocab'
)
default_dir
=
get_download_dir
()
vocab_file
=
'{}/jtnn/{}.txt'
.
format
(
default_dir
,
'vocab'
)
if
not
os
.
path
.
exists
(
vocab_file
):
zip_file_path
=
'{}/jtnn.zip'
.
format
(
default_dir
)
download
(
'https://s3-ap-southeast-1.amazonaws.com/dgl-data-cn/dataset/jtnn.zip'
,
path
=
zip_file_path
)
extract_archive
(
zip_file_path
,
'{}/jtnn'
.
format
(
default_dir
))
model
=
DGLJTNNVAE
(
vocab_file
=
vocab_file
,
model
=
DGLJTNNVAE
(
vocab_file
=
vocab_file
,
depth
=
3
,
depth
=
3
,
hidden_size
=
450
,
hidden_size
=
450
,
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment