Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
cbbbbde7
Unverified
Commit
cbbbbde7
authored
Oct 15, 2020
by
Quan (Andy) Gan
Committed by
GitHub
Oct 15, 2020
Browse files
[Bug] fix multiple bugs in JTNN example (#2220)
* [Bug] fix multiple bugs in JTNN example * remove debug code
parent
d628f5a2
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
10 additions
and
9 deletions
+10
-9
examples/pytorch/jtnn/jtnn/jtmpn.py
examples/pytorch/jtnn/jtnn/jtmpn.py
+1
-0
examples/pytorch/jtnn/jtnn/jtnn_dec.py
examples/pytorch/jtnn/jtnn/jtnn_dec.py
+5
-6
examples/pytorch/jtnn/jtnn/jtnn_enc.py
examples/pytorch/jtnn/jtnn/jtnn_enc.py
+2
-2
examples/pytorch/jtnn/jtnn/mpn.py
examples/pytorch/jtnn/jtnn/mpn.py
+1
-0
examples/pytorch/jtnn/vaetrain_dgl.py
examples/pytorch/jtnn/vaetrain_dgl.py
+1
-1
No files found.
examples/pytorch/jtnn/jtnn/jtmpn.py
View file @
cbbbbde7
...
...
@@ -195,6 +195,7 @@ class DGLJTMPN(nn.Module):
cand_graphs
.
apply_edges
(
func
=
lambda
edges
:
{
'src_x'
:
edges
.
src
[
'x'
]},
)
cand_line_graph
.
ndata
.
update
(
cand_graphs
.
edata
)
bond_features
=
cand_line_graph
.
ndata
[
'x'
]
source_features
=
cand_line_graph
.
ndata
[
'src_x'
]
...
...
examples/pytorch/jtnn/jtnn/jtnn_dec.py
View file @
cbbbbde7
...
...
@@ -143,7 +143,7 @@ class DGLJTNNDecoder(nn.Module):
# Predict root
mol_tree_batch
.
pull
(
root_ids
,
DGLF
.
copy_e
(
'm'
,
'm'
),
DGLF
.
sum
(
'm'
,
'h'
))
mol_tree_batch
.
apply_nodes
(
dec_tree_node_update
)
mol_tree_batch
.
apply_nodes
(
dec_tree_node_update
,
v
=
root_ids
)
# Extract hidden states and store them for stop/label prediction
h
=
mol_tree_batch
.
nodes
[
root_ids
].
data
[
'h'
]
x
=
mol_tree_batch
.
nodes
[
root_ids
].
data
[
'x'
]
...
...
@@ -170,12 +170,12 @@ class DGLJTNNDecoder(nn.Module):
mol_tree_batch_lg
.
ndata
.
update
(
mol_tree_batch
.
edata
)
mol_tree_batch_lg
.
pull
(
eid
,
DGLF
.
copy_u
(
'm'
,
'm'
),
DGLF
.
sum
(
'm'
,
's'
))
mol_tree_batch_lg
.
pull
(
eid
,
DGLF
.
copy_u
(
'rm'
,
'rm'
),
DGLF
.
sum
(
'rm'
,
'accum_rm'
))
mol_tree_batch_lg
.
apply_nodes
(
self
.
dec_tree_edge_update
)
mol_tree_batch_lg
.
apply_nodes
(
self
.
dec_tree_edge_update
,
v
=
eid
)
mol_tree_batch
.
edata
.
update
(
mol_tree_batch_lg
.
ndata
)
is_new
=
mol_tree_batch
.
nodes
[
v
].
data
[
'new'
]
mol_tree_batch
.
pull
(
v
,
DGLF
.
copy_e
(
'm'
,
'm'
),
DGLF
.
sum
(
'm'
,
'h'
))
mol_tree_batch
.
apply_nodes
(
dec_tree_node_update
)
mol_tree_batch
.
apply_nodes
(
dec_tree_node_update
,
v
=
v
)
# Extract
n_repr
=
mol_tree_batch
.
nodes
[
v
].
data
...
...
@@ -262,7 +262,6 @@ class DGLJTNNDecoder(nn.Module):
# Predict stop
p_input
=
torch
.
cat
([
x
,
h
,
mol_vec
],
1
)
p_score
=
torch
.
sigmoid
(
self
.
U_s
(
torch
.
relu
(
self
.
U
(
p_input
))))
p_score
[:]
=
0
backtrack
=
(
p_score
.
item
()
<
0.5
)
if
not
backtrack
:
...
...
@@ -302,7 +301,7 @@ class DGLJTNNDecoder(nn.Module):
uv
,
DGLF
.
copy_u
(
'rm'
,
'rm'
),
DGLF
.
sum
(
'rm'
,
'accum_rm'
))
mol_tree_graph_lg
.
apply_nodes
(
self
.
dec_tree_edge_update
.
update_zm
)
mol_tree_graph_lg
.
apply_nodes
(
self
.
dec_tree_edge_update
.
update_zm
,
v
=
uv
)
mol_tree_graph
.
edata
.
update
(
mol_tree_graph_lg
.
ndata
)
mol_tree_graph
.
pull
(
v
,
DGLF
.
copy_e
(
'm'
,
'm'
),
DGLF
.
sum
(
'm'
,
'h'
))
...
...
@@ -358,7 +357,7 @@ class DGLJTNNDecoder(nn.Module):
mol_tree_graph_lg
.
pull
(
u_pu
,
DGLF
.
copy_u
(
'm'
,
'm'
),
DGLF
.
sum
(
'm'
,
's'
))
mol_tree_graph_lg
.
pull
(
u_pu
,
DGLF
.
copy_u
(
'rm'
,
'rm'
),
DGLF
.
sum
(
'rm'
,
'accum_rm'
))
mol_tree_graph_lg
.
apply_nodes
(
self
.
dec_tree_edge_update
)
mol_tree_graph_lg
.
apply_nodes
(
self
.
dec_tree_edge_update
,
v
=
u_pu
)
mol_tree_graph
.
edata
.
update
(
mol_tree_graph_lg
.
ndata
)
mol_tree_graph
.
pull
(
pu
,
DGLF
.
copy_e
(
'm'
,
'm'
),
DGLF
.
sum
(
'm'
,
'h'
))
stack
.
pop
()
...
...
examples/pytorch/jtnn/jtnn/jtnn_enc.py
View file @
cbbbbde7
...
...
@@ -99,8 +99,8 @@ class DGLJTNNEncoder(nn.Module):
for
eid
in
level_order
(
mol_tree_batch
,
root_ids
):
eid
=
eid
.
to
(
mol_tree_batch_lg
.
device
)
mol_tree_batch_lg
.
pull
(
eid
,
DGLF
.
copy_u
(
'm'
,
'm'
),
DGLF
.
sum
(
'm'
,
's'
))
mol_tree_batch_lg
.
pull
(
eid
,
DGLF
.
copy_u
(
'rm'
,
'rm'
),
DGLF
.
sum
(
'rm'
,
'rm'
))
mol_tree_batch_lg
.
apply_nodes
(
self
.
enc_tree_update
)
mol_tree_batch_lg
.
pull
(
eid
,
DGLF
.
copy_u
(
'rm'
,
'rm'
),
DGLF
.
sum
(
'rm'
,
'
accum_
rm'
))
mol_tree_batch_lg
.
apply_nodes
(
self
.
enc_tree_update
,
v
=
eid
)
# Readout
mol_tree_batch
.
edata
.
update
(
mol_tree_batch_lg
.
ndata
)
...
...
examples/pytorch/jtnn/jtnn/mpn.py
View file @
cbbbbde7
...
...
@@ -136,6 +136,7 @@ class DGLMPN(nn.Module):
mol_graph
.
apply_edges
(
func
=
lambda
edges
:
{
'src_x'
:
edges
.
src
[
'x'
]},
)
mol_line_graph
.
ndata
.
update
(
mol_graph
.
edata
)
e_repr
=
mol_line_graph
.
ndata
bond_features
=
e_repr
[
'x'
]
...
...
examples/pytorch/jtnn/vaetrain_dgl.py
View file @
cbbbbde7
...
...
@@ -70,7 +70,7 @@ def train():
dataset
,
batch_size
=
batch_size
,
shuffle
=
True
,
num_workers
=
0
,
num_workers
=
4
,
collate_fn
=
JTNNCollator
(
vocab
,
True
),
drop_last
=
True
,
worker_init_fn
=
worker_init_fn
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment