Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
12d70630
Commit
12d70630
authored
Jan 31, 2019
by
Quan (Andy) Gan
Committed by
Minjie Wang
Jan 31, 2019
Browse files
[Hotfix] fixing zero shaped tensor problems for PyTorch 1.0.0 in JTNN example (#371)
parent
dedfd908
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
11 additions
and
5 deletions
+11
-5
examples/pytorch/jtnn/jtnn/datautils.py
examples/pytorch/jtnn/jtnn/datautils.py
+5
-3
examples/pytorch/jtnn/jtnn/jtnn_dec.py
examples/pytorch/jtnn/jtnn/jtnn_dec.py
+6
-2
No files found.
examples/pytorch/jtnn/jtnn/datautils.py
View file @
12d70630
...
...
@@ -9,6 +9,8 @@ from .mol_tree import Vocab
from
.mpn
import
mol2dgl_single
as
mol2dgl_enc
from
.jtmpn
import
mol2dgl_single
as
mol2dgl_dec
from
.jtmpn
import
ATOM_FDIM
as
ATOM_FDIM_DEC
from
.jtmpn
import
BOND_FDIM
as
BOND_FDIM_DEC
_url
=
'https://www.dropbox.com/s/4ypr0e0abcbsvoh/jtnn.zip?dl=1'
...
...
@@ -82,11 +84,11 @@ class JTNNDataset(Dataset):
tree_mess_tgt_e
,
tree_mess_tgt_n
=
mol2dgl_dec
(
cands
)
else
:
cand_graphs
=
[]
atom_x_dec
=
torch
.
zeros
(
0
,
atom_x_enc
.
shape
[
1
]
)
bond_x_dec
=
torch
.
zeros
(
0
,
bond_x_enc
.
shape
[
1
]
)
atom_x_dec
=
torch
.
zeros
(
0
,
ATOM_FDIM_DEC
)
bond_x_dec
=
torch
.
zeros
(
0
,
BOND_FDIM_DEC
)
tree_mess_src_e
=
torch
.
zeros
(
0
,
2
).
long
()
tree_mess_tgt_e
=
torch
.
zeros
(
0
,
2
).
long
()
tree_mess_tgt_n
=
torch
.
zeros
(
0
,
2
).
long
()
tree_mess_tgt_n
=
torch
.
zeros
(
0
).
long
()
# prebuild the stereoisomers
cands
=
mol_tree
.
stereo_cands
...
...
examples/pytorch/jtnn/jtnn/jtnn_dec.py
View file @
12d70630
...
...
@@ -199,8 +199,12 @@ class DGLJTNNDecoder(nn.Module):
p_inputs
.
append
(
torch
.
cat
([
x
,
h
,
tree_vec_set
],
1
))
# Only newly generated nodes are needed for label prediction
# NOTE: The following works since the uncomputed messages are zeros.
q_inputs
.
append
(
torch
.
cat
([
h
,
tree_vec_set
],
1
)[
is_new
])
q_targets
.
append
(
wid
[
is_new
])
q_input
=
torch
.
cat
([
h
,
tree_vec_set
],
1
)[
is_new
]
q_target
=
wid
[
is_new
]
if
q_input
.
shape
[
0
]
>
0
:
q_inputs
.
append
(
q_input
)
q_targets
.
append
(
q_target
)
p_targets
.
append
(
torch
.
zeros
((
root_out_degrees
==
0
).
sum
()).
long
())
# Batch compute the stop/label prediction losses
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment