Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
68a978d4
Unverified
Commit
68a978d4
authored
Jul 30, 2020
by
Quan (Andy) Gan
Committed by
GitHub
Jul 30, 2020
Browse files
[Model] Fixes JTNN for 0.5 (#1879)
* jtnn and fixes * make metagraph a method * fix test * fix
parent
faa1dc56
Changes
19
Hide whitespace changes
Inline
Side-by-side
Showing
19 changed files
with
175 additions
and
165 deletions
+175
-165
examples/pytorch/jtnn/README.md
examples/pytorch/jtnn/README.md
+1
-1
examples/pytorch/jtnn/jtnn/datautils.py
examples/pytorch/jtnn/jtnn/datautils.py
+5
-5
examples/pytorch/jtnn/jtnn/jtmpn.py
examples/pytorch/jtnn/jtnn/jtmpn.py
+22
-14
examples/pytorch/jtnn/jtnn/jtnn_dec.py
examples/pytorch/jtnn/jtnn/jtnn_dec.py
+52
-39
examples/pytorch/jtnn/jtnn/jtnn_enc.py
examples/pytorch/jtnn/jtnn/jtnn_enc.py
+11
-4
examples/pytorch/jtnn/jtnn/jtnn_vae.py
examples/pytorch/jtnn/jtnn/jtnn_vae.py
+17
-19
examples/pytorch/jtnn/jtnn/mol_tree_nx.py
examples/pytorch/jtnn/jtnn/mol_tree_nx.py
+10
-11
examples/pytorch/jtnn/jtnn/mpn.py
examples/pytorch/jtnn/jtnn/mpn.py
+6
-6
examples/pytorch/jtnn/jtnn/nnutils.py
examples/pytorch/jtnn/jtnn/nnutils.py
+15
-6
examples/pytorch/jtnn/vaetrain_dgl.py
examples/pytorch/jtnn/vaetrain_dgl.py
+3
-2
python/dgl/_ffi/object.py
python/dgl/_ffi/object.py
+7
-1
python/dgl/heterograph.py
python/dgl/heterograph.py
+13
-46
python/dgl/sampling/pinsage.py
python/dgl/sampling/pinsage.py
+1
-1
src/array/cuda/csr_transpose.cc
src/array/cuda/csr_transpose.cc
+3
-1
tests/compute/test_heterograph.py
tests/compute/test_heterograph.py
+1
-1
tests/compute/test_pickle.py
tests/compute/test_pickle.py
+2
-2
tests/compute/test_shared_mem.py
tests/compute/test_shared_mem.py
+3
-3
tests/test_utils/checks.py
tests/test_utils/checks.py
+2
-2
tutorials/basics/5_hetero.py
tutorials/basics/5_hetero.py
+1
-1
No files found.
examples/pytorch/jtnn/README.md
View file @
68a978d4
...
...
@@ -6,7 +6,7 @@ This is a direct modification from https://github.com/wengong-jin/icml18-jtnn
Dependencies
--------------
*
PyTorch 0.4.1+
*
RDKit
*
RDKit
=2018.09.3.0
*
requests
How to run
...
...
examples/pytorch/jtnn/jtnn/datautils.py
View file @
68a978d4
...
...
@@ -84,9 +84,9 @@ class JTNNDataset(Dataset):
cand_graphs
=
[]
atom_x_dec
=
torch
.
zeros
(
0
,
ATOM_FDIM_DEC
)
bond_x_dec
=
torch
.
zeros
(
0
,
BOND_FDIM_DEC
)
tree_mess_src_e
=
torch
.
zeros
(
0
,
2
).
long
()
tree_mess_tgt_e
=
torch
.
zeros
(
0
,
2
).
long
()
tree_mess_tgt_n
=
torch
.
zeros
(
0
).
long
()
tree_mess_src_e
=
torch
.
zeros
(
0
,
2
).
int
()
tree_mess_tgt_e
=
torch
.
zeros
(
0
,
2
).
int
()
tree_mess_tgt_n
=
torch
.
zeros
(
0
).
int
()
# prebuild the stereoisomers
cands
=
mol_tree
.
stereo_cands
...
...
@@ -143,7 +143,7 @@ class JTNNCollator(object):
mol_trees
=
_unpack_field
(
examples
,
'mol_tree'
)
wid
=
_unpack_field
(
examples
,
'wid'
)
for
_wid
,
mol_tree
in
zip
(
wid
,
mol_trees
):
mol_tree
.
ndata
[
'wid'
]
=
torch
.
LongTensor
(
_wid
)
mol_tree
.
graph
.
ndata
[
'wid'
]
=
torch
.
LongTensor
(
_wid
)
# TODO: either support pickling or get around ctypes pointers using scipy
# batch molecule graphs
...
...
@@ -176,7 +176,7 @@ class JTNNCollator(object):
tree_mess_src_e
[
i
]
+=
n_tree_nodes
tree_mess_tgt_n
[
i
]
+=
n_graph_nodes
n_graph_nodes
+=
sum
(
g
.
number_of_nodes
()
for
g
in
cand_graphs
[
i
])
n_tree_nodes
+=
mol_trees
[
i
].
number_of_nodes
()
n_tree_nodes
+=
mol_trees
[
i
].
graph
.
number_of_nodes
()
cand_batch_idx
.
extend
([
i
]
*
len
(
cand_graphs
[
i
]))
tree_mess_tgt_e
=
torch
.
cat
(
tree_mess_tgt_e
)
tree_mess_src_e
=
torch
.
cat
(
tree_mess_src_e
)
...
...
examples/pytorch/jtnn/jtnn/jtmpn.py
View file @
68a978d4
import
torch
import
torch.nn
as
nn
from
.nnutils
import
cuda
from
.nnutils
import
cuda
,
line_graph
import
rdkit.Chem
as
Chem
from
dgl
import
DGLGraph
,
mean_nodes
import
dgl
from
dgl
import
mean_nodes
import
dgl.function
as
DGLF
import
os
...
...
@@ -50,12 +51,11 @@ def mol2dgl_single(cand_batch):
ctr_node
=
mol_tree
.
nodes_dict
[
ctr_node_id
]
ctr_bid
=
ctr_node
[
'idx'
]
g
=
DGLGraph
(
)
mol_tree_graph
=
getattr
(
mol_tree
,
'graph'
,
mol_tree
)
for
i
,
atom
in
enumerate
(
mol
.
GetAtoms
()):
assert
i
==
atom
.
GetIdx
()
atom_x
.
append
(
atom_features
(
atom
))
g
.
add_nodes
(
n_atoms
)
bond_src
=
[]
bond_dst
=
[]
...
...
@@ -78,24 +78,24 @@ def mol2dgl_single(cand_batch):
x_bid
=
mol_tree
.
nodes_dict
[
x_nid
-
1
][
'idx'
]
if
x_nid
>
0
else
-
1
y_bid
=
mol_tree
.
nodes_dict
[
y_nid
-
1
][
'idx'
]
if
y_nid
>
0
else
-
1
if
x_bid
>=
0
and
y_bid
>=
0
and
x_bid
!=
y_bid
:
if
mol_tree
.
has_edge_between
(
x_bid
,
y_bid
):
if
mol_tree
_graph
.
has_edge
s
_between
(
x_bid
,
y_bid
):
tree_mess_target_edges
.
append
((
begin_idx
+
n_nodes
,
end_idx
+
n_nodes
))
tree_mess_source_edges
.
append
((
x_bid
,
y_bid
))
tree_mess_target_nodes
.
append
(
end_idx
+
n_nodes
)
if
mol_tree
.
has_edge_between
(
y_bid
,
x_bid
):
if
mol_tree
_graph
.
has_edge
s
_between
(
y_bid
,
x_bid
):
tree_mess_target_edges
.
append
((
end_idx
+
n_nodes
,
begin_idx
+
n_nodes
))
tree_mess_source_edges
.
append
((
y_bid
,
x_bid
))
tree_mess_target_nodes
.
append
(
begin_idx
+
n_nodes
)
n_nodes
+=
n_atoms
g
.
add_edges
(
bond_src
,
bond_dst
)
g
=
dgl
.
graph
(
(
bond_src
,
bond_dst
)
,
num_nodes
=
n_atoms
)
cand_graphs
.
append
(
g
)
return
cand_graphs
,
torch
.
stack
(
atom_x
),
\
torch
.
stack
(
bond_x
)
if
len
(
bond_x
)
>
0
else
torch
.
zeros
(
0
),
\
torch
.
Long
Tensor
(
tree_mess_source_edges
),
\
torch
.
Long
Tensor
(
tree_mess_target_edges
),
\
torch
.
Long
Tensor
(
tree_mess_target_nodes
)
torch
.
Int
Tensor
(
tree_mess_source_edges
),
\
torch
.
Int
Tensor
(
tree_mess_target_edges
),
\
torch
.
Int
Tensor
(
tree_mess_target_nodes
)
mpn_loopy_bp_msg
=
DGLF
.
copy_src
(
src
=
'msg'
,
out
=
'msg'
)
...
...
@@ -174,7 +174,7 @@ class DGLJTMPN(nn.Module):
n_samples
=
len
(
cand_graphs
)
cand_line_graph
=
cand
_graph
s
.
line
_graph
(
backtracking
=
False
,
shared
=
True
)
cand_line_graph
=
line
_graph
(
cand
_graph
s
,
backtracking
=
False
,
shared
=
True
)
n_nodes
=
cand_graphs
.
number_of_nodes
()
n_edges
=
cand_graphs
.
number_of_edges
()
...
...
@@ -222,20 +222,27 @@ class DGLJTMPN(nn.Module):
if
PAPER
:
src_u
,
src_v
=
tree_mess_src_edges
.
unbind
(
1
)
tgt_u
,
tgt_v
=
tree_mess_tgt_edges
.
unbind
(
1
)
alpha
=
mol_tree_batch
.
edges
[
src_u
,
src_v
].
data
[
'm'
]
src_u
=
src_u
.
to
(
mol_tree_batch
.
device
)
src_v
=
src_v
.
to
(
mol_tree_batch
.
device
)
eid
=
mol_tree_batch
.
edge_ids
(
src_u
.
int
(),
src_v
.
int
()).
long
()
alpha
=
mol_tree_batch
.
edata
[
'm'
][
eid
]
cand_graphs
.
edges
[
tgt_u
,
tgt_v
].
data
[
'alpha'
]
=
alpha
else
:
src_u
,
src_v
=
tree_mess_src_edges
.
unbind
(
1
)
alpha
=
mol_tree_batch
.
edges
[
src_u
,
src_v
].
data
[
'm'
]
src_u
=
src_u
.
to
(
mol_tree_batch
.
device
)
src_v
=
src_v
.
to
(
mol_tree_batch
.
device
)
eid
=
mol_tree_batch
.
edge_ids
(
src_u
.
int
(),
src_v
.
int
()).
long
()
alpha
=
mol_tree_batch
.
edata
[
'm'
][
eid
]
node_idx
=
(
tree_mess_tgt_nodes
.
to
(
device
=
zero_node_state
.
device
)[:,
None
]
.
expand_as
(
alpha
))
node_alpha
=
zero_node_state
.
clone
().
scatter_add
(
0
,
node_idx
,
alpha
)
node_alpha
=
zero_node_state
.
clone
().
scatter_add
(
0
,
node_idx
.
long
()
,
alpha
)
cand_graphs
.
ndata
[
'alpha'
]
=
node_alpha
cand_graphs
.
apply_edges
(
func
=
lambda
edges
:
{
'alpha'
:
edges
.
src
[
'alpha'
]},
)
cand_line_graph
.
ndata
.
update
(
cand_graphs
.
edata
)
for
i
in
range
(
self
.
depth
-
1
):
cand_line_graph
.
update_all
(
mpn_loopy_bp_msg
,
...
...
@@ -243,6 +250,7 @@ class DGLJTMPN(nn.Module):
self
.
loopy_bp_updater
,
)
cand_graphs
.
edata
.
update
(
cand_line_graph
.
ndata
)
cand_graphs
.
update_all
(
mpn_gather_msg
,
mpn_gather_reduce
,
...
...
examples/pytorch/jtnn/jtnn/jtnn_dec.py
View file @
68a978d4
...
...
@@ -3,7 +3,7 @@ import torch.nn as nn
import
torch.nn.functional
as
F
from
.mol_tree_nx
import
DGLMolTree
from
.chemutils
import
enum_assemble_nx
,
get_mol
from
.nnutils
import
GRUUpdate
,
cuda
from
.nnutils
import
GRUUpdate
,
cuda
,
line_graph
,
tocpu
from
dgl
import
batch
,
dfs_labeled_edges_generator
import
dgl.function
as
DGLF
import
numpy
as
np
...
...
@@ -13,6 +13,7 @@ MAX_DECODE_LEN = 100
def
dfs_order
(
forest
,
roots
):
forest
=
tocpu
(
forest
)
edges
=
dfs_labeled_edges_generator
(
forest
,
roots
,
has_reverse_edge
=
True
)
for
e
,
l
in
zip
(
*
edges
):
# I exploited the fact that the reverse edge ID equal to 1 xor forward
...
...
@@ -55,7 +56,7 @@ def have_slots(fa_slots, ch_slots):
def
can_assemble
(
mol_tree
,
u
,
v_node_dict
):
u_node_dict
=
mol_tree
.
nodes_dict
[
u
]
u_neighbors
=
mol_tree
.
successors
(
u
)
u_neighbors
=
mol_tree
.
graph
.
successors
(
u
)
u_neighbors_node_dict
=
[
mol_tree
.
nodes_dict
[
_u
]
for
_u
in
u_neighbors
...
...
@@ -106,13 +107,13 @@ class DGLJTNNDecoder(nn.Module):
ground truth tree
'''
mol_tree_batch
=
batch
(
mol_trees
)
mol_tree_batch_lg
=
mol_tree_batch
.
line_graph
(
backtracking
=
False
,
shared
=
True
)
mol_tree_batch_lg
=
line_graph
(
mol_tree_batch
,
backtracking
=
False
,
shared
=
True
)
n_trees
=
len
(
mol_trees
)
return
self
.
run
(
mol_tree_batch
,
mol_tree_batch_lg
,
n_trees
,
tree_vec
)
def
run
(
self
,
mol_tree_batch
,
mol_tree_batch_lg
,
n_trees
,
tree_vec
):
node_offset
=
np
.
cumsum
(
[
0
]
+
mol_tree_batch
.
batch_num_nodes
)
node_offset
=
np
.
cumsum
(
np
.
insert
(
mol_tree_batch
.
batch_num_nodes
().
cpu
().
numpy
(),
0
,
0
)
)
root_ids
=
node_offset
[:
-
1
]
n_nodes
=
mol_tree_batch
.
number_of_nodes
()
n_edges
=
mol_tree_batch
.
number_of_edges
()
...
...
@@ -120,7 +121,7 @@ class DGLJTNNDecoder(nn.Module):
mol_tree_batch
.
ndata
.
update
({
'x'
:
self
.
embedding
(
mol_tree_batch
.
ndata
[
'wid'
]),
'h'
:
cuda
(
torch
.
zeros
(
n_nodes
,
self
.
hidden_size
)),
'new'
:
cuda
(
torch
.
ones
(
n_nodes
).
b
yte
()),
# whether it's newly generated node
'new'
:
cuda
(
torch
.
ones
(
n_nodes
).
b
ool
()),
# whether it's newly generated node
})
mol_tree_batch
.
edata
.
update
({
...
...
@@ -162,22 +163,26 @@ class DGLJTNNDecoder(nn.Module):
# Traverse the tree and predict on children
for
eid
,
p
in
dfs_order
(
mol_tree_batch
,
root_ids
):
eid
=
eid
.
to
(
mol_tree_batch
.
device
)
p
=
p
.
to
(
mol_tree_batch
.
device
)
u
,
v
=
mol_tree_batch
.
find_edges
(
eid
)
p_target_list
=
torch
.
zeros_like
(
root_out_degrees
)
p_target_list
[
root_out_degrees
>
0
]
=
1
-
p
p_target_list
[
root_out_degrees
>
0
]
=
(
1
-
p
).
int
()
p_target_list
=
p_target_list
[
root_out_degrees
>=
0
]
p_targets
.
append
(
torch
.
tensor
(
p_target_list
))
root_out_degrees
-=
(
root_out_degrees
==
0
).
long
()
root_out_degrees
-=
torch
.
tensor
(
np
.
isin
(
root_ids
,
v
).
astype
(
'int64'
)
)
root_out_degrees
-=
(
root_out_degrees
==
0
).
int
()
root_out_degrees
-=
torch
.
tensor
(
np
.
isin
(
root_ids
,
v
.
cpu
().
numpy
())).
to
(
root_out_degrees
)
mol_tree_batch_lg
.
ndata
.
update
(
mol_tree_batch
.
edata
)
mol_tree_batch_lg
.
pull
(
eid
,
dec_tree_edge_msg
,
dec_tree_edge_reduce
,
self
.
dec_tree_edge_update
,
)
mol_tree_batch
.
edata
.
update
(
mol_tree_batch_lg
.
ndata
)
is_new
=
mol_tree_batch
.
nodes
[
v
].
data
[
'new'
]
mol_tree_batch
.
pull
(
v
,
...
...
@@ -185,6 +190,7 @@ class DGLJTNNDecoder(nn.Module):
dec_tree_node_reduce
,
dec_tree_node_update
,
)
# Extract
n_repr
=
mol_tree_batch
.
nodes
[
v
].
data
h
=
n_repr
[
'h'
]
...
...
@@ -200,7 +206,10 @@ class DGLJTNNDecoder(nn.Module):
if
q_input
.
shape
[
0
]
>
0
:
q_inputs
.
append
(
q_input
)
q_targets
.
append
(
q_target
)
p_targets
.
append
(
torch
.
zeros
((
root_out_degrees
==
0
).
sum
()).
long
())
p_targets
.
append
(
torch
.
zeros
(
(
root_out_degrees
==
0
).
sum
(),
device
=
root_out_degrees
.
device
,
dtype
=
torch
.
int32
))
# Batch compute the stop/label prediction losses
p_inputs
=
torch
.
cat
(
p_inputs
,
0
)
...
...
@@ -231,6 +240,8 @@ class DGLJTNNDecoder(nn.Module):
assert
mol_vec
.
shape
[
0
]
==
1
mol_tree
=
DGLMolTree
(
None
)
mol_tree
.
graph
=
mol_tree
.
graph
.
to
(
mol_vec
.
device
)
mol_tree_graph
=
mol_tree
.
graph
init_hidden
=
cuda
(
torch
.
zeros
(
1
,
self
.
hidden_size
))
...
...
@@ -240,11 +251,11 @@ class DGLJTNNDecoder(nn.Module):
_
,
root_wid
=
torch
.
max
(
root_score
,
1
)
root_wid
=
root_wid
.
view
(
1
)
mol_tree
.
add_nodes
(
1
)
# root
mol_tree
.
nodes
[
0
].
data
[
'wid'
]
=
root_wid
mol_tree
.
nodes
[
0
].
data
[
'x'
]
=
self
.
embedding
(
root_wid
)
mol_tree
.
nodes
[
0
].
data
[
'h'
]
=
init_hidden
mol_tree
.
nodes
[
0
].
data
[
'fail'
]
=
cuda
(
torch
.
tensor
([
0
]))
mol_tree
_graph
.
add_nodes
(
1
)
# root
mol_tree
_graph
.
n
data
[
'wid'
]
=
root_wid
mol_tree
_graph
.
n
data
[
'x'
]
=
self
.
embedding
(
root_wid
)
mol_tree
_graph
.
n
data
[
'h'
]
=
init_hidden
mol_tree
_graph
.
n
data
[
'fail'
]
=
cuda
(
torch
.
tensor
([
0
]))
mol_tree
.
nodes_dict
[
0
]
=
root_node_dict
=
create_node_dict
(
self
.
vocab
.
get_smiles
(
root_wid
))
...
...
@@ -259,27 +270,27 @@ class DGLJTNNDecoder(nn.Module):
for
step
in
range
(
MAX_DECODE_LEN
):
u
,
u_slots
=
stack
[
-
1
]
udata
=
mol_tree
.
nodes
[
u
].
data
x
=
udata
[
'x'
]
h
=
udata
[
'h'
]
x
=
mol_tree_graph
.
ndata
[
'x'
][
u
:
u
+
1
]
h
=
mol_tree_graph
.
ndata
[
'h'
][
u
:
u
+
1
]
# Predict stop
p_input
=
torch
.
cat
([
x
,
h
,
mol_vec
],
1
)
p_score
=
torch
.
sigmoid
(
self
.
U_s
(
torch
.
relu
(
self
.
U
(
p_input
))))
p_score
[:]
=
0
backtrack
=
(
p_score
.
item
()
<
0.5
)
if
not
backtrack
:
# Predict next clique. Note that the prediction may fail due
# to lack of assemblable components
mol_tree
.
add_nodes
(
1
)
mol_tree
_graph
.
add_nodes
(
1
)
new_node_id
+=
1
v
=
new_node_id
mol_tree
.
add_edges
(
u
,
v
)
mol_tree
_graph
.
add_edges
(
u
,
v
)
uv
=
new_edge_id
new_edge_id
+=
1
if
first
:
mol_tree
.
edata
.
update
({
mol_tree
_graph
.
edata
.
update
({
's'
:
cuda
(
torch
.
zeros
(
1
,
self
.
hidden_size
)),
'm'
:
cuda
(
torch
.
zeros
(
1
,
self
.
hidden_size
)),
'r'
:
cuda
(
torch
.
zeros
(
1
,
self
.
hidden_size
)),
...
...
@@ -291,26 +302,26 @@ class DGLJTNNDecoder(nn.Module):
})
first
=
False
mol_tree
.
edges
[
uv
].
data
[
'src_x'
]
=
mol_tree
.
nodes
[
u
].
data
[
'x'
]
mol_tree
_graph
.
e
data
[
'src_x'
]
[
uv
]
=
mol_tree
_graph
.
n
data
[
'x'
]
[
u
]
# keeping dst_x 0 is fine as h on new edge doesn't depend on that.
# DGL doesn't dynamically maintain a line graph.
mol_tree_
lg
=
mol_tree
.
line
_graph
(
backtracking
=
False
,
shared
=
True
)
mol_tree_
graph_lg
=
line_graph
(
mol_tree_graph
,
backtracking
=
False
,
shared
=
True
)
mol_tree_lg
.
pull
(
mol_tree_
graph_
lg
.
pull
(
uv
,
dec_tree_edge_msg
,
dec_tree_edge_reduce
,
self
.
dec_tree_edge_update
.
update_zm
,
)
mol_tree
.
pull
(
mol_tree_graph
.
edata
.
update
(
mol_tree_graph_lg
.
ndata
)
mol_tree_graph
.
pull
(
v
,
dec_tree_node_msg
,
dec_tree_node_reduce
,
)
vdata
=
mol_tree
.
nodes
[
v
].
data
h_v
=
vdata
[
'h'
]
h_v
=
mol_tree_graph
.
ndata
[
'h'
][
v
:
v
+
1
]
q_input
=
torch
.
cat
([
h_v
,
mol_vec
],
1
)
q_score
=
torch
.
softmax
(
self
.
W_o
(
torch
.
relu
(
self
.
W
(
q_input
))),
-
1
)
_
,
sort_wid
=
torch
.
sort
(
q_score
,
1
,
descending
=
True
)
...
...
@@ -329,49 +340,51 @@ class DGLJTNNDecoder(nn.Module):
if
next_wid
is
None
:
# Failed adding an actual children; v is a spurious node
# and we mark it.
v
data
[
'fail'
]
=
cuda
(
torch
.
tensor
([
1
]))
mol_tree_graph
.
n
data
[
'fail'
]
[
v
]
=
cuda
(
torch
.
tensor
([
1
]))
backtrack
=
True
else
:
next_wid
=
cuda
(
torch
.
tensor
([
next_wid
]))
v
data
[
'wid'
]
=
next_wid
v
data
[
'x'
]
=
self
.
embedding
(
next_wid
)
mol_tree_graph
.
n
data
[
'wid'
]
[
v
]
=
next_wid
mol_tree_graph
.
n
data
[
'x'
]
[
v
]
=
self
.
embedding
(
next_wid
)
mol_tree
.
nodes_dict
[
v
]
=
next_node_dict
all_nodes
[
v
]
=
next_node_dict
stack
.
append
((
v
,
next_slots
))
mol_tree
.
add_edge
(
v
,
u
)
mol_tree
_graph
.
add_edge
s
(
v
,
u
)
vu
=
new_edge_id
new_edge_id
+=
1
mol_tree
.
edges
[
uv
].
data
[
'dst_x'
]
=
mol_tree
.
nodes
[
v
].
data
[
'x'
]
mol_tree
.
edges
[
vu
].
data
[
'src_x'
]
=
mol_tree
.
nodes
[
v
].
data
[
'x'
]
mol_tree
.
edges
[
vu
].
data
[
'dst_x'
]
=
mol_tree
.
nodes
[
u
].
data
[
'x'
]
mol_tree
_graph
.
e
data
[
'dst_x'
]
[
uv
]
=
mol_tree
_graph
.
n
data
[
'x'
]
[
v
]
mol_tree
_graph
.
e
data
[
'src_x'
]
[
vu
]
=
mol_tree
_graph
.
n
data
[
'x'
]
[
v
]
mol_tree
_graph
.
e
data
[
'dst_x'
]
[
vu
]
=
mol_tree
_graph
.
n
data
[
'x'
]
[
u
]
# DGL doesn't dynamically maintain a line graph.
mol_tree_
lg
=
mol_tree
.
line
_graph
(
backtracking
=
False
,
shared
=
True
)
mol_tree_lg
.
apply_nodes
(
mol_tree_
graph_lg
=
line_graph
(
mol_tree_graph
,
backtracking
=
False
,
shared
=
True
)
mol_tree_
graph_
lg
.
apply_nodes
(
self
.
dec_tree_edge_update
.
update_r
,
uv
)
mol_tree_graph
.
edata
.
update
(
mol_tree_graph_lg
.
ndata
)
if
backtrack
:
if
len
(
stack
)
==
1
:
break
# At root, terminate
pu
,
_
=
stack
[
-
2
]
u_pu
=
mol_tree
.
edge_id
(
u
,
pu
)
u_pu
=
mol_tree
_graph
.
edge_id
(
u
,
pu
)
mol_tree_lg
.
pull
(
mol_tree_
graph_
lg
.
pull
(
u_pu
,
dec_tree_edge_msg
,
dec_tree_edge_reduce
,
self
.
dec_tree_edge_update
,
)
mol_tree
.
pull
(
mol_tree_graph
.
edata
.
update
(
mol_tree_graph_lg
.
ndata
)
mol_tree_graph
.
pull
(
pu
,
dec_tree_node_msg
,
dec_tree_node_reduce
,
)
stack
.
pop
()
effective_nodes
=
mol_tree
.
filter_nodes
(
lambda
nodes
:
nodes
.
data
[
'fail'
]
!=
1
)
effective_nodes
=
mol_tree
_graph
.
filter_nodes
(
lambda
nodes
:
nodes
.
data
[
'fail'
]
!=
1
)
effective_nodes
,
_
=
torch
.
sort
(
effective_nodes
)
return
mol_tree
,
all_nodes
,
effective_nodes
examples/pytorch/jtnn/jtnn/jtnn_enc.py
View file @
68a978d4
import
torch
import
torch.nn
as
nn
from
.nnutils
import
GRUUpdate
,
cuda
from
.nnutils
import
GRUUpdate
,
cuda
,
line_graph
,
tocpu
from
dgl
import
batch
,
bfs_edges_generator
import
dgl.function
as
DGLF
import
numpy
as
np
...
...
@@ -8,7 +8,11 @@ import numpy as np
MAX_NB
=
8
def
level_order
(
forest
,
roots
):
forest
=
tocpu
(
forest
)
edges
=
bfs_edges_generator
(
forest
,
roots
)
if
len
(
edges
)
==
0
:
# no edges in the tree; do not perform loopy BP
return
_
,
leaves
=
forest
.
find_edges
(
edges
[
-
1
])
edges_back
=
bfs_edges_generator
(
forest
,
roots
,
reverse
=
True
)
yield
from
reversed
(
edges_back
)
...
...
@@ -53,14 +57,14 @@ class DGLJTNNEncoder(nn.Module):
mol_tree_batch
=
batch
(
mol_trees
)
# Build line graph to prepare for belief propagation
mol_tree_batch_lg
=
mol_tree_batch
.
line_graph
(
backtracking
=
False
,
shared
=
True
)
mol_tree_batch_lg
=
line_graph
(
mol_tree_batch
,
backtracking
=
False
,
shared
=
True
)
return
self
.
run
(
mol_tree_batch
,
mol_tree_batch_lg
)
def
run
(
self
,
mol_tree_batch
,
mol_tree_batch_lg
):
# Since tree roots are designated to 0. In the batched graph we can
# simply find the corresponding node ID by looking at node_offset
node_offset
=
np
.
cumsum
(
[
0
]
+
mol_tree_batch
.
batch_num_nodes
)
node_offset
=
np
.
cumsum
(
np
.
insert
(
mol_tree_batch
.
batch_num_nodes
().
cpu
().
numpy
(),
0
,
0
)
)
root_ids
=
node_offset
[:
-
1
]
n_nodes
=
mol_tree_batch
.
number_of_nodes
()
n_edges
=
mol_tree_batch
.
number_of_edges
()
...
...
@@ -68,6 +72,7 @@ class DGLJTNNEncoder(nn.Module):
# Assign structure embeddings to tree nodes
mol_tree_batch
.
ndata
.
update
({
'x'
:
self
.
embedding
(
mol_tree_batch
.
ndata
[
'wid'
]),
'm'
:
cuda
(
torch
.
zeros
(
n_nodes
,
self
.
hidden_size
)),
'h'
:
cuda
(
torch
.
zeros
(
n_nodes
,
self
.
hidden_size
)),
})
...
...
@@ -95,16 +100,18 @@ class DGLJTNNEncoder(nn.Module):
# messages, and the uncomputed messages are zero vectors. Essentially,
# we can always compute s_ij as the sum of incoming m_ij, no matter
# if m_ij is actually computed or not.
mol_tree_batch_lg
.
ndata
.
update
(
mol_tree_batch
.
edata
)
for
eid
in
level_order
(
mol_tree_batch
,
root_ids
):
#eid = mol_tree_batch.edge_ids(u, v)
mol_tree_batch_lg
.
pull
(
eid
,
eid
.
to
(
mol_tree_batch_lg
.
device
)
,
enc_tree_msg
,
enc_tree_reduce
,
self
.
enc_tree_update
,
)
# Readout
mol_tree_batch
.
edata
.
update
(
mol_tree_batch_lg
.
ndata
)
mol_tree_batch
.
update_all
(
enc_tree_gather_msg
,
enc_tree_gather_reduce
,
...
...
examples/pytorch/jtnn/jtnn/jtnn_vae.py
View file @
68a978d4
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
.nnutils
import
cuda
,
move_dgl_to_cuda
from
.nnutils
import
cuda
from
.chemutils
import
set_atommap
,
copy_edit_mol
,
enum_assemble_nx
,
\
attach_mols_nx
,
decode_stereo
from
.jtnn_enc
import
DGLJTNNEncoder
...
...
@@ -44,24 +44,24 @@ class DGLJTNNVAE(nn.Module):
@
staticmethod
def
move_to_cuda
(
mol_batch
):
for
t
in
mol_batch
[
'mol_trees'
]:
mo
ve_dgl_to_cuda
(
t
)
for
i
in
range
(
len
(
mol_batch
[
'mol_trees'
]
))
:
mo
l_batch
[
'mol_trees'
][
i
].
graph
=
cuda
(
mol_batch
[
'mol_trees'
][
i
].
graph
)
mo
ve_dgl_to_
cuda
(
mol_batch
[
'mol_graph_batch'
])
mo
l_batch
[
'mol_graph_batch'
]
=
cuda
(
mol_batch
[
'mol_graph_batch'
])
if
'cand_graph_batch'
in
mol_batch
:
mo
ve_dgl_to_
cuda
(
mol_batch
[
'cand_graph_batch'
])
mo
l_batch
[
'cand_graph_batch'
]
=
cuda
(
mol_batch
[
'cand_graph_batch'
])
if
mol_batch
.
get
(
'stereo_cand_graph_batch'
)
is
not
None
:
mo
ve_dgl_to_
cuda
(
mol_batch
[
'stereo_cand_graph_batch'
])
mo
l_batch
[
'stereo_cand_graph_batch'
]
=
cuda
(
mol_batch
[
'stereo_cand_graph_batch'
])
def
encode
(
self
,
mol_batch
):
mol_graphs
=
mol_batch
[
'mol_graph_batch'
]
mol_vec
=
self
.
mpn
(
mol_graphs
)
mol_tree_batch
,
tree_vec
=
self
.
jtnn
(
mol_batch
[
'mol_trees'
])
mol_tree_batch
,
tree_vec
=
self
.
jtnn
(
[
t
.
graph
for
t
in
mol_batch
[
'mol_trees'
]
]
)
self
.
n_nodes_total
+=
mol_graphs
.
number_of_nodes
()
self
.
n_edges_total
+=
mol_graphs
.
number_of_edges
()
self
.
n_tree_nodes_total
+=
sum
(
t
.
number_of_nodes
()
for
t
in
mol_batch
[
'mol_trees'
])
self
.
n_tree_nodes_total
+=
sum
(
t
.
graph
.
number_of_nodes
()
for
t
in
mol_batch
[
'mol_trees'
])
self
.
n_passes
+=
1
return
mol_tree_batch
,
tree_vec
,
mol_vec
...
...
@@ -93,7 +93,7 @@ class DGLJTNNVAE(nn.Module):
tree_vec
,
mol_vec
,
z_mean
,
z_log_var
=
self
.
sample
(
tree_vec
,
mol_vec
,
e1
,
e2
)
kl_loss
=
-
0.5
*
torch
.
sum
(
1.0
+
z_log_var
-
z_mean
*
z_mean
-
torch
.
exp
(
z_log_var
))
/
batch_size
word_loss
,
topo_loss
,
word_acc
,
topo_acc
=
self
.
decoder
(
mol_trees
,
tree_vec
)
word_loss
,
topo_loss
,
word_acc
,
topo_acc
=
self
.
decoder
(
[
t
.
graph
for
t
in
mol_trees
]
,
tree_vec
)
assm_loss
,
assm_acc
=
self
.
assm
(
mol_batch
,
mol_tree_batch
,
mol_vec
)
stereo_loss
,
stereo_acc
=
self
.
stereo
(
mol_batch
,
mol_vec
)
...
...
@@ -103,9 +103,9 @@ class DGLJTNNVAE(nn.Module):
def
assm
(
self
,
mol_batch
,
mol_tree_batch
,
mol_vec
):
cands
=
[
mol_batch
[
'cand_graph_batch'
],
mol_batch
[
'tree_mess_src_e'
],
mol_batch
[
'tree_mess_tgt_e'
],
mol_batch
[
'tree_mess_tgt_n'
]]
cuda
(
mol_batch
[
'tree_mess_src_e'
]
)
,
cuda
(
mol_batch
[
'tree_mess_tgt_e'
]
)
,
cuda
(
mol_batch
[
'tree_mess_tgt_n'
]
)
]
cand_vec
=
self
.
jtmpn
(
cands
,
mol_tree_batch
)
cand_vec
=
self
.
G_mean
(
cand_vec
)
...
...
@@ -179,12 +179,11 @@ class DGLJTNNVAE(nn.Module):
node
[
'idx'
]
=
i
node
[
'nid'
]
=
i
+
1
node
[
'is_leaf'
]
=
True
if
mol_tree
.
in_degree
(
node_id
)
>
1
:
if
mol_tree
.
graph
.
in_degree
s
(
node_id
)
>
1
:
node
[
'is_leaf'
]
=
False
set_atommap
(
node
[
'mol'
],
node
[
'nid'
])
mol_tree_sg
=
mol_tree
.
subgraph
(
effective_nodes
)
mol_tree_sg
.
copy_from_parent
()
mol_tree_sg
=
mol_tree
.
graph
.
subgraph
(
effective_nodes
.
int
().
to
(
tree_vec
.
device
))
mol_tree_msg
,
_
=
self
.
jtnn
([
mol_tree_sg
])
mol_tree_msg
=
unbatch
(
mol_tree_msg
)[
0
]
mol_tree_msg
.
nodes_dict
=
nodes_dict
...
...
@@ -210,7 +209,7 @@ class DGLJTNNVAE(nn.Module):
stereo_graphs
=
[
mol2dgl_enc
(
c
)
for
c
in
stereo_cands
]
stereo_cand_graphs
,
atom_x
,
bond_x
=
\
zip
(
*
stereo_graphs
)
stereo_cand_graphs
=
batch
(
stereo_cand_graphs
)
stereo_cand_graphs
=
cuda
(
batch
(
stereo_cand_graphs
)
)
atom_x
=
cuda
(
torch
.
cat
(
atom_x
))
bond_x
=
cuda
(
torch
.
cat
(
bond_x
))
stereo_cand_graphs
.
ndata
[
'x'
]
=
atom_x
...
...
@@ -248,9 +247,8 @@ class DGLJTNNVAE(nn.Module):
cands
=
[(
candmol
,
mol_tree_msg
,
cur_node_id
)
for
candmol
in
cand_mols
]
cand_graphs
,
atom_x
,
bond_x
,
tree_mess_src_edges
,
\
tree_mess_tgt_edges
,
tree_mess_tgt_nodes
=
mol2dgl_dec
(
cands
)
cand_graphs
=
batch
(
cand_graphs
)
tree_mess_tgt_edges
,
tree_mess_tgt_nodes
=
mol2dgl_dec
(
cands
)
cand_graphs
=
batch
([
g
.
to
(
mol_vec
.
device
)
for
g
in
cand_graphs
])
atom_x
=
cuda
(
atom_x
)
bond_x
=
cuda
(
bond_x
)
cand_graphs
.
ndata
[
'x'
]
=
atom_x
...
...
examples/pytorch/jtnn/jtnn/mol_tree_nx.py
View file @
68a978d4
from
dgl
import
DGLGraph
import
dgl
import
rdkit.Chem
as
Chem
from
.chemutils
import
get_clique_mol
,
tree_decomp
,
get_mol
,
get_smiles
,
\
set_atommap
,
enum_assemble_nx
,
decode_stereo
import
numpy
as
np
class
DGLMolTree
(
DGLGraph
):
class
DGLMolTree
(
object
):
def
__init__
(
self
,
smiles
):
DGLGraph
.
__init__
(
self
)
self
.
nodes_dict
=
{}
if
smiles
is
None
:
self
.
graph
=
dgl
.
graph
(([],
[]))
return
self
.
smiles
=
smiles
...
...
@@ -34,7 +34,6 @@ class DGLMolTree(DGLGraph):
)
if
min
(
c
)
==
0
:
root
=
i
self
.
add_nodes
(
len
(
cliques
))
# The clique with atom ID 0 becomes root
if
root
>
0
:
...
...
@@ -51,16 +50,16 @@ class DGLMolTree(DGLGraph):
dst
[
2
*
i
]
=
y
src
[
2
*
i
+
1
]
=
y
dst
[
2
*
i
+
1
]
=
x
self
.
add_edges
(
src
,
dst
)
self
.
graph
=
dgl
.
graph
((
src
,
dst
),
num_nodes
=
len
(
cliques
)
)
for
i
in
self
.
nodes_dict
:
self
.
nodes_dict
[
i
][
'nid'
]
=
i
+
1
if
self
.
out_degree
(
i
)
>
1
:
# Leaf node mol is not marked
if
self
.
graph
.
out_degree
s
(
i
)
>
1
:
# Leaf node mol is not marked
set_atommap
(
self
.
nodes_dict
[
i
][
'mol'
],
self
.
nodes_dict
[
i
][
'nid'
])
self
.
nodes_dict
[
i
][
'is_leaf'
]
=
(
self
.
out_degree
(
i
)
==
1
)
self
.
nodes_dict
[
i
][
'is_leaf'
]
=
(
self
.
graph
.
out_degree
s
(
i
)
==
1
)
def
treesize
(
self
):
return
self
.
number_of_nodes
()
return
self
.
graph
.
number_of_nodes
()
def
_recover_node
(
self
,
i
,
original_mol
):
node
=
self
.
nodes_dict
[
i
]
...
...
@@ -71,7 +70,7 @@ class DGLMolTree(DGLGraph):
for
cidx
in
node
[
'clique'
]:
original_mol
.
GetAtomWithIdx
(
cidx
).
SetAtomMapNum
(
node
[
'nid'
])
for
j
in
self
.
successors
(
i
).
numpy
():
for
j
in
self
.
graph
.
successors
(
i
).
numpy
():
nei_node
=
self
.
nodes_dict
[
j
]
clique
.
extend
(
nei_node
[
'clique'
])
if
nei_node
[
'is_leaf'
]:
# Leaf node, no need to mark
...
...
@@ -93,10 +92,10 @@ class DGLMolTree(DGLGraph):
return
node
[
'label'
]
def
_assemble_node
(
self
,
i
):
neighbors
=
[
self
.
nodes_dict
[
j
]
for
j
in
self
.
successors
(
i
).
numpy
()
neighbors
=
[
self
.
nodes_dict
[
j
]
for
j
in
self
.
graph
.
successors
(
i
).
numpy
()
if
self
.
nodes_dict
[
j
][
'mol'
].
GetNumAtoms
()
>
1
]
neighbors
=
sorted
(
neighbors
,
key
=
lambda
x
:
x
[
'mol'
].
GetNumAtoms
(),
reverse
=
True
)
singletons
=
[
self
.
nodes_dict
[
j
]
for
j
in
self
.
successors
(
i
).
numpy
()
singletons
=
[
self
.
nodes_dict
[
j
]
for
j
in
self
.
graph
.
successors
(
i
).
numpy
()
if
self
.
nodes_dict
[
j
][
'mol'
].
GetNumAtoms
()
==
1
]
neighbors
=
singletons
+
neighbors
...
...
examples/pytorch/jtnn/jtnn/mpn.py
View file @
68a978d4
...
...
@@ -3,8 +3,10 @@ import torch.nn as nn
import
rdkit.Chem
as
Chem
import
torch.nn.functional
as
F
from
.chemutils
import
get_mol
from
dgl
import
DGLGraph
,
mean_nodes
import
dgl
from
dgl
import
mean_nodes
import
dgl.function
as
DGLF
from
.nnutils
import
line_graph
ELEM_LIST
=
[
'C'
,
'N'
,
'O'
,
'S'
,
'F'
,
'Si'
,
'P'
,
'Cl'
,
'Br'
,
'Mg'
,
'Na'
,
'Ca'
,
'Fe'
,
'Al'
,
'I'
,
'B'
,
'K'
,
'Se'
,
'Zn'
,
'H'
,
'Cu'
,
'Mn'
,
'unknown'
]
...
...
@@ -41,11 +43,9 @@ def mol2dgl_single(smiles):
mol
=
get_mol
(
smiles
)
n_atoms
=
mol
.
GetNumAtoms
()
n_bonds
=
mol
.
GetNumBonds
()
graph
=
DGLGraph
()
for
i
,
atom
in
enumerate
(
mol
.
GetAtoms
()):
assert
i
==
atom
.
GetIdx
()
atom_x
.
append
(
atom_features
(
atom
))
graph
.
add_nodes
(
n_atoms
)
bond_src
=
[]
bond_dst
=
[]
...
...
@@ -60,8 +60,7 @@ def mol2dgl_single(smiles):
bond_src
.
append
(
end_idx
)
bond_dst
.
append
(
begin_idx
)
bond_x
.
append
(
features
)
graph
.
add_edges
(
bond_src
,
bond_dst
)
graph
=
dgl
.
graph
((
bond_src
,
bond_dst
),
num_nodes
=
n_atoms
)
n_edges
+=
n_bonds
return
graph
,
torch
.
stack
(
atom_x
),
\
torch
.
stack
(
bond_x
)
if
len
(
bond_x
)
>
0
else
torch
.
zeros
(
0
)
...
...
@@ -123,7 +122,7 @@ class DGLMPN(nn.Module):
def
forward
(
self
,
mol_graph
):
n_samples
=
mol_graph
.
batch_size
mol_line_graph
=
mol_graph
.
line_graph
(
backtracking
=
False
,
shared
=
True
)
mol_line_graph
=
line_graph
(
mol_graph
,
backtracking
=
False
,
shared
=
True
)
n_nodes
=
mol_graph
.
number_of_nodes
()
n_edges
=
mol_graph
.
number_of_edges
()
...
...
@@ -170,6 +169,7 @@ class DGLMPN(nn.Module):
self
.
loopy_bp_updater
,
)
mol_graph
.
edata
.
update
(
mol_line_graph
.
ndata
)
mol_graph
.
update_all
(
mpn_gather_msg
,
mpn_gather_reduce
,
...
...
examples/pytorch/jtnn/jtnn/nnutils.py
View file @
68a978d4
import
torch
import
torch.nn
as
nn
import
os
import
dgl
def
cuda
(
tensor
):
def
cuda
(
x
):
if
torch
.
cuda
.
is_available
()
and
not
os
.
getenv
(
'NOCUDA'
,
None
):
return
tensor
.
cuda
()
return
x
.
to
(
torch
.
device
(
'cuda'
))
# works for both DGLGraph and tensor
else
:
return
tensor
...
...
@@ -42,7 +43,15 @@ class GRUUpdate(nn.Module):
dic
.
update
(
self
.
update_r
(
node
,
zm
=
dic
))
return
dic
def
move_dgl_to_cuda
(
g
):
g
.
ndata
.
update
({
k
:
cuda
(
g
.
ndata
[
k
])
for
k
in
g
.
ndata
})
g
.
edata
.
update
({
k
:
cuda
(
g
.
edata
[
k
])
for
k
in
g
.
edata
})
def
tocpu
(
g
):
src
,
dst
=
g
.
edges
()
src
=
src
.
cpu
()
dst
=
dst
.
cpu
()
return
dgl
.
graph
((
src
,
dst
),
num_nodes
=
g
.
number_of_nodes
())
def
line_graph
(
g
,
backtracking
=
True
,
shared
=
False
):
#g2 = tocpu(g)
g2
=
dgl
.
line_graph
(
g
,
backtracking
,
shared
)
#g2 = g2.to(g.device)
g2
.
ndata
.
update
(
g
.
edata
)
return
g2
examples/pytorch/jtnn/vaetrain_dgl.py
View file @
68a978d4
...
...
@@ -8,6 +8,7 @@ import math, random, sys
from
optparse
import
OptionParser
from
collections
import
deque
import
rdkit
import
tqdm
from
jtnn
import
*
...
...
@@ -69,7 +70,7 @@ def train():
dataset
,
batch_size
=
batch_size
,
shuffle
=
True
,
num_workers
=
4
,
num_workers
=
0
,
collate_fn
=
JTNNCollator
(
vocab
,
True
),
drop_last
=
True
,
worker_init_fn
=
worker_init_fn
)
...
...
@@ -77,7 +78,7 @@ def train():
for
epoch
in
range
(
MAX_EPOCH
):
word_acc
,
topo_acc
,
assm_acc
,
steo_acc
=
0
,
0
,
0
,
0
for
it
,
batch
in
enumerate
(
dataloader
):
for
it
,
batch
in
tqdm
.
tqdm
(
enumerate
(
dataloader
)
,
total
=
2000
)
:
model
.
zero_grad
()
try
:
loss
,
kl_div
,
wacc
,
tacc
,
sacc
,
dacc
=
model
(
batch
,
beta
)
...
...
python/dgl/_ffi/object.py
View file @
68a978d4
...
...
@@ -28,7 +28,13 @@ def _new_object(cls):
class
ObjectBase
(
_ObjectBase
):
"""ObjectBase is the base class of all DGL CAPI object."""
"""ObjectBase is the base class of all DGL CAPI object.
The core attribute is ``handle``, which is a C raw pointer. It must be initialized
via ``__init_handle_by_constructor__``.
Note that the same handle **CANNOT** be shared across multiple ObjectBase instances.
"""
def
__dir__
(
self
):
plist
=
ctypes
.
POINTER
(
ctypes
.
c_char_p
)()
size
=
ctypes
.
c_uint
()
...
...
python/dgl/heterograph.py
View file @
68a978d4
...
...
@@ -268,9 +268,6 @@ class DGLHeteroGraph(object):
self
.
_etype2canonical
[
ety
]
=
self
.
_canonical_etypes
[
i
]
self
.
_etypes_invmap
=
{
t
:
i
for
i
,
t
in
enumerate
(
self
.
_canonical_etypes
)}
# Cached metagraph in networkx
self
.
_nx_metagraph
=
None
# node and edge frame
if
node_frames
is
None
:
node_frames
=
[
None
]
*
len
(
self
.
_ntypes
)
...
...
@@ -286,27 +283,12 @@ class DGLHeteroGraph(object):
for
i
,
frame
in
enumerate
(
edge_frames
)]
self
.
_edge_frames
=
edge_frames
def
__getstate__
(
self
):
metainfo
=
(
self
.
_ntypes
,
self
.
_etypes
,
self
.
_canonical_etypes
,
self
.
_srctypes_invmap
,
self
.
_dsttypes_invmap
,
self
.
_is_unibipartite
,
self
.
_etype2canonical
,
self
.
_etypes_invmap
)
return
(
self
.
_graph
,
metainfo
,
self
.
_node_frames
,
self
.
_edge_frames
,
self
.
_batch_num_nodes
,
self
.
_batch_num_edges
)
def
__setstate__
(
self
,
state
):
# Compatibility check
# TODO: version the storage
if
isinstance
(
state
,
tuple
)
and
len
(
state
)
==
6
:
# DGL >= 0.5
#TODO(minjie): too many states in python; should clean up and lower to C
self
.
_nx_metagraph
=
None
(
self
.
_graph
,
metainfo
,
self
.
_node_frames
,
self
.
_edge_frames
,
self
.
_batch_num_nodes
,
self
.
_batch_num_edges
)
=
state
(
self
.
_ntypes
,
self
.
_etypes
,
self
.
_canonical_etypes
,
self
.
_srctypes_invmap
,
self
.
_dsttypes_invmap
,
self
.
_is_unibipartite
,
self
.
_etype2canonical
,
self
.
_etypes_invmap
)
=
metainfo
if
isinstance
(
state
,
dict
):
# Since 0.5 we use the default __dict__ method
self
.
__dict__
.
update
(
state
)
elif
isinstance
(
state
,
tuple
)
and
len
(
state
)
==
5
:
# DGL == 0.4.3
dgl_warning
(
"The object is pickled with DGL == 0.4.3. "
...
...
@@ -337,7 +319,7 @@ class DGLHeteroGraph(object):
for
i
in
range
(
len
(
self
.
ntypes
))}
nedge_dict
=
{
self
.
canonical_etypes
[
i
]
:
self
.
_graph
.
number_of_edges
(
i
)
for
i
in
range
(
len
(
self
.
etypes
))}
meta
=
str
(
self
.
metagraph
.
edges
(
keys
=
True
))
meta
=
str
(
self
.
metagraph
()
.
edges
(
keys
=
True
))
return
ret
.
format
(
node
=
nnode_dict
,
edge
=
nedge_dict
,
meta
=
meta
)
def
__copy__
(
self
):
...
...
@@ -345,20 +327,7 @@ class DGLHeteroGraph(object):
#TODO(minjie): too many states in python; should clean up and lower to C
cls
=
type
(
self
)
obj
=
cls
.
__new__
(
cls
)
obj
.
_graph
=
self
.
_graph
obj
.
_batch_num_nodes
=
self
.
_batch_num_nodes
obj
.
_batch_num_edges
=
self
.
_batch_num_edges
obj
.
_ntypes
=
self
.
_ntypes
obj
.
_etypes
=
self
.
_etypes
obj
.
_canonical_etypes
=
self
.
_canonical_etypes
obj
.
_srctypes_invmap
=
self
.
_srctypes_invmap
obj
.
_dsttypes_invmap
=
self
.
_dsttypes_invmap
obj
.
_is_unibipartite
=
self
.
_is_unibipartite
obj
.
_etype2canonical
=
self
.
_etype2canonical
obj
.
_etypes_invmap
=
self
.
_etypes_invmap
obj
.
_nx_metagraph
=
self
.
_nx_metagraph
obj
.
_node_frames
=
self
.
_node_frames
obj
.
_edge_frames
=
self
.
_edge_frames
obj
.
__dict__
.
update
(
self
.
__dict__
)
return
obj
#################################################################
...
...
@@ -975,7 +944,6 @@ class DGLHeteroGraph(object):
else
:
return
self
.
ntypes
@
property
def
metagraph
(
self
):
"""Return the metagraph as networkx.MultiDiGraph.
...
...
@@ -992,7 +960,7 @@ class DGLHeteroGraph(object):
>>> follows_g = dgl.graph(([0, 1], [1, 2]), 'user', 'follows')
>>> plays_g = dgl.bipartite(([0, 1, 1, 2], [0, 0, 1, 1]), 'user', 'plays', 'game')
>>> g = dgl.hetero_from_relations([follows_g, plays_g])
>>> meta_g = g.metagraph
>>> meta_g = g.metagraph
()
The metagraph then has two nodes and two edges.
...
...
@@ -1005,13 +973,12 @@ class DGLHeteroGraph(object):
>>> meta_g.number_of_edges()
2
"""
if
self
.
_nx_metagraph
is
None
:
nx_graph
=
self
.
_graph
.
metagraph
.
to_networkx
()
self
.
_nx_metagraph
=
nx
.
MultiDiGraph
()
for
u_v
in
nx_graph
.
edges
:
srctype
,
etype
,
dsttype
=
self
.
canonical_etypes
[
nx_graph
.
edges
[
u_v
][
'id'
]]
self
.
_nx_metagraph
.
add_edge
(
srctype
,
dsttype
,
etype
)
return
self
.
_nx_metagraph
nx_graph
=
self
.
_graph
.
metagraph
.
to_networkx
()
nx_metagraph
=
nx
.
MultiDiGraph
()
for
u_v
in
nx_graph
.
edges
:
srctype
,
etype
,
dsttype
=
self
.
canonical_etypes
[
nx_graph
.
edges
[
u_v
][
'id'
]]
nx_metagraph
.
add_edge
(
srctype
,
dsttype
,
etype
)
return
nx_metagraph
def
to_canonical_etype
(
self
,
etype
):
"""Convert edge type to canonical etype: (srctype, etype, dsttype).
...
...
@@ -5282,7 +5249,7 @@ class DGLBlock(DGLHeteroGraph):
for
ntype
in
self
.
dsttypes
}
nedge_dict
=
{
etype
:
self
.
number_of_edges
(
etype
)
for
etype
in
self
.
canonical_etypes
}
meta
=
str
(
self
.
metagraph
.
edges
(
keys
=
True
))
meta
=
str
(
self
.
metagraph
()
.
edges
(
keys
=
True
))
return
ret
.
format
(
srcnode
=
nsrcnode_dict
,
dstnode
=
ndstnode_dict
,
edge
=
nedge_dict
,
meta
=
meta
)
...
...
python/dgl/sampling/pinsage.py
View file @
68a978d4
...
...
@@ -211,7 +211,7 @@ class PinSAGESampler(RandomWalkNeighborSampler):
"""
def
__init__
(
self
,
G
,
ntype
,
other_type
,
random_walk_length
,
random_walk_restart_prob
,
num_random_walks
,
num_neighbors
,
weight_column
=
'weights'
):
metagraph
=
G
.
metagraph
metagraph
=
G
.
metagraph
()
fw_etype
=
list
(
metagraph
[
ntype
][
other_type
])[
0
]
bw_etype
=
list
(
metagraph
[
other_type
][
ntype
])[
0
]
super
().
__init__
(
G
,
random_walk_length
,
...
...
src/array/cuda/csr_transpose.cc
View file @
68a978d4
...
...
@@ -33,7 +33,9 @@ CSRMatrix CSRTranspose(CSRMatrix csr) {
const
int32_t
*
indices_ptr
=
static_cast
<
int32_t
*>
(
indices
->
data
);
const
void
*
data_ptr
=
data
->
data
;
NDArray
t_indptr
=
aten
::
NewIdArray
(
csr
.
num_cols
+
1
,
ctx
,
bits
);
// (BarclayII) csr2csc doesn't seem to clear the content of cscColPtr if nnz == 0.
// We need to do it ourselves.
NDArray
t_indptr
=
aten
::
Full
(
0
,
csr
.
num_cols
+
1
,
bits
,
ctx
);
NDArray
t_indices
=
aten
::
NewIdArray
(
nnz
,
ctx
,
bits
);
NDArray
t_data
=
aten
::
NewIdArray
(
nnz
,
ctx
,
bits
);
int32_t
*
t_indptr_ptr
=
static_cast
<
int32_t
*>
(
t_indptr
->
data
);
...
...
tests/compute/test_heterograph.py
View file @
68a978d4
...
...
@@ -224,7 +224,7 @@ def test_query(idtype):
assert
set
(
canonical_etypes
)
==
set
(
g
.
canonical_etypes
)
# metagraph
mg
=
g
.
metagraph
mg
=
g
.
metagraph
()
assert
set
(
g
.
ntypes
)
==
set
(
mg
.
nodes
)
etype_triplets
=
[(
u
,
v
,
e
)
for
u
,
v
,
e
in
mg
.
edges
(
keys
=
True
)]
assert
set
([
...
...
tests/compute/test_pickle.py
View file @
68a978d4
...
...
@@ -34,8 +34,8 @@ def _assert_is_identical_hetero(g, g2):
assert
g
.
canonical_etypes
==
g2
.
canonical_etypes
# check if two metagraphs are identical
for
edges
,
features
in
g
.
metagraph
.
edges
(
keys
=
True
).
items
():
assert
g2
.
metagraph
.
edges
(
keys
=
True
)[
edges
]
==
features
for
edges
,
features
in
g
.
metagraph
()
.
edges
(
keys
=
True
).
items
():
assert
g2
.
metagraph
()
.
edges
(
keys
=
True
)[
edges
]
==
features
# check if node ID spaces and feature spaces are equal
for
ntype
in
g
.
ntypes
:
...
...
tests/compute/test_shared_mem.py
View file @
68a978d4
...
...
@@ -35,8 +35,8 @@ def _assert_is_identical_hetero(g, g2):
assert
g
.
canonical_etypes
==
g2
.
canonical_etypes
# check if two metagraphs are identical
for
edges
,
features
in
g
.
metagraph
.
edges
(
keys
=
True
).
items
():
assert
g2
.
metagraph
.
edges
(
keys
=
True
)[
edges
]
==
features
for
edges
,
features
in
g
.
metagraph
()
.
edges
(
keys
=
True
).
items
():
assert
g2
.
metagraph
()
.
edges
(
keys
=
True
)[
edges
]
==
features
# check if node ID spaces and feature spaces are equal
for
ntype
in
g
.
ntypes
:
...
...
@@ -89,4 +89,4 @@ def test_copy_from_gpu():
if
__name__
==
"__main__"
:
test_single_process
(
F
.
int64
)
test_multi_process
(
F
.
int32
)
test_copy_from_gpu
()
\ No newline at end of file
test_copy_from_gpu
()
tests/test_utils/checks.py
View file @
68a978d4
...
...
@@ -17,8 +17,8 @@ def check_graph_equal(g1, g2, *,
assert
g1
.
batch_size
==
g2
.
batch_size
# check if two metagraphs are identical
for
edges
,
features
in
g1
.
metagraph
.
edges
(
keys
=
True
).
items
():
assert
g2
.
metagraph
.
edges
(
keys
=
True
)[
edges
]
==
features
for
edges
,
features
in
g1
.
metagraph
()
.
edges
(
keys
=
True
).
items
():
assert
g2
.
metagraph
()
.
edges
(
keys
=
True
)[
edges
]
==
features
for
nty
in
g1
.
ntypes
:
assert
g1
.
number_of_nodes
(
nty
)
==
g2
.
number_of_nodes
(
nty
)
...
...
tutorials/basics/5_hetero.py
View file @
68a978d4
...
...
@@ -234,7 +234,7 @@ def plot_graph(nxg):
ag
.
layout
(
'dot'
)
ag
.
draw
(
'graph.png'
)
plot_graph
(
G
.
metagraph
)
plot_graph
(
G
.
metagraph
()
)
###############################################################################
# Learning tasks associated with heterographs
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment