Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
704bcaf6
Unverified
Commit
704bcaf6
authored
Feb 19, 2023
by
Hongzhi (Steve), Chen
Committed by
GitHub
Feb 19, 2023
Browse files
examples (#5323)
Co-authored-by:
Ubuntu
<
ubuntu@ip-172-31-28-63.ap-northeast-1.compute.internal
>
parent
6bc82161
Changes
332
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
553 additions
and
326 deletions
+553
-326
examples/pytorch/jknet/main.py
examples/pytorch/jknet/main.py
+2
-2
examples/pytorch/jknet/model.py
examples/pytorch/jknet/model.py
+1
-2
examples/pytorch/jtnn/jtnn/datautils.py
examples/pytorch/jtnn/jtnn/datautils.py
+7
-6
examples/pytorch/jtnn/jtnn/jtmpn.py
examples/pytorch/jtnn/jtnn/jtmpn.py
+151
-75
examples/pytorch/jtnn/jtnn/jtnn_dec.py
examples/pytorch/jtnn/jtnn/jtnn_dec.py
+186
-122
examples/pytorch/jtnn/jtnn/jtnn_enc.py
examples/pytorch/jtnn/jtnn/jtnn_enc.py
+2
-3
examples/pytorch/jtnn/jtnn/jtnn_vae.py
examples/pytorch/jtnn/jtnn/jtnn_vae.py
+2
-4
examples/pytorch/jtnn/jtnn/mol_tree_nx.py
examples/pytorch/jtnn/jtnn/mol_tree_nx.py
+1
-2
examples/pytorch/jtnn/jtnn/mpn.py
examples/pytorch/jtnn/jtnn/mpn.py
+2
-3
examples/pytorch/jtnn/jtnn/nnutils.py
examples/pytorch/jtnn/jtnn/nnutils.py
+2
-2
examples/pytorch/label_propagation/main.py
examples/pytorch/label_propagation/main.py
+2
-2
examples/pytorch/lda/example_20newsgroups.py
examples/pytorch/lda/example_20newsgroups.py
+4
-4
examples/pytorch/lda/lda_model.py
examples/pytorch/lda/lda_model.py
+2
-2
examples/pytorch/line_graph/gnn.py
examples/pytorch/line_graph/gnn.py
+49
-21
examples/pytorch/line_graph/train.py
examples/pytorch/line_graph/train.py
+1
-1
examples/pytorch/metapath2vec/download.py
examples/pytorch/metapath2vec/download.py
+13
-10
examples/pytorch/metapath2vec/metapath2vec.py
examples/pytorch/metapath2vec/metapath2vec.py
+50
-25
examples/pytorch/metapath2vec/model.py
examples/pytorch/metapath2vec/model.py
+4
-5
examples/pytorch/metapath2vec/reading_data.py
examples/pytorch/metapath2vec/reading_data.py
+40
-13
examples/pytorch/metapath2vec/sampler.py
examples/pytorch/metapath2vec/sampler.py
+32
-22
No files found.
examples/pytorch/jknet/main.py
View file @
704bcaf6
...
...
@@ -7,12 +7,12 @@ import numpy as np
import
torch
import
torch.nn
as
nn
import
torch.optim
as
optim
from
dgl.data
import
CiteseerGraphDataset
,
CoraGraphDataset
from
model
import
JKNet
from
sklearn.model_selection
import
train_test_split
from
tqdm
import
trange
from
dgl.data
import
CiteseerGraphDataset
,
CoraGraphDataset
def
main
(
args
):
# Step 1: Prepare graph data and retrieve train/validation/test index ============================= #
...
...
examples/pytorch/jknet/model.py
View file @
704bcaf6
import
dgl.function
as
fn
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
dgl.function
as
fn
from
dgl.nn
import
GraphConv
,
JumpingKnowledge
...
...
examples/pytorch/jtnn/jtnn/datautils.py
View file @
704bcaf6
import
torch
from
torch.utils.data
import
Dataset
import
dgl
import
torch
from
dgl.data.utils
import
(
_get_dgl_url
,
download
,
extract_archive
,
get_download_dir
,
)
from
torch.utils.data
import
Dataset
from
.jtmpn
import
ATOM_FDIM
as
ATOM_FDIM_DEC
from
.jtmpn
import
BOND_FDIM
as
BOND_FDIM_DEC
from
.jtmpn
import
mol2dgl_single
as
mol2dgl_dec
from
.jtmpn
import
(
ATOM_FDIM
as
ATOM_FDIM_DEC
,
BOND_FDIM
as
BOND_FDIM_DEC
,
mol2dgl_single
as
mol2dgl_dec
,
)
from
.mol_tree
import
Vocab
from
.mol_tree_nx
import
DGLMolTree
from
.mpn
import
mol2dgl_single
as
mol2dgl_enc
...
...
examples/pytorch/jtnn/jtnn/jtmpn.py
View file @
704bcaf6
import
os
import
dgl
import
dgl.function
as
DGLF
import
rdkit.Chem
as
Chem
import
torch
import
torch.nn
as
nn
from
dgl
import
line_graph
,
mean_nodes
from
.nnutils
import
cuda
import
rdkit.Chem
as
Chem
import
dgl
from
dgl
import
mean_nodes
,
line_graph
import
dgl.function
as
DGLF
import
os
ELEM_LIST
=
[
'C'
,
'N'
,
'O'
,
'S'
,
'F'
,
'Si'
,
'P'
,
'Cl'
,
'Br'
,
'Mg'
,
'Na'
,
'Ca'
,
'Fe'
,
'Al'
,
'I'
,
'B'
,
'K'
,
'Se'
,
'Zn'
,
'H'
,
'Cu'
,
'Mn'
,
'unknown'
]
ELEM_LIST
=
[
"C"
,
"N"
,
"O"
,
"S"
,
"F"
,
"Si"
,
"P"
,
"Cl"
,
"Br"
,
"Mg"
,
"Na"
,
"Ca"
,
"Fe"
,
"Al"
,
"I"
,
"B"
,
"K"
,
"Se"
,
"Zn"
,
"H"
,
"Cu"
,
"Mn"
,
"unknown"
,
]
ATOM_FDIM
=
len
(
ELEM_LIST
)
+
6
+
5
+
1
BOND_FDIM
=
5
BOND_FDIM
=
5
MAX_NB
=
10
PAPER
=
os
.
getenv
(
'PAPER'
,
False
)
PAPER
=
os
.
getenv
(
"PAPER"
,
False
)
def
onek_encoding_unk
(
x
,
allowable_set
):
if
x
not
in
allowable_set
:
x
=
allowable_set
[
-
1
]
return
[
x
==
s
for
s
in
allowable_set
]
# Note that during graph decoding they don't predict stereochemistry-related
# characteristics (i.e. Chiral Atoms, E-Z, Cis-Trans). Instead, they decode
# the 2-D graph first, then enumerate all possible 3-D forms and find the
# one with highest score.
def
atom_features
(
atom
):
return
(
torch
.
Tensor
(
onek_encoding_unk
(
atom
.
GetSymbol
(),
ELEM_LIST
)
+
onek_encoding_unk
(
atom
.
GetDegree
(),
[
0
,
1
,
2
,
3
,
4
,
5
])
+
onek_encoding_unk
(
atom
.
GetFormalCharge
(),
[
-
1
,
-
2
,
1
,
2
,
0
])
+
[
atom
.
GetIsAromatic
()]))
return
torch
.
Tensor
(
onek_encoding_unk
(
atom
.
GetSymbol
(),
ELEM_LIST
)
+
onek_encoding_unk
(
atom
.
GetDegree
(),
[
0
,
1
,
2
,
3
,
4
,
5
])
+
onek_encoding_unk
(
atom
.
GetFormalCharge
(),
[
-
1
,
-
2
,
1
,
2
,
0
])
+
[
atom
.
GetIsAromatic
()]
)
def
bond_features
(
bond
):
bt
=
bond
.
GetBondType
()
return
(
torch
.
Tensor
([
bt
==
Chem
.
rdchem
.
BondType
.
SINGLE
,
bt
==
Chem
.
rdchem
.
BondType
.
DOUBLE
,
bt
==
Chem
.
rdchem
.
BondType
.
TRIPLE
,
bt
==
Chem
.
rdchem
.
BondType
.
AROMATIC
,
bond
.
IsInRing
()]))
return
torch
.
Tensor
(
[
bt
==
Chem
.
rdchem
.
BondType
.
SINGLE
,
bt
==
Chem
.
rdchem
.
BondType
.
DOUBLE
,
bt
==
Chem
.
rdchem
.
BondType
.
TRIPLE
,
bt
==
Chem
.
rdchem
.
BondType
.
AROMATIC
,
bond
.
IsInRing
(),
]
)
def
mol2dgl_single
(
cand_batch
):
cand_graphs
=
[]
tree_mess_source_edges
=
[]
# map these edges from trees to...
tree_mess_target_edges
=
[]
# these edges on candidate graphs
tree_mess_source_edges
=
[]
# map these edges from trees to...
tree_mess_target_edges
=
[]
# these edges on candidate graphs
tree_mess_target_nodes
=
[]
n_nodes
=
0
n_edges
=
0
...
...
@@ -50,8 +89,8 @@ def mol2dgl_single(cand_batch):
n_bonds
=
mol
.
GetNumBonds
()
ctr_node
=
mol_tree
.
nodes_dict
[
ctr_node_id
]
ctr_bid
=
ctr_node
[
'
idx
'
]
mol_tree_graph
=
getattr
(
mol_tree
,
'
graph
'
,
mol_tree
)
ctr_bid
=
ctr_node
[
"
idx
"
]
mol_tree_graph
=
getattr
(
mol_tree
,
"
graph
"
,
mol_tree
)
for
i
,
atom
in
enumerate
(
mol
.
GetAtoms
()):
assert
i
==
atom
.
GetIdx
()
...
...
@@ -75,15 +114,19 @@ def mol2dgl_single(cand_batch):
x_nid
,
y_nid
=
a1
.
GetAtomMapNum
(),
a2
.
GetAtomMapNum
()
# Tree node ID in the batch
x_bid
=
mol_tree
.
nodes_dict
[
x_nid
-
1
][
'
idx
'
]
if
x_nid
>
0
else
-
1
y_bid
=
mol_tree
.
nodes_dict
[
y_nid
-
1
][
'
idx
'
]
if
y_nid
>
0
else
-
1
x_bid
=
mol_tree
.
nodes_dict
[
x_nid
-
1
][
"
idx
"
]
if
x_nid
>
0
else
-
1
y_bid
=
mol_tree
.
nodes_dict
[
y_nid
-
1
][
"
idx
"
]
if
y_nid
>
0
else
-
1
if
x_bid
>=
0
and
y_bid
>=
0
and
x_bid
!=
y_bid
:
if
mol_tree_graph
.
has_edges_between
(
x_bid
,
y_bid
):
tree_mess_target_edges
.
append
((
begin_idx
+
n_nodes
,
end_idx
+
n_nodes
))
tree_mess_target_edges
.
append
(
(
begin_idx
+
n_nodes
,
end_idx
+
n_nodes
)
)
tree_mess_source_edges
.
append
((
x_bid
,
y_bid
))
tree_mess_target_nodes
.
append
(
end_idx
+
n_nodes
)
if
mol_tree_graph
.
has_edges_between
(
y_bid
,
x_bid
):
tree_mess_target_edges
.
append
((
end_idx
+
n_nodes
,
begin_idx
+
n_nodes
))
tree_mess_target_edges
.
append
(
(
end_idx
+
n_nodes
,
begin_idx
+
n_nodes
)
)
tree_mess_source_edges
.
append
((
y_bid
,
x_bid
))
tree_mess_target_nodes
.
append
(
begin_idx
+
n_nodes
)
...
...
@@ -91,11 +134,14 @@ def mol2dgl_single(cand_batch):
g
=
dgl
.
graph
((
bond_src
,
bond_dst
),
num_nodes
=
n_atoms
)
cand_graphs
.
append
(
g
)
return
cand_graphs
,
torch
.
stack
(
atom_x
),
\
torch
.
stack
(
bond_x
)
if
len
(
bond_x
)
>
0
else
torch
.
zeros
(
0
),
\
torch
.
LongTensor
(
tree_mess_source_edges
),
\
torch
.
LongTensor
(
tree_mess_target_edges
),
\
torch
.
LongTensor
(
tree_mess_target_nodes
)
return
(
cand_graphs
,
torch
.
stack
(
atom_x
),
torch
.
stack
(
bond_x
)
if
len
(
bond_x
)
>
0
else
torch
.
zeros
(
0
),
torch
.
LongTensor
(
tree_mess_source_edges
),
torch
.
LongTensor
(
tree_mess_target_edges
),
torch
.
LongTensor
(
tree_mess_target_nodes
),
)
class
LoopyBPUpdate
(
nn
.
Module
):
...
...
@@ -106,28 +152,28 @@ class LoopyBPUpdate(nn.Module):
self
.
W_h
=
nn
.
Linear
(
hidden_size
,
hidden_size
,
bias
=
False
)
def
forward
(
self
,
node
):
msg_input
=
node
.
data
[
'
msg_input
'
]
msg_delta
=
self
.
W_h
(
node
.
data
[
'
accum_msg
'
]
+
node
.
data
[
'
alpha
'
])
msg_input
=
node
.
data
[
"
msg_input
"
]
msg_delta
=
self
.
W_h
(
node
.
data
[
"
accum_msg
"
]
+
node
.
data
[
"
alpha
"
])
msg
=
torch
.
relu
(
msg_input
+
msg_delta
)
return
{
'
msg
'
:
msg
}
return
{
"
msg
"
:
msg
}
if
PAPER
:
mpn_gather_msg
=
[
DGLF
.
copy_e
(
edge
=
'
msg
'
,
out
=
'
msg
'
),
DGLF
.
copy_e
(
edge
=
'
alpha
'
,
out
=
'
alpha
'
)
DGLF
.
copy_e
(
edge
=
"
msg
"
,
out
=
"
msg
"
),
DGLF
.
copy_e
(
edge
=
"
alpha
"
,
out
=
"
alpha
"
),
]
else
:
mpn_gather_msg
=
DGLF
.
copy_e
(
edge
=
'
msg
'
,
out
=
'
msg
'
)
mpn_gather_msg
=
DGLF
.
copy_e
(
edge
=
"
msg
"
,
out
=
"
msg
"
)
if
PAPER
:
mpn_gather_reduce
=
[
DGLF
.
sum
(
msg
=
'
msg
'
,
out
=
'm'
),
DGLF
.
sum
(
msg
=
'
alpha
'
,
out
=
'
accum_alpha
'
),
DGLF
.
sum
(
msg
=
"
msg
"
,
out
=
"m"
),
DGLF
.
sum
(
msg
=
"
alpha
"
,
out
=
"
accum_alpha
"
),
]
else
:
mpn_gather_reduce
=
DGLF
.
sum
(
msg
=
'
msg
'
,
out
=
'm'
)
mpn_gather_reduce
=
DGLF
.
sum
(
msg
=
"
msg
"
,
out
=
"m"
)
class
GatherUpdate
(
nn
.
Module
):
...
...
@@ -139,12 +185,12 @@ class GatherUpdate(nn.Module):
def
forward
(
self
,
node
):
if
PAPER
:
#m = node['m']
m
=
node
.
data
[
'm'
]
+
node
.
data
[
'
accum_alpha
'
]
#
m = node['m']
m
=
node
.
data
[
"m"
]
+
node
.
data
[
"
accum_alpha
"
]
else
:
m
=
node
.
data
[
'm'
]
+
node
.
data
[
'
alpha
'
]
m
=
node
.
data
[
"m"
]
+
node
.
data
[
"
alpha
"
]
return
{
'h'
:
torch
.
relu
(
self
.
W_o
(
torch
.
cat
([
node
.
data
[
'x'
],
m
],
1
))),
"h"
:
torch
.
relu
(
self
.
W_o
(
torch
.
cat
([
node
.
data
[
"x"
],
m
],
1
))),
}
...
...
@@ -166,20 +212,32 @@ class DGLJTMPN(nn.Module):
self
.
n_passes
=
0
def
forward
(
self
,
cand_batch
,
mol_tree_batch
):
cand_graphs
,
tree_mess_src_edges
,
tree_mess_tgt_edges
,
tree_mess_tgt_nodes
=
cand_batch
(
cand_graphs
,
tree_mess_src_edges
,
tree_mess_tgt_edges
,
tree_mess_tgt_nodes
,
)
=
cand_batch
n_samples
=
len
(
cand_graphs
)
cand_line_graph
=
line_graph
(
cand_graphs
,
backtracking
=
False
,
shared
=
True
)
cand_line_graph
=
line_graph
(
cand_graphs
,
backtracking
=
False
,
shared
=
True
)
n_nodes
=
cand_graphs
.
number_of_nodes
()
n_edges
=
cand_graphs
.
number_of_edges
()
cand_graphs
=
self
.
run
(
cand_graphs
,
cand_line_graph
,
tree_mess_src_edges
,
tree_mess_tgt_edges
,
tree_mess_tgt_nodes
,
mol_tree_batch
)
cand_graphs
,
cand_line_graph
,
tree_mess_src_edges
,
tree_mess_tgt_edges
,
tree_mess_tgt_nodes
,
mol_tree_batch
,
)
g_repr
=
mean_nodes
(
cand_graphs
,
'h'
)
g_repr
=
mean_nodes
(
cand_graphs
,
"h"
)
self
.
n_samples_total
+=
n_samples
self
.
n_nodes_total
+=
n_nodes
...
...
@@ -188,33 +246,45 @@ class DGLJTMPN(nn.Module):
return
g_repr
def
run
(
self
,
cand_graphs
,
cand_line_graph
,
tree_mess_src_edges
,
tree_mess_tgt_edges
,
tree_mess_tgt_nodes
,
mol_tree_batch
):
def
run
(
self
,
cand_graphs
,
cand_line_graph
,
tree_mess_src_edges
,
tree_mess_tgt_edges
,
tree_mess_tgt_nodes
,
mol_tree_batch
,
):
n_nodes
=
cand_graphs
.
number_of_nodes
()
cand_graphs
.
apply_edges
(
func
=
lambda
edges
:
{
'
src_x
'
:
edges
.
src
[
'x'
]},
func
=
lambda
edges
:
{
"
src_x
"
:
edges
.
src
[
"x"
]},
)
cand_line_graph
.
ndata
.
update
(
cand_graphs
.
edata
)
bond_features
=
cand_line_graph
.
ndata
[
'x'
]
source_features
=
cand_line_graph
.
ndata
[
'
src_x
'
]
bond_features
=
cand_line_graph
.
ndata
[
"x"
]
source_features
=
cand_line_graph
.
ndata
[
"
src_x
"
]
features
=
torch
.
cat
([
source_features
,
bond_features
],
1
)
msg_input
=
self
.
W_i
(
features
)
cand_line_graph
.
ndata
.
update
({
'msg_input'
:
msg_input
,
'msg'
:
torch
.
relu
(
msg_input
),
'accum_msg'
:
torch
.
zeros_like
(
msg_input
),
})
cand_line_graph
.
ndata
.
update
(
{
"msg_input"
:
msg_input
,
"msg"
:
torch
.
relu
(
msg_input
),
"accum_msg"
:
torch
.
zeros_like
(
msg_input
),
}
)
zero_node_state
=
bond_features
.
new
(
n_nodes
,
self
.
hidden_size
).
zero_
()
cand_graphs
.
ndata
.
update
({
'm'
:
zero_node_state
.
clone
(),
'h'
:
zero_node_state
.
clone
(),
})
cand_graphs
.
edata
[
'alpha'
]
=
\
cuda
(
torch
.
zeros
(
cand_graphs
.
number_of_edges
(),
self
.
hidden_size
))
cand_graphs
.
ndata
[
'alpha'
]
=
zero_node_state
cand_graphs
.
ndata
.
update
(
{
"m"
:
zero_node_state
.
clone
(),
"h"
:
zero_node_state
.
clone
(),
}
)
cand_graphs
.
edata
[
"alpha"
]
=
cuda
(
torch
.
zeros
(
cand_graphs
.
number_of_edges
(),
self
.
hidden_size
)
)
cand_graphs
.
ndata
[
"alpha"
]
=
zero_node_state
if
tree_mess_src_edges
.
shape
[
0
]
>
0
:
if
PAPER
:
src_u
,
src_v
=
tree_mess_src_edges
.
unbind
(
1
)
...
...
@@ -222,33 +292,39 @@ class DGLJTMPN(nn.Module):
src_u
=
src_u
.
to
(
mol_tree_batch
.
device
)
src_v
=
src_v
.
to
(
mol_tree_batch
.
device
)
eid
=
mol_tree_batch
.
edge_ids
(
src_u
,
src_v
)
alpha
=
mol_tree_batch
.
edata
[
'm'
][
eid
]
cand_graphs
.
edges
[
tgt_u
,
tgt_v
].
data
[
'
alpha
'
]
=
alpha
alpha
=
mol_tree_batch
.
edata
[
"m"
][
eid
]
cand_graphs
.
edges
[
tgt_u
,
tgt_v
].
data
[
"
alpha
"
]
=
alpha
else
:
src_u
,
src_v
=
tree_mess_src_edges
.
unbind
(
1
)
src_u
=
src_u
.
to
(
mol_tree_batch
.
device
)
src_v
=
src_v
.
to
(
mol_tree_batch
.
device
)
eid
=
mol_tree_batch
.
edge_ids
(
src_u
,
src_v
)
alpha
=
mol_tree_batch
.
edata
[
'm'
][
eid
]
node_idx
=
(
tree_mess_tgt_nodes
.
to
(
device
=
zero_node_state
.
device
)[:,
None
]
.
expand_as
(
alpha
))
node_alpha
=
zero_node_state
.
clone
().
scatter_add
(
0
,
node_idx
,
alpha
)
cand_graphs
.
ndata
[
'alpha'
]
=
node_alpha
alpha
=
mol_tree_batch
.
edata
[
"m"
][
eid
]
node_idx
=
tree_mess_tgt_nodes
.
to
(
device
=
zero_node_state
.
device
)[:,
None
].
expand_as
(
alpha
)
node_alpha
=
zero_node_state
.
clone
().
scatter_add
(
0
,
node_idx
,
alpha
)
cand_graphs
.
ndata
[
"alpha"
]
=
node_alpha
cand_graphs
.
apply_edges
(
func
=
lambda
edges
:
{
'
alpha
'
:
edges
.
src
[
'
alpha
'
]},
func
=
lambda
edges
:
{
"
alpha
"
:
edges
.
src
[
"
alpha
"
]},
)
cand_line_graph
.
ndata
.
update
(
cand_graphs
.
edata
)
for
i
in
range
(
self
.
depth
-
1
):
cand_line_graph
.
update_all
(
DGLF
.
copy_u
(
'msg'
,
'msg'
),
DGLF
.
sum
(
'msg'
,
'accum_msg'
))
cand_line_graph
.
update_all
(
DGLF
.
copy_u
(
"msg"
,
"msg"
),
DGLF
.
sum
(
"msg"
,
"accum_msg"
)
)
cand_line_graph
.
apply_nodes
(
self
.
loopy_bp_updater
)
cand_graphs
.
edata
.
update
(
cand_line_graph
.
ndata
)
cand_graphs
.
update_all
(
DGLF
.
copy_e
(
'
msg
'
,
'
msg
'
),
DGLF
.
sum
(
'
msg
'
,
'm'
))
cand_graphs
.
update_all
(
DGLF
.
copy_e
(
"
msg
"
,
"
msg
"
),
DGLF
.
sum
(
"
msg
"
,
"m"
))
if
PAPER
:
cand_graphs
.
update_all
(
DGLF
.
copy_e
(
'alpha'
,
'alpha'
),
DGLF
.
sum
(
'alpha'
,
'accum_alpha'
))
cand_graphs
.
update_all
(
DGLF
.
copy_e
(
"alpha"
,
"alpha"
),
DGLF
.
sum
(
"alpha"
,
"accum_alpha"
)
)
cand_graphs
.
apply_nodes
(
self
.
gather_updater
)
return
cand_graphs
examples/pytorch/jtnn/jtnn/jtnn_dec.py
View file @
704bcaf6
import
dgl.function
as
DGLF
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
.mol_tree_nx
import
DGLMolTree
from
.chemutils
import
enum_assemble_nx
,
get_mol
from
.nnutils
import
GRUUpdate
,
cuda
,
tocpu
from
dgl
import
batch
,
dfs_labeled_edges_generator
,
line_graph
import
dgl.function
as
DGLF
import
numpy
as
np
from
.chemutils
import
enum_assemble_nx
,
get_mol
from
.mol_tree_nx
import
DGLMolTree
from
.nnutils
import
cuda
,
GRUUpdate
,
tocpu
MAX_NB
=
8
MAX_DECODE_LEN
=
100
...
...
@@ -21,60 +22,70 @@ def dfs_order(forest, roots):
# using find_edges().
yield
e
^
l
,
l
dec_tree_node_msg
=
DGLF
.
copy_e
(
edge
=
'm'
,
out
=
'm'
)
dec_tree_node_reduce
=
DGLF
.
sum
(
msg
=
'm'
,
out
=
'h'
)
dec_tree_node_msg
=
DGLF
.
copy_e
(
edge
=
"m"
,
out
=
"m"
)
dec_tree_node_reduce
=
DGLF
.
sum
(
msg
=
"m"
,
out
=
"h"
)
def
dec_tree_node_update
(
nodes
):
return
{
'
new
'
:
nodes
.
data
[
'
new
'
].
clone
().
zero_
()}
return
{
"
new
"
:
nodes
.
data
[
"
new
"
].
clone
().
zero_
()}
def
have_slots
(
fa_slots
,
ch_slots
):
if
len
(
fa_slots
)
>
2
and
len
(
ch_slots
)
>
2
:
return
True
matches
=
[]
for
i
,
s1
in
enumerate
(
fa_slots
):
a1
,
c1
,
h1
=
s1
for
j
,
s2
in
enumerate
(
ch_slots
):
a2
,
c2
,
h2
=
s2
for
i
,
s1
in
enumerate
(
fa_slots
):
a1
,
c1
,
h1
=
s1
for
j
,
s2
in
enumerate
(
ch_slots
):
a2
,
c2
,
h2
=
s2
if
a1
==
a2
and
c1
==
c2
and
(
a1
!=
"C"
or
h1
+
h2
>=
4
):
matches
.
append
(
(
i
,
j
)
)
matches
.
append
((
i
,
j
))
if
len
(
matches
)
==
0
:
return
False
if
len
(
matches
)
==
0
:
return
False
fa_match
,
ch_match
=
list
(
zip
(
*
matches
))
if
len
(
set
(
fa_match
))
==
1
and
1
<
len
(
fa_slots
)
<=
2
:
#never remove atom from ring
fa_match
,
ch_match
=
list
(
zip
(
*
matches
))
if
(
len
(
set
(
fa_match
))
==
1
and
1
<
len
(
fa_slots
)
<=
2
):
# never remove atom from ring
fa_slots
.
pop
(
fa_match
[
0
])
if
len
(
set
(
ch_match
))
==
1
and
1
<
len
(
ch_slots
)
<=
2
:
#never remove atom from ring
if
(
len
(
set
(
ch_match
))
==
1
and
1
<
len
(
ch_slots
)
<=
2
):
# never remove atom from ring
ch_slots
.
pop
(
ch_match
[
0
])
return
True
def
can_assemble
(
mol_tree
,
u
,
v_node_dict
):
u_node_dict
=
mol_tree
.
nodes_dict
[
u
]
u_neighbors
=
mol_tree
.
graph
.
successors
(
u
)
u_neighbors_node_dict
=
[
mol_tree
.
nodes_dict
[
_u
]
for
_u
in
u_neighbors
if
_u
in
mol_tree
.
nodes_dict
]
mol_tree
.
nodes_dict
[
_u
]
for
_u
in
u_neighbors
if
_u
in
mol_tree
.
nodes_dict
]
neis
=
u_neighbors_node_dict
+
[
v_node_dict
]
for
i
,
nei
in
enumerate
(
neis
):
nei
[
'nid'
]
=
i
neighbors
=
[
nei
for
nei
in
neis
if
nei
[
'mol'
].
GetNumAtoms
()
>
1
]
neighbors
=
sorted
(
neighbors
,
key
=
lambda
x
:
x
[
'mol'
].
GetNumAtoms
(),
reverse
=
True
)
singletons
=
[
nei
for
nei
in
neis
if
nei
[
'mol'
].
GetNumAtoms
()
==
1
]
for
i
,
nei
in
enumerate
(
neis
):
nei
[
"nid"
]
=
i
neighbors
=
[
nei
for
nei
in
neis
if
nei
[
"mol"
].
GetNumAtoms
()
>
1
]
neighbors
=
sorted
(
neighbors
,
key
=
lambda
x
:
x
[
"mol"
].
GetNumAtoms
(),
reverse
=
True
)
singletons
=
[
nei
for
nei
in
neis
if
nei
[
"mol"
].
GetNumAtoms
()
==
1
]
neighbors
=
singletons
+
neighbors
cands
=
enum_assemble_nx
(
u_node_dict
,
neighbors
)
return
len
(
cands
)
>
0
def
create_node_dict
(
smiles
,
clique
=
[]):
return
dict
(
smiles
=
smiles
,
mol
=
get_mol
(
smiles
),
clique
=
clique
,
)
smiles
=
smiles
,
mol
=
get_mol
(
smiles
),
clique
=
clique
,
)
class
DGLJTNNDecoder
(
nn
.
Module
):
...
...
@@ -98,41 +109,54 @@ class DGLJTNNDecoder(nn.Module):
self
.
U_s
=
nn
.
Linear
(
hidden_size
,
1
)
def
forward
(
self
,
mol_trees
,
tree_vec
):
'''
"""
The training procedure which computes the prediction loss given the
ground truth tree
'''
"""
mol_tree_batch
=
batch
(
mol_trees
)
mol_tree_batch_lg
=
line_graph
(
mol_tree_batch
,
backtracking
=
False
,
shared
=
True
)
mol_tree_batch_lg
=
line_graph
(
mol_tree_batch
,
backtracking
=
False
,
shared
=
True
)
n_trees
=
len
(
mol_trees
)
return
self
.
run
(
mol_tree_batch
,
mol_tree_batch_lg
,
n_trees
,
tree_vec
)
def
run
(
self
,
mol_tree_batch
,
mol_tree_batch_lg
,
n_trees
,
tree_vec
):
node_offset
=
np
.
cumsum
(
np
.
insert
(
mol_tree_batch
.
batch_num_nodes
().
cpu
().
numpy
(),
0
,
0
))
node_offset
=
np
.
cumsum
(
np
.
insert
(
mol_tree_batch
.
batch_num_nodes
().
cpu
().
numpy
(),
0
,
0
)
)
root_ids
=
node_offset
[:
-
1
]
n_nodes
=
mol_tree_batch
.
number_of_nodes
()
n_edges
=
mol_tree_batch
.
number_of_edges
()
mol_tree_batch
.
ndata
.
update
({
'x'
:
self
.
embedding
(
mol_tree_batch
.
ndata
[
'wid'
]),
'h'
:
cuda
(
torch
.
zeros
(
n_nodes
,
self
.
hidden_size
)),
'new'
:
cuda
(
torch
.
ones
(
n_nodes
).
bool
()),
# whether it's newly generated node
})
mol_tree_batch
.
edata
.
update
({
's'
:
cuda
(
torch
.
zeros
(
n_edges
,
self
.
hidden_size
)),
'm'
:
cuda
(
torch
.
zeros
(
n_edges
,
self
.
hidden_size
)),
'r'
:
cuda
(
torch
.
zeros
(
n_edges
,
self
.
hidden_size
)),
'z'
:
cuda
(
torch
.
zeros
(
n_edges
,
self
.
hidden_size
)),
'src_x'
:
cuda
(
torch
.
zeros
(
n_edges
,
self
.
hidden_size
)),
'dst_x'
:
cuda
(
torch
.
zeros
(
n_edges
,
self
.
hidden_size
)),
'rm'
:
cuda
(
torch
.
zeros
(
n_edges
,
self
.
hidden_size
)),
'accum_rm'
:
cuda
(
torch
.
zeros
(
n_edges
,
self
.
hidden_size
)),
})
mol_tree_batch
.
ndata
.
update
(
{
"x"
:
self
.
embedding
(
mol_tree_batch
.
ndata
[
"wid"
]),
"h"
:
cuda
(
torch
.
zeros
(
n_nodes
,
self
.
hidden_size
)),
"new"
:
cuda
(
torch
.
ones
(
n_nodes
).
bool
()
),
# whether it's newly generated node
}
)
mol_tree_batch
.
edata
.
update
(
{
"s"
:
cuda
(
torch
.
zeros
(
n_edges
,
self
.
hidden_size
)),
"m"
:
cuda
(
torch
.
zeros
(
n_edges
,
self
.
hidden_size
)),
"r"
:
cuda
(
torch
.
zeros
(
n_edges
,
self
.
hidden_size
)),
"z"
:
cuda
(
torch
.
zeros
(
n_edges
,
self
.
hidden_size
)),
"src_x"
:
cuda
(
torch
.
zeros
(
n_edges
,
self
.
hidden_size
)),
"dst_x"
:
cuda
(
torch
.
zeros
(
n_edges
,
self
.
hidden_size
)),
"rm"
:
cuda
(
torch
.
zeros
(
n_edges
,
self
.
hidden_size
)),
"accum_rm"
:
cuda
(
torch
.
zeros
(
n_edges
,
self
.
hidden_size
)),
}
)
mol_tree_batch
.
apply_edges
(
func
=
lambda
edges
:
{
'src_x'
:
edges
.
src
[
'x'
],
'dst_x'
:
edges
.
dst
[
'x'
]},
func
=
lambda
edges
:
{
"src_x"
:
edges
.
src
[
"x"
],
"dst_x"
:
edges
.
dst
[
"x"
],
},
)
# input tensors for stop prediction (p) and label prediction (q)
...
...
@@ -142,16 +166,16 @@ class DGLJTNNDecoder(nn.Module):
q_targets
=
[]
# Predict root
mol_tree_batch
.
pull
(
root_ids
,
DGLF
.
copy_e
(
'm'
,
'm'
),
DGLF
.
sum
(
'm'
,
'h'
))
mol_tree_batch
.
pull
(
root_ids
,
DGLF
.
copy_e
(
"m"
,
"m"
),
DGLF
.
sum
(
"m"
,
"h"
))
mol_tree_batch
.
apply_nodes
(
dec_tree_node_update
,
v
=
root_ids
)
# Extract hidden states and store them for stop/label prediction
h
=
mol_tree_batch
.
nodes
[
root_ids
].
data
[
'h'
]
x
=
mol_tree_batch
.
nodes
[
root_ids
].
data
[
'x'
]
h
=
mol_tree_batch
.
nodes
[
root_ids
].
data
[
"h"
]
x
=
mol_tree_batch
.
nodes
[
root_ids
].
data
[
"x"
]
p_inputs
.
append
(
torch
.
cat
([
x
,
h
,
tree_vec
],
1
))
# If the out degree is 0 we don't generate any edges at all
root_out_degrees
=
mol_tree_batch
.
out_degrees
(
root_ids
)
q_inputs
.
append
(
torch
.
cat
([
h
,
tree_vec
],
1
))
q_targets
.
append
(
mol_tree_batch
.
nodes
[
root_ids
].
data
[
'
wid
'
])
q_targets
.
append
(
mol_tree_batch
.
nodes
[
root_ids
].
data
[
"
wid
"
])
# Traverse the tree and predict on children
for
eid
,
p
in
dfs_order
(
mol_tree_batch
,
root_ids
):
...
...
@@ -160,29 +184,35 @@ class DGLJTNNDecoder(nn.Module):
u
,
v
=
mol_tree_batch
.
find_edges
(
eid
)
p_target_list
=
torch
.
zeros_like
(
root_out_degrees
)
p_target_list
[
root_out_degrees
>
0
]
=
(
1
-
p
)
p_target_list
[
root_out_degrees
>
0
]
=
1
-
p
p_target_list
=
p_target_list
[
root_out_degrees
>=
0
]
p_targets
.
append
(
torch
.
tensor
(
p_target_list
))
root_out_degrees
-=
(
root_out_degrees
==
0
).
long
()
root_out_degrees
-=
torch
.
tensor
(
np
.
isin
(
root_ids
,
v
.
cpu
().
numpy
())).
to
(
root_out_degrees
)
root_out_degrees
-=
torch
.
tensor
(
np
.
isin
(
root_ids
,
v
.
cpu
().
numpy
())
).
to
(
root_out_degrees
)
mol_tree_batch_lg
.
ndata
.
update
(
mol_tree_batch
.
edata
)
mol_tree_batch_lg
.
pull
(
eid
,
DGLF
.
copy_u
(
'm'
,
'm'
),
DGLF
.
sum
(
'm'
,
's'
))
mol_tree_batch_lg
.
pull
(
eid
,
DGLF
.
copy_u
(
'rm'
,
'rm'
),
DGLF
.
sum
(
'rm'
,
'accum_rm'
))
mol_tree_batch_lg
.
pull
(
eid
,
DGLF
.
copy_u
(
"m"
,
"m"
),
DGLF
.
sum
(
"m"
,
"s"
)
)
mol_tree_batch_lg
.
pull
(
eid
,
DGLF
.
copy_u
(
"rm"
,
"rm"
),
DGLF
.
sum
(
"rm"
,
"accum_rm"
)
)
mol_tree_batch_lg
.
apply_nodes
(
self
.
dec_tree_edge_update
,
v
=
eid
)
mol_tree_batch
.
edata
.
update
(
mol_tree_batch_lg
.
ndata
)
is_new
=
mol_tree_batch
.
nodes
[
v
].
data
[
'
new
'
]
mol_tree_batch
.
pull
(
v
,
DGLF
.
copy_e
(
'm'
,
'm'
),
DGLF
.
sum
(
'm'
,
'h'
))
is_new
=
mol_tree_batch
.
nodes
[
v
].
data
[
"
new
"
]
mol_tree_batch
.
pull
(
v
,
DGLF
.
copy_e
(
"m"
,
"m"
),
DGLF
.
sum
(
"m"
,
"h"
))
mol_tree_batch
.
apply_nodes
(
dec_tree_node_update
,
v
=
v
)
# Extract
n_repr
=
mol_tree_batch
.
nodes
[
v
].
data
h
=
n_repr
[
'h'
]
x
=
n_repr
[
'x'
]
h
=
n_repr
[
"h"
]
x
=
n_repr
[
"x"
]
tree_vec_set
=
tree_vec
[
root_out_degrees
>=
0
]
wid
=
n_repr
[
'
wid
'
]
wid
=
n_repr
[
"
wid
"
]
p_inputs
.
append
(
torch
.
cat
([
x
,
h
,
tree_vec_set
],
1
))
# Only newly generated nodes are needed for label prediction
# NOTE: The following works since the uncomputed messages are zeros.
...
...
@@ -192,10 +222,13 @@ class DGLJTNNDecoder(nn.Module):
if
q_input
.
shape
[
0
]
>
0
:
q_inputs
.
append
(
q_input
)
q_targets
.
append
(
q_target
)
p_targets
.
append
(
torch
.
zeros
(
(
root_out_degrees
==
0
).
sum
(),
device
=
root_out_degrees
.
device
,
dtype
=
torch
.
int64
))
p_targets
.
append
(
torch
.
zeros
(
(
root_out_degrees
==
0
).
sum
(),
device
=
root_out_degrees
.
device
,
dtype
=
torch
.
int64
,
)
)
# Batch compute the stop/label prediction losses
p_inputs
=
torch
.
cat
(
p_inputs
,
0
)
...
...
@@ -206,9 +239,12 @@ class DGLJTNNDecoder(nn.Module):
q
=
self
.
W_o
(
torch
.
relu
(
self
.
W
(
q_inputs
)))
p
=
self
.
U_s
(
torch
.
relu
(
self
.
U
(
p_inputs
)))[:,
0
]
p_loss
=
F
.
binary_cross_entropy_with_logits
(
p
,
p_targets
.
float
(),
size_average
=
False
)
/
n_trees
p_loss
=
(
F
.
binary_cross_entropy_with_logits
(
p
,
p_targets
.
float
(),
size_average
=
False
)
/
n_trees
)
q_loss
=
F
.
cross_entropy
(
q
,
q_targets
,
size_average
=
False
)
/
n_trees
p_acc
=
((
p
>
0
).
long
()
==
p_targets
).
sum
().
float
()
/
p_targets
.
shape
[
0
]
q_acc
=
(
q
.
max
(
1
)[
1
]
==
q_targets
).
float
().
sum
()
/
q_targets
.
shape
[
0
]
...
...
@@ -237,13 +273,14 @@ class DGLJTNNDecoder(nn.Module):
_
,
root_wid
=
torch
.
max
(
root_score
,
1
)
root_wid
=
root_wid
.
view
(
1
)
mol_tree_graph
.
add_nodes
(
1
)
# root
mol_tree_graph
.
ndata
[
'
wid
'
]
=
root_wid
mol_tree_graph
.
ndata
[
'x'
]
=
self
.
embedding
(
root_wid
)
mol_tree_graph
.
ndata
[
'h'
]
=
init_hidden
mol_tree_graph
.
ndata
[
'
fail
'
]
=
cuda
(
torch
.
tensor
([
0
]))
mol_tree_graph
.
add_nodes
(
1
)
# root
mol_tree_graph
.
ndata
[
"
wid
"
]
=
root_wid
mol_tree_graph
.
ndata
[
"x"
]
=
self
.
embedding
(
root_wid
)
mol_tree_graph
.
ndata
[
"h"
]
=
init_hidden
mol_tree_graph
.
ndata
[
"
fail
"
]
=
cuda
(
torch
.
tensor
([
0
]))
mol_tree
.
nodes_dict
[
0
]
=
root_node_dict
=
create_node_dict
(
self
.
vocab
.
get_smiles
(
root_wid
))
self
.
vocab
.
get_smiles
(
root_wid
)
)
stack
,
trace
=
[],
[]
stack
.
append
((
0
,
self
.
vocab
.
get_slots
(
root_wid
)))
...
...
@@ -256,13 +293,13 @@ class DGLJTNNDecoder(nn.Module):
for
step
in
range
(
MAX_DECODE_LEN
):
u
,
u_slots
=
stack
[
-
1
]
x
=
mol_tree_graph
.
ndata
[
'x'
][
u
:
u
+
1
]
h
=
mol_tree_graph
.
ndata
[
'h'
][
u
:
u
+
1
]
x
=
mol_tree_graph
.
ndata
[
"x"
][
u
:
u
+
1
]
h
=
mol_tree_graph
.
ndata
[
"h"
][
u
:
u
+
1
]
# Predict stop
p_input
=
torch
.
cat
([
x
,
h
,
mol_vec
],
1
)
p_score
=
torch
.
sigmoid
(
self
.
U_s
(
torch
.
relu
(
self
.
U
(
p_input
))))
backtrack
=
(
p_score
.
item
()
<
0.5
)
backtrack
=
p_score
.
item
()
<
0.5
if
not
backtrack
:
# Predict next clique. Note that the prediction may fail due
...
...
@@ -273,49 +310,61 @@ class DGLJTNNDecoder(nn.Module):
mol_tree_graph
.
add_edges
(
u
,
v
)
uv
=
new_edge_id
new_edge_id
+=
1
if
first
:
mol_tree_graph
.
edata
.
update
({
's'
:
cuda
(
torch
.
zeros
(
1
,
self
.
hidden_size
)),
'm'
:
cuda
(
torch
.
zeros
(
1
,
self
.
hidden_size
)),
'r'
:
cuda
(
torch
.
zeros
(
1
,
self
.
hidden_size
)),
'z'
:
cuda
(
torch
.
zeros
(
1
,
self
.
hidden_size
)),
'src_x'
:
cuda
(
torch
.
zeros
(
1
,
self
.
hidden_size
)),
'dst_x'
:
cuda
(
torch
.
zeros
(
1
,
self
.
hidden_size
)),
'rm'
:
cuda
(
torch
.
zeros
(
1
,
self
.
hidden_size
)),
'accum_rm'
:
cuda
(
torch
.
zeros
(
1
,
self
.
hidden_size
)),
})
mol_tree_graph
.
edata
.
update
(
{
"s"
:
cuda
(
torch
.
zeros
(
1
,
self
.
hidden_size
)),
"m"
:
cuda
(
torch
.
zeros
(
1
,
self
.
hidden_size
)),
"r"
:
cuda
(
torch
.
zeros
(
1
,
self
.
hidden_size
)),
"z"
:
cuda
(
torch
.
zeros
(
1
,
self
.
hidden_size
)),
"src_x"
:
cuda
(
torch
.
zeros
(
1
,
self
.
hidden_size
)),
"dst_x"
:
cuda
(
torch
.
zeros
(
1
,
self
.
hidden_size
)),
"rm"
:
cuda
(
torch
.
zeros
(
1
,
self
.
hidden_size
)),
"accum_rm"
:
cuda
(
torch
.
zeros
(
1
,
self
.
hidden_size
)),
}
)
first
=
False
mol_tree_graph
.
edata
[
'
src_x
'
][
uv
]
=
mol_tree_graph
.
ndata
[
'x'
][
u
]
mol_tree_graph
.
edata
[
"
src_x
"
][
uv
]
=
mol_tree_graph
.
ndata
[
"x"
][
u
]
# keeping dst_x 0 is fine as h on new edge doesn't depend on that.
# DGL doesn't dynamically maintain a line graph.
mol_tree_graph_lg
=
line_graph
(
mol_tree_graph
,
backtracking
=
False
,
shared
=
True
)
mol_tree_graph_lg
=
line_graph
(
mol_tree_graph
,
backtracking
=
False
,
shared
=
True
)
mol_tree_graph_lg
.
pull
(
uv
,
DGLF
.
copy_u
(
'm'
,
'm'
),
DGLF
.
sum
(
'm'
,
's'
))
uv
,
DGLF
.
copy_u
(
"m"
,
"m"
),
DGLF
.
sum
(
"m"
,
"s"
)
)
mol_tree_graph_lg
.
pull
(
uv
,
DGLF
.
copy_u
(
'rm'
,
'rm'
),
DGLF
.
sum
(
'rm'
,
'accum_rm'
))
mol_tree_graph_lg
.
apply_nodes
(
self
.
dec_tree_edge_update
.
update_zm
,
v
=
uv
)
uv
,
DGLF
.
copy_u
(
"rm"
,
"rm"
),
DGLF
.
sum
(
"rm"
,
"accum_rm"
)
)
mol_tree_graph_lg
.
apply_nodes
(
self
.
dec_tree_edge_update
.
update_zm
,
v
=
uv
)
mol_tree_graph
.
edata
.
update
(
mol_tree_graph_lg
.
ndata
)
mol_tree_graph
.
pull
(
v
,
DGLF
.
copy_e
(
'm'
,
'm'
),
DGLF
.
sum
(
'm'
,
'h'
))
mol_tree_graph
.
pull
(
v
,
DGLF
.
copy_e
(
"m"
,
"m"
),
DGLF
.
sum
(
"m"
,
"h"
)
)
h_v
=
mol_tree_graph
.
ndata
[
'h'
][
v
:
v
+
1
]
h_v
=
mol_tree_graph
.
ndata
[
"h"
][
v
:
v
+
1
]
q_input
=
torch
.
cat
([
h_v
,
mol_vec
],
1
)
q_score
=
torch
.
softmax
(
self
.
W_o
(
torch
.
relu
(
self
.
W
(
q_input
))),
-
1
)
q_score
=
torch
.
softmax
(
self
.
W_o
(
torch
.
relu
(
self
.
W
(
q_input
))),
-
1
)
_
,
sort_wid
=
torch
.
sort
(
q_score
,
1
,
descending
=
True
)
sort_wid
=
sort_wid
.
squeeze
()
next_wid
=
None
for
wid
in
sort_wid
.
tolist
()[:
5
]:
slots
=
self
.
vocab
.
get_slots
(
wid
)
cand_node_dict
=
create_node_dict
(
self
.
vocab
.
get_smiles
(
wid
))
if
(
have_slots
(
u_slots
,
slots
)
and
can_assemble
(
mol_tree
,
u
,
cand_node_dict
)):
cand_node_dict
=
create_node_dict
(
self
.
vocab
.
get_smiles
(
wid
)
)
if
have_slots
(
u_slots
,
slots
)
and
can_assemble
(
mol_tree
,
u
,
cand_node_dict
):
next_wid
=
wid
next_slots
=
slots
next_node_dict
=
cand_node_dict
...
...
@@ -324,44 +373,59 @@ class DGLJTNNDecoder(nn.Module):
if
next_wid
is
None
:
# Failed adding an actual children; v is a spurious node
# and we mark it.
mol_tree_graph
.
ndata
[
'
fail
'
][
v
]
=
cuda
(
torch
.
tensor
([
1
]))
mol_tree_graph
.
ndata
[
"
fail
"
][
v
]
=
cuda
(
torch
.
tensor
([
1
]))
backtrack
=
True
else
:
next_wid
=
cuda
(
torch
.
tensor
([
next_wid
]))
mol_tree_graph
.
ndata
[
'
wid
'
][
v
]
=
next_wid
mol_tree_graph
.
ndata
[
'x'
][
v
]
=
self
.
embedding
(
next_wid
)
mol_tree_graph
.
ndata
[
"
wid
"
][
v
]
=
next_wid
mol_tree_graph
.
ndata
[
"x"
][
v
]
=
self
.
embedding
(
next_wid
)
mol_tree
.
nodes_dict
[
v
]
=
next_node_dict
all_nodes
[
v
]
=
next_node_dict
stack
.
append
((
v
,
next_slots
))
mol_tree_graph
.
add_edges
(
v
,
u
)
vu
=
new_edge_id
new_edge_id
+=
1
mol_tree_graph
.
edata
[
'dst_x'
][
uv
]
=
mol_tree_graph
.
ndata
[
'x'
][
v
]
mol_tree_graph
.
edata
[
'src_x'
][
vu
]
=
mol_tree_graph
.
ndata
[
'x'
][
v
]
mol_tree_graph
.
edata
[
'dst_x'
][
vu
]
=
mol_tree_graph
.
ndata
[
'x'
][
u
]
mol_tree_graph
.
edata
[
"dst_x"
][
uv
]
=
mol_tree_graph
.
ndata
[
"x"
][
v
]
mol_tree_graph
.
edata
[
"src_x"
][
vu
]
=
mol_tree_graph
.
ndata
[
"x"
][
v
]
mol_tree_graph
.
edata
[
"dst_x"
][
vu
]
=
mol_tree_graph
.
ndata
[
"x"
][
u
]
# DGL doesn't dynamically maintain a line graph.
mol_tree_graph_lg
=
line_graph
(
mol_tree_graph
,
backtracking
=
False
,
shared
=
True
)
mol_tree_graph_lg
=
line_graph
(
mol_tree_graph
,
backtracking
=
False
,
shared
=
True
)
mol_tree_graph_lg
.
apply_nodes
(
self
.
dec_tree_edge_update
.
update_r
,
uv
)
self
.
dec_tree_edge_update
.
update_r
,
uv
)
mol_tree_graph
.
edata
.
update
(
mol_tree_graph_lg
.
ndata
)
if
backtrack
:
if
len
(
stack
)
==
1
:
break
# At root, terminate
break
# At root, terminate
pu
,
_
=
stack
[
-
2
]
u_pu
=
mol_tree_graph
.
edge_ids
(
u
,
pu
)
mol_tree_graph_lg
.
pull
(
u_pu
,
DGLF
.
copy_u
(
'm'
,
'm'
),
DGLF
.
sum
(
'm'
,
's'
))
mol_tree_graph_lg
.
pull
(
u_pu
,
DGLF
.
copy_u
(
'rm'
,
'rm'
),
DGLF
.
sum
(
'rm'
,
'accum_rm'
))
mol_tree_graph_lg
.
pull
(
u_pu
,
DGLF
.
copy_u
(
"m"
,
"m"
),
DGLF
.
sum
(
"m"
,
"s"
)
)
mol_tree_graph_lg
.
pull
(
u_pu
,
DGLF
.
copy_u
(
"rm"
,
"rm"
),
DGLF
.
sum
(
"rm"
,
"accum_rm"
)
)
mol_tree_graph_lg
.
apply_nodes
(
self
.
dec_tree_edge_update
,
v
=
u_pu
)
mol_tree_graph
.
edata
.
update
(
mol_tree_graph_lg
.
ndata
)
mol_tree_graph
.
pull
(
pu
,
DGLF
.
copy_e
(
'm'
,
'm'
),
DGLF
.
sum
(
'm'
,
'h'
))
mol_tree_graph
.
pull
(
pu
,
DGLF
.
copy_e
(
"m"
,
"m"
),
DGLF
.
sum
(
"m"
,
"h"
)
)
stack
.
pop
()
effective_nodes
=
mol_tree_graph
.
filter_nodes
(
lambda
nodes
:
nodes
.
data
[
'fail'
]
!=
1
)
effective_nodes
=
mol_tree_graph
.
filter_nodes
(
lambda
nodes
:
nodes
.
data
[
"fail"
]
!=
1
)
effective_nodes
,
_
=
torch
.
sort
(
effective_nodes
)
return
mol_tree
,
all_nodes
,
effective_nodes
examples/pytorch/jtnn/jtnn/jtnn_enc.py
View file @
704bcaf6
import
dgl.function
as
DGLF
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
dgl.function
as
DGLF
from
dgl
import
batch
,
bfs_edges_generator
,
line_graph
from
.nnutils
import
GRUUpdate
,
cuda
,
tocpu
from
.nnutils
import
cuda
,
GRUUpdate
,
tocpu
MAX_NB
=
8
...
...
examples/pytorch/jtnn/jtnn/jtnn_vae.py
View file @
704bcaf6
...
...
@@ -14,12 +14,10 @@ from .chemutils import (
enum_assemble_nx
,
set_atommap
,
)
from
.jtmpn
import
DGLJTMPN
from
.jtmpn
import
mol2dgl_single
as
mol2dgl_dec
from
.jtmpn
import
DGLJTMPN
,
mol2dgl_single
as
mol2dgl_dec
from
.jtnn_dec
import
DGLJTNNDecoder
from
.jtnn_enc
import
DGLJTNNEncoder
from
.mpn
import
DGLMPN
from
.mpn
import
mol2dgl_single
as
mol2dgl_enc
from
.mpn
import
DGLMPN
,
mol2dgl_single
as
mol2dgl_enc
from
.nnutils
import
cuda
...
...
examples/pytorch/jtnn/jtnn/mol_tree_nx.py
View file @
704bcaf6
import
dgl
import
numpy
as
np
import
rdkit.Chem
as
Chem
import
dgl
from
.chemutils
import
(
decode_stereo
,
enum_assemble_nx
,
...
...
examples/pytorch/jtnn/jtnn/mpn.py
View file @
704bcaf6
import
dgl
import
dgl.function
as
DGLF
import
rdkit.Chem
as
Chem
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
dgl
import
dgl.function
as
DGLF
from
dgl
import
line_graph
,
mean_nodes
from
.chemutils
import
get_mol
...
...
examples/pytorch/jtnn/jtnn/nnutils.py
View file @
704bcaf6
import
os
import
dgl
import
torch
import
torch.nn
as
nn
import
dgl
def
cuda
(
x
):
if
torch
.
cuda
.
is_available
()
and
not
os
.
getenv
(
"NOCUDA"
,
None
):
...
...
examples/pytorch/label_propagation/main.py
View file @
704bcaf6
import
argparse
import
torch
import
dgl
import
torch
from
dgl.data
import
CiteseerGraphDataset
,
CoraGraphDataset
,
PubmedGraphDataset
from
dgl.nn
import
LabelPropagation
...
...
examples/pytorch/lda/example_20newsgroups.py
View file @
704bcaf6
...
...
@@ -20,18 +20,18 @@
import
warnings
from
time
import
time
import
dgl
import
matplotlib.pyplot
as
plt
import
numpy
as
np
import
scipy.sparse
as
ss
import
torch
from
dgl
import
function
as
fn
from
lda_model
import
LatentDirichletAllocation
as
LDAModel
from
sklearn.datasets
import
fetch_20newsgroups
from
sklearn.decomposition
import
NMF
,
LatentDirichletAllocation
from
sklearn.decomposition
import
LatentDirichletAllocation
,
NMF
from
sklearn.feature_extraction.text
import
CountVectorizer
,
TfidfVectorizer
import
dgl
from
dgl
import
function
as
fn
n_samples
=
2000
n_features
=
1000
n_components
=
10
...
...
examples/pytorch/lda/lda_model.py
View file @
704bcaf6
...
...
@@ -23,12 +23,12 @@ import io
import
os
import
warnings
import
dgl
import
numpy
as
np
import
scipy
as
sp
import
torch
import
dgl
try
:
from
functools
import
cached_property
except
ImportError
:
...
...
examples/pytorch/line_graph/gnn.py
View file @
704bcaf6
import
copy
import
itertools
import
dgl
import
dgl.function
as
fn
import
networkx
as
nx
import
numpy
as
np
import
torch
as
th
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
numpy
as
np
class
GNNModule
(
nn
.
Module
):
def
__init__
(
self
,
in_feats
,
out_feats
,
radius
):
...
...
@@ -15,14 +17,22 @@ class GNNModule(nn.Module):
self
.
radius
=
radius
new_linear
=
lambda
:
nn
.
Linear
(
in_feats
,
out_feats
)
new_linear_list
=
lambda
:
nn
.
ModuleList
([
new_linear
()
for
i
in
range
(
radius
)])
new_linear_list
=
lambda
:
nn
.
ModuleList
(
[
new_linear
()
for
i
in
range
(
radius
)]
)
self
.
theta_x
,
self
.
theta_deg
,
self
.
theta_y
=
\
new_linear
(),
new_linear
(),
new_linear
()
self
.
theta_x
,
self
.
theta_deg
,
self
.
theta_y
=
(
new_linear
(),
new_linear
(),
new_linear
(),
)
self
.
theta_list
=
new_linear_list
()
self
.
gamma_y
,
self
.
gamma_deg
,
self
.
gamma_x
=
\
new_linear
(),
new_linear
(),
new_linear
()
self
.
gamma_y
,
self
.
gamma_deg
,
self
.
gamma_x
=
(
new_linear
(),
new_linear
(),
new_linear
(),
)
self
.
gamma_list
=
new_linear_list
()
self
.
bn_x
=
nn
.
BatchNorm1d
(
out_feats
)
...
...
@@ -30,43 +40,61 @@ class GNNModule(nn.Module):
def
aggregate
(
self
,
g
,
z
):
z_list
=
[]
g
.
ndata
[
'z'
]
=
z
g
.
update_all
(
fn
.
copy_u
(
u
=
'z'
,
out
=
'm'
),
fn
.
sum
(
msg
=
'm'
,
out
=
'z'
))
z_list
.
append
(
g
.
ndata
[
'z'
])
g
.
ndata
[
"z"
]
=
z
g
.
update_all
(
fn
.
copy_u
(
u
=
"z"
,
out
=
"m"
),
fn
.
sum
(
msg
=
"m"
,
out
=
"z"
))
z_list
.
append
(
g
.
ndata
[
"z"
])
for
i
in
range
(
self
.
radius
-
1
):
for
j
in
range
(
2
**
i
):
g
.
update_all
(
fn
.
copy_u
(
u
=
'z'
,
out
=
'm'
),
fn
.
sum
(
msg
=
'm'
,
out
=
'z'
))
z_list
.
append
(
g
.
ndata
[
'z'
])
for
j
in
range
(
2
**
i
):
g
.
update_all
(
fn
.
copy_u
(
u
=
"z"
,
out
=
"m"
),
fn
.
sum
(
msg
=
"m"
,
out
=
"z"
)
)
z_list
.
append
(
g
.
ndata
[
"z"
])
return
z_list
def
forward
(
self
,
g
,
lg
,
x
,
y
,
deg_g
,
deg_lg
,
pm_pd
):
pmpd_x
=
F
.
embedding
(
pm_pd
,
x
)
sum_x
=
sum
(
theta
(
z
)
for
theta
,
z
in
zip
(
self
.
theta_list
,
self
.
aggregate
(
g
,
x
)))
sum_x
=
sum
(
theta
(
z
)
for
theta
,
z
in
zip
(
self
.
theta_list
,
self
.
aggregate
(
g
,
x
))
)
g
.
edata
[
'y'
]
=
y
g
.
update_all
(
fn
.
copy_e
(
e
=
'y'
,
out
=
'm'
),
fn
.
sum
(
'm'
,
'
pmpd_y
'
))
pmpd_y
=
g
.
ndata
.
pop
(
'
pmpd_y
'
)
g
.
edata
[
"y"
]
=
y
g
.
update_all
(
fn
.
copy_e
(
e
=
"y"
,
out
=
"m"
),
fn
.
sum
(
"m"
,
"
pmpd_y
"
))
pmpd_y
=
g
.
ndata
.
pop
(
"
pmpd_y
"
)
x
=
self
.
theta_x
(
x
)
+
self
.
theta_deg
(
deg_g
*
x
)
+
sum_x
+
self
.
theta_y
(
pmpd_y
)
x
=
(
self
.
theta_x
(
x
)
+
self
.
theta_deg
(
deg_g
*
x
)
+
sum_x
+
self
.
theta_y
(
pmpd_y
)
)
n
=
self
.
out_feats
//
2
x
=
th
.
cat
([
x
[:,
:
n
],
F
.
relu
(
x
[:,
n
:])],
1
)
x
=
self
.
bn_x
(
x
)
sum_y
=
sum
(
gamma
(
z
)
for
gamma
,
z
in
zip
(
self
.
gamma_list
,
self
.
aggregate
(
lg
,
y
)))
sum_y
=
sum
(
gamma
(
z
)
for
gamma
,
z
in
zip
(
self
.
gamma_list
,
self
.
aggregate
(
lg
,
y
))
)
y
=
self
.
gamma_y
(
y
)
+
self
.
gamma_deg
(
deg_lg
*
y
)
+
sum_y
+
self
.
gamma_x
(
pmpd_x
)
y
=
(
self
.
gamma_y
(
y
)
+
self
.
gamma_deg
(
deg_lg
*
y
)
+
sum_y
+
self
.
gamma_x
(
pmpd_x
)
)
y
=
th
.
cat
([
y
[:,
:
n
],
F
.
relu
(
y
[:,
n
:])],
1
)
y
=
self
.
bn_y
(
y
)
return
x
,
y
class
GNN
(
nn
.
Module
):
def
__init__
(
self
,
feats
,
radius
,
n_classes
):
super
(
GNN
,
self
).
__init__
()
self
.
linear
=
nn
.
Linear
(
feats
[
-
1
],
n_classes
)
self
.
module_list
=
nn
.
ModuleList
([
GNNModule
(
m
,
n
,
radius
)
for
m
,
n
in
zip
(
feats
[:
-
1
],
feats
[
1
:])])
self
.
module_list
=
nn
.
ModuleList
(
[
GNNModule
(
m
,
n
,
radius
)
for
m
,
n
in
zip
(
feats
[:
-
1
],
feats
[
1
:])]
)
def
forward
(
self
,
g
,
lg
,
deg_g
,
deg_lg
,
pm_pd
):
x
,
y
=
deg_g
,
deg_lg
...
...
examples/pytorch/line_graph/train.py
View file @
704bcaf6
...
...
@@ -16,9 +16,9 @@ import numpy as np
import
torch
as
th
import
torch.nn.functional
as
F
import
torch.optim
as
optim
from
torch.utils.data
import
DataLoader
from
dgl.data
import
SBMMixtureDataset
from
torch.utils.data
import
DataLoader
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--batch-size"
,
type
=
int
,
help
=
"Batch size"
,
default
=
1
)
...
...
examples/pytorch/metapath2vec/download.py
View file @
704bcaf6
import
os
import
torch
as
th
import
torch.nn
as
nn
import
tqdm
...
...
@@ -20,35 +21,37 @@ class PBar(object):
class
AminerDataset
(
object
):
"""
Download Aminer Dataset from Amazon S3 bucket.
Download Aminer Dataset from Amazon S3 bucket.
"""
def
__init__
(
self
,
path
):
self
.
url
=
'https://data.dgl.ai/dataset/aminer.zip'
def
__init__
(
self
,
path
):
self
.
url
=
"https://data.dgl.ai/dataset/aminer.zip"
if
not
os
.
path
.
exists
(
os
.
path
.
join
(
path
,
'
aminer.txt
'
)):
print
(
'
File not found. Downloading from
'
,
self
.
url
)
self
.
_download_and_extract
(
path
,
'
aminer.zip
'
)
self
.
fn
=
os
.
path
.
join
(
path
,
'
aminer.txt
'
)
if
not
os
.
path
.
exists
(
os
.
path
.
join
(
path
,
"
aminer.txt
"
)):
print
(
"
File not found. Downloading from
"
,
self
.
url
)
self
.
_download_and_extract
(
path
,
"
aminer.zip
"
)
self
.
fn
=
os
.
path
.
join
(
path
,
"
aminer.txt
"
)
def
_download_and_extract
(
self
,
path
,
filename
):
import
shutil
,
zipfile
,
zlib
from
tqdm
import
tqdm
import
urllib.request
from
tqdm
import
tqdm
fn
=
os
.
path
.
join
(
path
,
filename
)
with
PBar
()
as
pb
:
urllib
.
request
.
urlretrieve
(
self
.
url
,
fn
,
pb
)
print
(
'
Download finished. Unzipping the file...
'
)
print
(
"
Download finished. Unzipping the file...
"
)
with
zipfile
.
ZipFile
(
fn
)
as
zf
:
zf
.
extractall
(
path
)
print
(
'
Unzip finished.
'
)
print
(
"
Unzip finished.
"
)
class
CustomDataset
(
object
):
"""
Custom dataset generated by sampler.py (e.g. NetDBIS)
"""
def
__init__
(
self
,
path
):
self
.
fn
=
path
examples/pytorch/metapath2vec/metapath2vec.py
View file @
704bcaf6
import
torch
import
argparse
import
torch
import
torch.optim
as
optim
from
download
import
AminerDataset
,
CustomDataset
from
model
import
SkipGramModel
from
reading_data
import
DataReader
,
Metapath2vecDataset
from
torch.utils.data
import
DataLoader
from
tqdm
import
tqdm
from
reading_data
import
DataReader
,
Metapath2vecDataset
from
model
import
SkipGramModel
from
download
import
AminerDataset
,
CustomDataset
class
Metapath2VecTrainer
:
def
__init__
(
self
,
args
):
...
...
@@ -18,8 +19,13 @@ class Metapath2VecTrainer:
dataset
=
CustomDataset
(
args
.
path
)
self
.
data
=
DataReader
(
dataset
,
args
.
min_count
,
args
.
care_type
)
dataset
=
Metapath2vecDataset
(
self
.
data
,
args
.
window_size
)
self
.
dataloader
=
DataLoader
(
dataset
,
batch_size
=
args
.
batch_size
,
shuffle
=
True
,
num_workers
=
args
.
num_workers
,
collate_fn
=
dataset
.
collate
)
self
.
dataloader
=
DataLoader
(
dataset
,
batch_size
=
args
.
batch_size
,
shuffle
=
True
,
num_workers
=
args
.
num_workers
,
collate_fn
=
dataset
.
collate
,
)
self
.
output_file_name
=
args
.
output_file
self
.
emb_size
=
len
(
self
.
data
.
word2id
)
...
...
@@ -35,15 +41,17 @@ class Metapath2VecTrainer:
self
.
skip_gram_model
.
cuda
()
def
train
(
self
):
optimizer
=
optim
.
SparseAdam
(
list
(
self
.
skip_gram_model
.
parameters
()),
lr
=
self
.
initial_lr
)
scheduler
=
torch
.
optim
.
lr_scheduler
.
CosineAnnealingLR
(
optimizer
,
len
(
self
.
dataloader
))
optimizer
=
optim
.
SparseAdam
(
list
(
self
.
skip_gram_model
.
parameters
()),
lr
=
self
.
initial_lr
)
scheduler
=
torch
.
optim
.
lr_scheduler
.
CosineAnnealingLR
(
optimizer
,
len
(
self
.
dataloader
)
)
for
iteration
in
range
(
self
.
iterations
):
print
(
"
\n\n\n
Iteration: "
+
str
(
iteration
+
1
))
running_loss
=
0.0
for
i
,
sample_batched
in
enumerate
(
tqdm
(
self
.
dataloader
)):
if
len
(
sample_batched
[
0
])
>
1
:
pos_u
=
sample_batched
[
0
].
to
(
self
.
device
)
pos_v
=
sample_batched
[
1
].
to
(
self
.
device
)
...
...
@@ -59,23 +67,40 @@ class Metapath2VecTrainer:
if
i
>
0
and
i
%
500
==
0
:
print
(
" Loss: "
+
str
(
running_loss
))
self
.
skip_gram_model
.
save_embedding
(
self
.
data
.
id2word
,
self
.
output_file_name
)
self
.
skip_gram_model
.
save_embedding
(
self
.
data
.
id2word
,
self
.
output_file_name
)
if
__name__
==
'
__main__
'
:
if
__name__
==
"
__main__
"
:
parser
=
argparse
.
ArgumentParser
(
description
=
"Metapath2vec"
)
#parser.add_argument('--input_file', type=str, help="input_file")
parser
.
add_argument
(
'--aminer'
,
action
=
'store_true'
,
help
=
'Use AMiner dataset'
)
parser
.
add_argument
(
'--path'
,
type
=
str
,
help
=
"input_path"
)
parser
.
add_argument
(
'--output_file'
,
type
=
str
,
help
=
'output_file'
)
parser
.
add_argument
(
'--dim'
,
default
=
128
,
type
=
int
,
help
=
"embedding dimensions"
)
parser
.
add_argument
(
'--window_size'
,
default
=
7
,
type
=
int
,
help
=
"context window size"
)
parser
.
add_argument
(
'--iterations'
,
default
=
5
,
type
=
int
,
help
=
"iterations"
)
parser
.
add_argument
(
'--batch_size'
,
default
=
50
,
type
=
int
,
help
=
"batch size"
)
parser
.
add_argument
(
'--care_type'
,
default
=
0
,
type
=
int
,
help
=
"if 1, heterogeneous negative sampling, else normal negative sampling"
)
parser
.
add_argument
(
'--initial_lr'
,
default
=
0.025
,
type
=
float
,
help
=
"learning rate"
)
parser
.
add_argument
(
'--min_count'
,
default
=
5
,
type
=
int
,
help
=
"min count"
)
parser
.
add_argument
(
'--num_workers'
,
default
=
16
,
type
=
int
,
help
=
"number of workers"
)
# parser.add_argument('--input_file', type=str, help="input_file")
parser
.
add_argument
(
"--aminer"
,
action
=
"store_true"
,
help
=
"Use AMiner dataset"
)
parser
.
add_argument
(
"--path"
,
type
=
str
,
help
=
"input_path"
)
parser
.
add_argument
(
"--output_file"
,
type
=
str
,
help
=
"output_file"
)
parser
.
add_argument
(
"--dim"
,
default
=
128
,
type
=
int
,
help
=
"embedding dimensions"
)
parser
.
add_argument
(
"--window_size"
,
default
=
7
,
type
=
int
,
help
=
"context window size"
)
parser
.
add_argument
(
"--iterations"
,
default
=
5
,
type
=
int
,
help
=
"iterations"
)
parser
.
add_argument
(
"--batch_size"
,
default
=
50
,
type
=
int
,
help
=
"batch size"
)
parser
.
add_argument
(
"--care_type"
,
default
=
0
,
type
=
int
,
help
=
"if 1, heterogeneous negative sampling, else normal negative sampling"
,
)
parser
.
add_argument
(
"--initial_lr"
,
default
=
0.025
,
type
=
float
,
help
=
"learning rate"
)
parser
.
add_argument
(
"--min_count"
,
default
=
5
,
type
=
int
,
help
=
"min count"
)
parser
.
add_argument
(
"--num_workers"
,
default
=
16
,
type
=
int
,
help
=
"number of workers"
)
args
=
parser
.
parse_args
()
m2v
=
Metapath2VecTrainer
(
args
)
m2v
.
train
()
examples/pytorch/metapath2vec/model.py
View file @
704bcaf6
...
...
@@ -10,7 +10,6 @@ from torch.nn import init
class
SkipGramModel
(
nn
.
Module
):
def
__init__
(
self
,
emb_size
,
emb_dimension
):
super
(
SkipGramModel
,
self
).
__init__
()
self
.
emb_size
=
emb_size
...
...
@@ -39,8 +38,8 @@ class SkipGramModel(nn.Module):
def
save_embedding
(
self
,
id2word
,
file_name
):
embedding
=
self
.
u_embeddings
.
weight
.
cpu
().
data
.
numpy
()
with
open
(
file_name
,
'w'
)
as
f
:
f
.
write
(
'
%d %d
\n
'
%
(
len
(
id2word
),
self
.
emb_dimension
))
with
open
(
file_name
,
"w"
)
as
f
:
f
.
write
(
"
%d %d
\n
"
%
(
len
(
id2word
),
self
.
emb_dimension
))
for
wid
,
w
in
id2word
.
items
():
e
=
' '
.
join
(
map
(
lambda
x
:
str
(
x
),
embedding
[
wid
]))
f
.
write
(
'%s %s
\n
'
%
(
w
,
e
))
\ No newline at end of file
e
=
" "
.
join
(
map
(
lambda
x
:
str
(
x
),
embedding
[
wid
]))
f
.
write
(
"%s %s
\n
"
%
(
w
,
e
))
examples/pytorch/metapath2vec/reading_data.py
View file @
704bcaf6
import
numpy
as
np
import
torch
from
torch.utils.data
import
Dataset
from
download
import
AminerDataset
from
torch.utils.data
import
Dataset
np
.
random
.
seed
(
12345
)
class
DataReader
:
NEGATIVE_TABLE_SIZE
=
1e8
def
__init__
(
self
,
dataset
,
min_count
,
care_type
):
self
.
negatives
=
[]
self
.
discards
=
[]
self
.
negpos
=
0
...
...
@@ -35,7 +36,11 @@ class DataReader:
word_frequency
[
word
]
=
word_frequency
.
get
(
word
,
0
)
+
1
if
self
.
token_count
%
1000000
==
0
:
print
(
"Read "
+
str
(
int
(
self
.
token_count
/
1000000
))
+
"M words."
)
print
(
"Read "
+
str
(
int
(
self
.
token_count
/
1000000
))
+
"M words."
)
wid
=
0
for
w
,
c
in
word_frequency
.
items
():
...
...
@@ -71,15 +76,18 @@ class DataReader:
def
getNegatives
(
self
,
target
,
size
):
# TODO check equality with target
if
self
.
care_type
==
0
:
response
=
self
.
negatives
[
self
.
negpos
:
self
.
negpos
+
size
]
response
=
self
.
negatives
[
self
.
negpos
:
self
.
negpos
+
size
]
self
.
negpos
=
(
self
.
negpos
+
size
)
%
len
(
self
.
negatives
)
if
len
(
response
)
!=
size
:
return
np
.
concatenate
((
response
,
self
.
negatives
[
0
:
self
.
negpos
]))
return
np
.
concatenate
(
(
response
,
self
.
negatives
[
0
:
self
.
negpos
])
)
return
response
# -----------------------------------------------------------------------------------------------------------------
class
Metapath2vecDataset
(
Dataset
):
def
__init__
(
self
,
data
,
window_size
):
# read in data, window_size and input filename
...
...
@@ -103,25 +111,44 @@ class Metapath2vecDataset(Dataset):
words
=
line
.
split
()
if
len
(
words
)
>
1
:
word_ids
=
[
self
.
data
.
word2id
[
w
]
for
w
in
words
if
w
in
self
.
data
.
word2id
and
np
.
random
.
rand
()
<
self
.
data
.
discards
[
self
.
data
.
word2id
[
w
]]]
word_ids
=
[
self
.
data
.
word2id
[
w
]
for
w
in
words
if
w
in
self
.
data
.
word2id
and
np
.
random
.
rand
()
<
self
.
data
.
discards
[
self
.
data
.
word2id
[
w
]]
]
pair_catch
=
[]
for
i
,
u
in
enumerate
(
word_ids
):
for
j
,
v
in
enumerate
(
word_ids
[
max
(
i
-
self
.
window_size
,
0
):
i
+
self
.
window_size
]):
word_ids
[
max
(
i
-
self
.
window_size
,
0
)
:
i
+
self
.
window_size
]
):
assert
u
<
self
.
data
.
word_count
assert
v
<
self
.
data
.
word_count
if
i
==
j
:
continue
pair_catch
.
append
((
u
,
v
,
self
.
data
.
getNegatives
(
v
,
5
)))
pair_catch
.
append
(
(
u
,
v
,
self
.
data
.
getNegatives
(
v
,
5
))
)
return
pair_catch
@
staticmethod
def
collate
(
batches
):
all_u
=
[
u
for
batch
in
batches
for
u
,
_
,
_
in
batch
if
len
(
batch
)
>
0
]
all_v
=
[
v
for
batch
in
batches
for
_
,
v
,
_
in
batch
if
len
(
batch
)
>
0
]
all_neg_v
=
[
neg_v
for
batch
in
batches
for
_
,
_
,
neg_v
in
batch
if
len
(
batch
)
>
0
]
return
torch
.
LongTensor
(
all_u
),
torch
.
LongTensor
(
all_v
),
torch
.
LongTensor
(
all_neg_v
)
all_neg_v
=
[
neg_v
for
batch
in
batches
for
_
,
_
,
neg_v
in
batch
if
len
(
batch
)
>
0
]
return
(
torch
.
LongTensor
(
all_u
),
torch
.
LongTensor
(
all_v
),
torch
.
LongTensor
(
all_neg_v
),
)
examples/pytorch/metapath2vec/sampler.py
View file @
704bcaf6
import
numpy
as
np
import
os
import
random
import
sys
import
time
import
tqdm
import
dgl
import
sys
import
os
import
numpy
as
np
import
tqdm
num_walks_per_node
=
1000
walk_length
=
100
path
=
sys
.
argv
[
1
]
def
construct_graph
():
paper_ids
=
[]
paper_names
=
[]
...
...
@@ -31,7 +33,7 @@ def construct_graph():
while
True
:
w
=
f_4
.
readline
()
if
not
w
:
break
;
break
w
=
w
.
strip
().
split
()
identity
=
int
(
w
[
0
])
conf_ids
.
append
(
identity
)
...
...
@@ -39,10 +41,10 @@ def construct_graph():
while
True
:
v
=
f_5
.
readline
()
if
not
v
:
break
;
break
v
=
v
.
strip
().
split
()
identity
=
int
(
v
[
0
])
paper_name
=
'p'
+
''
.
join
(
v
[
1
:])
paper_name
=
"p"
+
""
.
join
(
v
[
1
:])
paper_ids
.
append
(
identity
)
paper_names
.
append
(
paper_name
)
f_3
.
close
()
...
...
@@ -60,41 +62,49 @@ def construct_graph():
f_1
=
open
(
os
.
path
.
join
(
path
,
"paper_author.txt"
),
"r"
)
f_2
=
open
(
os
.
path
.
join
(
path
,
"paper_conf.txt"
),
"r"
)
for
x
in
f_1
:
x
=
x
.
split
(
'
\t
'
)
x
=
x
.
split
(
"
\t
"
)
x
[
0
]
=
int
(
x
[
0
])
x
[
1
]
=
int
(
x
[
1
].
strip
(
'
\n
'
))
x
[
1
]
=
int
(
x
[
1
].
strip
(
"
\n
"
))
paper_author_src
.
append
(
paper_ids_invmap
[
x
[
0
]])
paper_author_dst
.
append
(
author_ids_invmap
[
x
[
1
]])
for
y
in
f_2
:
y
=
y
.
split
(
'
\t
'
)
y
=
y
.
split
(
"
\t
"
)
y
[
0
]
=
int
(
y
[
0
])
y
[
1
]
=
int
(
y
[
1
].
strip
(
'
\n
'
))
y
[
1
]
=
int
(
y
[
1
].
strip
(
"
\n
"
))
paper_conf_src
.
append
(
paper_ids_invmap
[
y
[
0
]])
paper_conf_dst
.
append
(
conf_ids_invmap
[
y
[
1
]])
f_1
.
close
()
f_2
.
close
()
hg
=
dgl
.
heterograph
({
(
'paper'
,
'pa'
,
'author'
)
:
(
paper_author_src
,
paper_author_dst
),
(
'author'
,
'ap'
,
'paper'
)
:
(
paper_author_dst
,
paper_author_src
),
(
'paper'
,
'pc'
,
'conf'
)
:
(
paper_conf_src
,
paper_conf_dst
),
(
'conf'
,
'cp'
,
'paper'
)
:
(
paper_conf_dst
,
paper_conf_src
)})
hg
=
dgl
.
heterograph
(
{
(
"paper"
,
"pa"
,
"author"
):
(
paper_author_src
,
paper_author_dst
),
(
"author"
,
"ap"
,
"paper"
):
(
paper_author_dst
,
paper_author_src
),
(
"paper"
,
"pc"
,
"conf"
):
(
paper_conf_src
,
paper_conf_dst
),
(
"conf"
,
"cp"
,
"paper"
):
(
paper_conf_dst
,
paper_conf_src
),
}
)
return
hg
,
author_names
,
conf_names
,
paper_names
#"conference - paper - Author - paper - conference" metapath sampling
# "conference - paper - Author - paper - conference" metapath sampling
def
generate_metapath
():
output_path
=
open
(
os
.
path
.
join
(
path
,
"output_path.txt"
),
"w"
)
count
=
0
hg
,
author_names
,
conf_names
,
paper_names
=
construct_graph
()
for
conf_idx
in
tqdm
.
trange
(
hg
.
number_of_nodes
(
'
conf
'
)):
for
conf_idx
in
tqdm
.
trange
(
hg
.
number_of_nodes
(
"
conf
"
)):
traces
,
_
=
dgl
.
sampling
.
random_walk
(
hg
,
[
conf_idx
]
*
num_walks_per_node
,
metapath
=
[
'cp'
,
'pa'
,
'ap'
,
'pc'
]
*
walk_length
)
hg
,
[
conf_idx
]
*
num_walks_per_node
,
metapath
=
[
"cp"
,
"pa"
,
"ap"
,
"pc"
]
*
walk_length
,
)
for
tr
in
traces
:
outline
=
' '
.
join
(
(
conf_names
if
i
%
4
==
0
else
author_names
)[
tr
[
i
]]
for
i
in
range
(
0
,
len
(
tr
),
2
))
# skip paper
outline
=
" "
.
join
(
(
conf_names
if
i
%
4
==
0
else
author_names
)[
tr
[
i
]]
for
i
in
range
(
0
,
len
(
tr
),
2
)
)
# skip paper
print
(
outline
,
file
=
output_path
)
output_path
.
close
()
...
...
Prev
1
…
6
7
8
9
10
11
12
13
14
…
17
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment