Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
704bcaf6
Unverified
Commit
704bcaf6
authored
Feb 19, 2023
by
Hongzhi (Steve), Chen
Committed by
GitHub
Feb 19, 2023
Browse files
examples (#5323)
Co-authored-by:
Ubuntu
<
ubuntu@ip-172-31-28-63.ap-northeast-1.compute.internal
>
parent
6bc82161
Changes
332
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
553 additions
and
326 deletions
+553
-326
examples/pytorch/jknet/main.py
examples/pytorch/jknet/main.py
+2
-2
examples/pytorch/jknet/model.py
examples/pytorch/jknet/model.py
+1
-2
examples/pytorch/jtnn/jtnn/datautils.py
examples/pytorch/jtnn/jtnn/datautils.py
+7
-6
examples/pytorch/jtnn/jtnn/jtmpn.py
examples/pytorch/jtnn/jtnn/jtmpn.py
+151
-75
examples/pytorch/jtnn/jtnn/jtnn_dec.py
examples/pytorch/jtnn/jtnn/jtnn_dec.py
+186
-122
examples/pytorch/jtnn/jtnn/jtnn_enc.py
examples/pytorch/jtnn/jtnn/jtnn_enc.py
+2
-3
examples/pytorch/jtnn/jtnn/jtnn_vae.py
examples/pytorch/jtnn/jtnn/jtnn_vae.py
+2
-4
examples/pytorch/jtnn/jtnn/mol_tree_nx.py
examples/pytorch/jtnn/jtnn/mol_tree_nx.py
+1
-2
examples/pytorch/jtnn/jtnn/mpn.py
examples/pytorch/jtnn/jtnn/mpn.py
+2
-3
examples/pytorch/jtnn/jtnn/nnutils.py
examples/pytorch/jtnn/jtnn/nnutils.py
+2
-2
examples/pytorch/label_propagation/main.py
examples/pytorch/label_propagation/main.py
+2
-2
examples/pytorch/lda/example_20newsgroups.py
examples/pytorch/lda/example_20newsgroups.py
+4
-4
examples/pytorch/lda/lda_model.py
examples/pytorch/lda/lda_model.py
+2
-2
examples/pytorch/line_graph/gnn.py
examples/pytorch/line_graph/gnn.py
+49
-21
examples/pytorch/line_graph/train.py
examples/pytorch/line_graph/train.py
+1
-1
examples/pytorch/metapath2vec/download.py
examples/pytorch/metapath2vec/download.py
+13
-10
examples/pytorch/metapath2vec/metapath2vec.py
examples/pytorch/metapath2vec/metapath2vec.py
+50
-25
examples/pytorch/metapath2vec/model.py
examples/pytorch/metapath2vec/model.py
+4
-5
examples/pytorch/metapath2vec/reading_data.py
examples/pytorch/metapath2vec/reading_data.py
+40
-13
examples/pytorch/metapath2vec/sampler.py
examples/pytorch/metapath2vec/sampler.py
+32
-22
No files found.
examples/pytorch/jknet/main.py
View file @
704bcaf6
...
@@ -7,12 +7,12 @@ import numpy as np
...
@@ -7,12 +7,12 @@ import numpy as np
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.optim
as
optim
import
torch.optim
as
optim
from
dgl.data
import
CiteseerGraphDataset
,
CoraGraphDataset
from
model
import
JKNet
from
model
import
JKNet
from
sklearn.model_selection
import
train_test_split
from
sklearn.model_selection
import
train_test_split
from
tqdm
import
trange
from
tqdm
import
trange
from
dgl.data
import
CiteseerGraphDataset
,
CoraGraphDataset
def
main
(
args
):
def
main
(
args
):
# Step 1: Prepare graph data and retrieve train/validation/test index ============================= #
# Step 1: Prepare graph data and retrieve train/validation/test index ============================= #
...
...
examples/pytorch/jknet/model.py
View file @
704bcaf6
import
dgl.function
as
fn
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
import
dgl.function
as
fn
from
dgl.nn
import
GraphConv
,
JumpingKnowledge
from
dgl.nn
import
GraphConv
,
JumpingKnowledge
...
...
examples/pytorch/jtnn/jtnn/datautils.py
View file @
704bcaf6
import
torch
from
torch.utils.data
import
Dataset
import
dgl
import
dgl
import
torch
from
dgl.data.utils
import
(
from
dgl.data.utils
import
(
_get_dgl_url
,
_get_dgl_url
,
download
,
download
,
extract_archive
,
extract_archive
,
get_download_dir
,
get_download_dir
,
)
)
from
torch.utils.data
import
Dataset
from
.jtmpn
import
ATOM_FDIM
as
ATOM_FDIM_DEC
from
.jtmpn
import
(
from
.jtmpn
import
BOND_FDIM
as
BOND_FDIM_DEC
ATOM_FDIM
as
ATOM_FDIM_DEC
,
from
.jtmpn
import
mol2dgl_single
as
mol2dgl_dec
BOND_FDIM
as
BOND_FDIM_DEC
,
mol2dgl_single
as
mol2dgl_dec
,
)
from
.mol_tree
import
Vocab
from
.mol_tree
import
Vocab
from
.mol_tree_nx
import
DGLMolTree
from
.mol_tree_nx
import
DGLMolTree
from
.mpn
import
mol2dgl_single
as
mol2dgl_enc
from
.mpn
import
mol2dgl_single
as
mol2dgl_enc
...
...
examples/pytorch/jtnn/jtnn/jtmpn.py
View file @
704bcaf6
import
os
import
dgl
import
dgl.function
as
DGLF
import
rdkit.Chem
as
Chem
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
dgl
import
line_graph
,
mean_nodes
from
.nnutils
import
cuda
from
.nnutils
import
cuda
import
rdkit.Chem
as
Chem
import
dgl
from
dgl
import
mean_nodes
,
line_graph
import
dgl.function
as
DGLF
import
os
ELEM_LIST
=
[
'C'
,
'N'
,
'O'
,
'S'
,
'F'
,
'Si'
,
'P'
,
'Cl'
,
'Br'
,
'Mg'
,
'Na'
,
'Ca'
,
ELEM_LIST
=
[
'Fe'
,
'Al'
,
'I'
,
'B'
,
'K'
,
'Se'
,
'Zn'
,
'H'
,
'Cu'
,
'Mn'
,
'unknown'
]
"C"
,
"N"
,
"O"
,
"S"
,
"F"
,
"Si"
,
"P"
,
"Cl"
,
"Br"
,
"Mg"
,
"Na"
,
"Ca"
,
"Fe"
,
"Al"
,
"I"
,
"B"
,
"K"
,
"Se"
,
"Zn"
,
"H"
,
"Cu"
,
"Mn"
,
"unknown"
,
]
ATOM_FDIM
=
len
(
ELEM_LIST
)
+
6
+
5
+
1
ATOM_FDIM
=
len
(
ELEM_LIST
)
+
6
+
5
+
1
BOND_FDIM
=
5
BOND_FDIM
=
5
MAX_NB
=
10
MAX_NB
=
10
PAPER
=
os
.
getenv
(
'PAPER'
,
False
)
PAPER
=
os
.
getenv
(
"PAPER"
,
False
)
def
onek_encoding_unk
(
x
,
allowable_set
):
def
onek_encoding_unk
(
x
,
allowable_set
):
if
x
not
in
allowable_set
:
if
x
not
in
allowable_set
:
x
=
allowable_set
[
-
1
]
x
=
allowable_set
[
-
1
]
return
[
x
==
s
for
s
in
allowable_set
]
return
[
x
==
s
for
s
in
allowable_set
]
# Note that during graph decoding they don't predict stereochemistry-related
# Note that during graph decoding they don't predict stereochemistry-related
# characteristics (i.e. Chiral Atoms, E-Z, Cis-Trans). Instead, they decode
# characteristics (i.e. Chiral Atoms, E-Z, Cis-Trans). Instead, they decode
# the 2-D graph first, then enumerate all possible 3-D forms and find the
# the 2-D graph first, then enumerate all possible 3-D forms and find the
# one with highest score.
# one with highest score.
def
atom_features
(
atom
):
def
atom_features
(
atom
):
return
(
torch
.
Tensor
(
onek_encoding_unk
(
atom
.
GetSymbol
(),
ELEM_LIST
)
return
torch
.
Tensor
(
+
onek_encoding_unk
(
atom
.
GetDegree
(),
[
0
,
1
,
2
,
3
,
4
,
5
])
onek_encoding_unk
(
atom
.
GetSymbol
(),
ELEM_LIST
)
+
onek_encoding_unk
(
atom
.
GetFormalCharge
(),
[
-
1
,
-
2
,
1
,
2
,
0
])
+
onek_encoding_unk
(
atom
.
GetDegree
(),
[
0
,
1
,
2
,
3
,
4
,
5
])
+
[
atom
.
GetIsAromatic
()]))
+
onek_encoding_unk
(
atom
.
GetFormalCharge
(),
[
-
1
,
-
2
,
1
,
2
,
0
])
+
[
atom
.
GetIsAromatic
()]
)
def
bond_features
(
bond
):
def
bond_features
(
bond
):
bt
=
bond
.
GetBondType
()
bt
=
bond
.
GetBondType
()
return
(
torch
.
Tensor
([
bt
==
Chem
.
rdchem
.
BondType
.
SINGLE
,
bt
==
Chem
.
rdchem
.
BondType
.
DOUBLE
,
bt
==
Chem
.
rdchem
.
BondType
.
TRIPLE
,
bt
==
Chem
.
rdchem
.
BondType
.
AROMATIC
,
bond
.
IsInRing
()]))
return
torch
.
Tensor
(
[
bt
==
Chem
.
rdchem
.
BondType
.
SINGLE
,
bt
==
Chem
.
rdchem
.
BondType
.
DOUBLE
,
bt
==
Chem
.
rdchem
.
BondType
.
TRIPLE
,
bt
==
Chem
.
rdchem
.
BondType
.
AROMATIC
,
bond
.
IsInRing
(),
]
)
def
mol2dgl_single
(
cand_batch
):
def
mol2dgl_single
(
cand_batch
):
cand_graphs
=
[]
cand_graphs
=
[]
tree_mess_source_edges
=
[]
# map these edges from trees to...
tree_mess_source_edges
=
[]
# map these edges from trees to...
tree_mess_target_edges
=
[]
# these edges on candidate graphs
tree_mess_target_edges
=
[]
# these edges on candidate graphs
tree_mess_target_nodes
=
[]
tree_mess_target_nodes
=
[]
n_nodes
=
0
n_nodes
=
0
n_edges
=
0
n_edges
=
0
...
@@ -50,8 +89,8 @@ def mol2dgl_single(cand_batch):
...
@@ -50,8 +89,8 @@ def mol2dgl_single(cand_batch):
n_bonds
=
mol
.
GetNumBonds
()
n_bonds
=
mol
.
GetNumBonds
()
ctr_node
=
mol_tree
.
nodes_dict
[
ctr_node_id
]
ctr_node
=
mol_tree
.
nodes_dict
[
ctr_node_id
]
ctr_bid
=
ctr_node
[
'
idx
'
]
ctr_bid
=
ctr_node
[
"
idx
"
]
mol_tree_graph
=
getattr
(
mol_tree
,
'
graph
'
,
mol_tree
)
mol_tree_graph
=
getattr
(
mol_tree
,
"
graph
"
,
mol_tree
)
for
i
,
atom
in
enumerate
(
mol
.
GetAtoms
()):
for
i
,
atom
in
enumerate
(
mol
.
GetAtoms
()):
assert
i
==
atom
.
GetIdx
()
assert
i
==
atom
.
GetIdx
()
...
@@ -75,15 +114,19 @@ def mol2dgl_single(cand_batch):
...
@@ -75,15 +114,19 @@ def mol2dgl_single(cand_batch):
x_nid
,
y_nid
=
a1
.
GetAtomMapNum
(),
a2
.
GetAtomMapNum
()
x_nid
,
y_nid
=
a1
.
GetAtomMapNum
(),
a2
.
GetAtomMapNum
()
# Tree node ID in the batch
# Tree node ID in the batch
x_bid
=
mol_tree
.
nodes_dict
[
x_nid
-
1
][
'
idx
'
]
if
x_nid
>
0
else
-
1
x_bid
=
mol_tree
.
nodes_dict
[
x_nid
-
1
][
"
idx
"
]
if
x_nid
>
0
else
-
1
y_bid
=
mol_tree
.
nodes_dict
[
y_nid
-
1
][
'
idx
'
]
if
y_nid
>
0
else
-
1
y_bid
=
mol_tree
.
nodes_dict
[
y_nid
-
1
][
"
idx
"
]
if
y_nid
>
0
else
-
1
if
x_bid
>=
0
and
y_bid
>=
0
and
x_bid
!=
y_bid
:
if
x_bid
>=
0
and
y_bid
>=
0
and
x_bid
!=
y_bid
:
if
mol_tree_graph
.
has_edges_between
(
x_bid
,
y_bid
):
if
mol_tree_graph
.
has_edges_between
(
x_bid
,
y_bid
):
tree_mess_target_edges
.
append
((
begin_idx
+
n_nodes
,
end_idx
+
n_nodes
))
tree_mess_target_edges
.
append
(
(
begin_idx
+
n_nodes
,
end_idx
+
n_nodes
)
)
tree_mess_source_edges
.
append
((
x_bid
,
y_bid
))
tree_mess_source_edges
.
append
((
x_bid
,
y_bid
))
tree_mess_target_nodes
.
append
(
end_idx
+
n_nodes
)
tree_mess_target_nodes
.
append
(
end_idx
+
n_nodes
)
if
mol_tree_graph
.
has_edges_between
(
y_bid
,
x_bid
):
if
mol_tree_graph
.
has_edges_between
(
y_bid
,
x_bid
):
tree_mess_target_edges
.
append
((
end_idx
+
n_nodes
,
begin_idx
+
n_nodes
))
tree_mess_target_edges
.
append
(
(
end_idx
+
n_nodes
,
begin_idx
+
n_nodes
)
)
tree_mess_source_edges
.
append
((
y_bid
,
x_bid
))
tree_mess_source_edges
.
append
((
y_bid
,
x_bid
))
tree_mess_target_nodes
.
append
(
begin_idx
+
n_nodes
)
tree_mess_target_nodes
.
append
(
begin_idx
+
n_nodes
)
...
@@ -91,11 +134,14 @@ def mol2dgl_single(cand_batch):
...
@@ -91,11 +134,14 @@ def mol2dgl_single(cand_batch):
g
=
dgl
.
graph
((
bond_src
,
bond_dst
),
num_nodes
=
n_atoms
)
g
=
dgl
.
graph
((
bond_src
,
bond_dst
),
num_nodes
=
n_atoms
)
cand_graphs
.
append
(
g
)
cand_graphs
.
append
(
g
)
return
cand_graphs
,
torch
.
stack
(
atom_x
),
\
return
(
torch
.
stack
(
bond_x
)
if
len
(
bond_x
)
>
0
else
torch
.
zeros
(
0
),
\
cand_graphs
,
torch
.
LongTensor
(
tree_mess_source_edges
),
\
torch
.
stack
(
atom_x
),
torch
.
LongTensor
(
tree_mess_target_edges
),
\
torch
.
stack
(
bond_x
)
if
len
(
bond_x
)
>
0
else
torch
.
zeros
(
0
),
torch
.
LongTensor
(
tree_mess_target_nodes
)
torch
.
LongTensor
(
tree_mess_source_edges
),
torch
.
LongTensor
(
tree_mess_target_edges
),
torch
.
LongTensor
(
tree_mess_target_nodes
),
)
class
LoopyBPUpdate
(
nn
.
Module
):
class
LoopyBPUpdate
(
nn
.
Module
):
...
@@ -106,28 +152,28 @@ class LoopyBPUpdate(nn.Module):
...
@@ -106,28 +152,28 @@ class LoopyBPUpdate(nn.Module):
self
.
W_h
=
nn
.
Linear
(
hidden_size
,
hidden_size
,
bias
=
False
)
self
.
W_h
=
nn
.
Linear
(
hidden_size
,
hidden_size
,
bias
=
False
)
def
forward
(
self
,
node
):
def
forward
(
self
,
node
):
msg_input
=
node
.
data
[
'
msg_input
'
]
msg_input
=
node
.
data
[
"
msg_input
"
]
msg_delta
=
self
.
W_h
(
node
.
data
[
'
accum_msg
'
]
+
node
.
data
[
'
alpha
'
])
msg_delta
=
self
.
W_h
(
node
.
data
[
"
accum_msg
"
]
+
node
.
data
[
"
alpha
"
])
msg
=
torch
.
relu
(
msg_input
+
msg_delta
)
msg
=
torch
.
relu
(
msg_input
+
msg_delta
)
return
{
'
msg
'
:
msg
}
return
{
"
msg
"
:
msg
}
if
PAPER
:
if
PAPER
:
mpn_gather_msg
=
[
mpn_gather_msg
=
[
DGLF
.
copy_e
(
edge
=
'
msg
'
,
out
=
'
msg
'
),
DGLF
.
copy_e
(
edge
=
"
msg
"
,
out
=
"
msg
"
),
DGLF
.
copy_e
(
edge
=
'
alpha
'
,
out
=
'
alpha
'
)
DGLF
.
copy_e
(
edge
=
"
alpha
"
,
out
=
"
alpha
"
),
]
]
else
:
else
:
mpn_gather_msg
=
DGLF
.
copy_e
(
edge
=
'
msg
'
,
out
=
'
msg
'
)
mpn_gather_msg
=
DGLF
.
copy_e
(
edge
=
"
msg
"
,
out
=
"
msg
"
)
if
PAPER
:
if
PAPER
:
mpn_gather_reduce
=
[
mpn_gather_reduce
=
[
DGLF
.
sum
(
msg
=
'
msg
'
,
out
=
'm'
),
DGLF
.
sum
(
msg
=
"
msg
"
,
out
=
"m"
),
DGLF
.
sum
(
msg
=
'
alpha
'
,
out
=
'
accum_alpha
'
),
DGLF
.
sum
(
msg
=
"
alpha
"
,
out
=
"
accum_alpha
"
),
]
]
else
:
else
:
mpn_gather_reduce
=
DGLF
.
sum
(
msg
=
'
msg
'
,
out
=
'm'
)
mpn_gather_reduce
=
DGLF
.
sum
(
msg
=
"
msg
"
,
out
=
"m"
)
class
GatherUpdate
(
nn
.
Module
):
class
GatherUpdate
(
nn
.
Module
):
...
@@ -139,12 +185,12 @@ class GatherUpdate(nn.Module):
...
@@ -139,12 +185,12 @@ class GatherUpdate(nn.Module):
def
forward
(
self
,
node
):
def
forward
(
self
,
node
):
if
PAPER
:
if
PAPER
:
#m = node['m']
#
m = node['m']
m
=
node
.
data
[
'm'
]
+
node
.
data
[
'
accum_alpha
'
]
m
=
node
.
data
[
"m"
]
+
node
.
data
[
"
accum_alpha
"
]
else
:
else
:
m
=
node
.
data
[
'm'
]
+
node
.
data
[
'
alpha
'
]
m
=
node
.
data
[
"m"
]
+
node
.
data
[
"
alpha
"
]
return
{
return
{
'h'
:
torch
.
relu
(
self
.
W_o
(
torch
.
cat
([
node
.
data
[
'x'
],
m
],
1
))),
"h"
:
torch
.
relu
(
self
.
W_o
(
torch
.
cat
([
node
.
data
[
"x"
],
m
],
1
))),
}
}
...
@@ -166,20 +212,32 @@ class DGLJTMPN(nn.Module):
...
@@ -166,20 +212,32 @@ class DGLJTMPN(nn.Module):
self
.
n_passes
=
0
self
.
n_passes
=
0
def
forward
(
self
,
cand_batch
,
mol_tree_batch
):
def
forward
(
self
,
cand_batch
,
mol_tree_batch
):
cand_graphs
,
tree_mess_src_edges
,
tree_mess_tgt_edges
,
tree_mess_tgt_nodes
=
cand_batch
(
cand_graphs
,
tree_mess_src_edges
,
tree_mess_tgt_edges
,
tree_mess_tgt_nodes
,
)
=
cand_batch
n_samples
=
len
(
cand_graphs
)
n_samples
=
len
(
cand_graphs
)
cand_line_graph
=
line_graph
(
cand_graphs
,
backtracking
=
False
,
shared
=
True
)
cand_line_graph
=
line_graph
(
cand_graphs
,
backtracking
=
False
,
shared
=
True
)
n_nodes
=
cand_graphs
.
number_of_nodes
()
n_nodes
=
cand_graphs
.
number_of_nodes
()
n_edges
=
cand_graphs
.
number_of_edges
()
n_edges
=
cand_graphs
.
number_of_edges
()
cand_graphs
=
self
.
run
(
cand_graphs
=
self
.
run
(
cand_graphs
,
cand_line_graph
,
tree_mess_src_edges
,
tree_mess_tgt_edges
,
cand_graphs
,
tree_mess_tgt_nodes
,
mol_tree_batch
)
cand_line_graph
,
tree_mess_src_edges
,
tree_mess_tgt_edges
,
tree_mess_tgt_nodes
,
mol_tree_batch
,
)
g_repr
=
mean_nodes
(
cand_graphs
,
'h'
)
g_repr
=
mean_nodes
(
cand_graphs
,
"h"
)
self
.
n_samples_total
+=
n_samples
self
.
n_samples_total
+=
n_samples
self
.
n_nodes_total
+=
n_nodes
self
.
n_nodes_total
+=
n_nodes
...
@@ -188,33 +246,45 @@ class DGLJTMPN(nn.Module):
...
@@ -188,33 +246,45 @@ class DGLJTMPN(nn.Module):
return
g_repr
return
g_repr
def
run
(
self
,
cand_graphs
,
cand_line_graph
,
tree_mess_src_edges
,
tree_mess_tgt_edges
,
def
run
(
tree_mess_tgt_nodes
,
mol_tree_batch
):
self
,
cand_graphs
,
cand_line_graph
,
tree_mess_src_edges
,
tree_mess_tgt_edges
,
tree_mess_tgt_nodes
,
mol_tree_batch
,
):
n_nodes
=
cand_graphs
.
number_of_nodes
()
n_nodes
=
cand_graphs
.
number_of_nodes
()
cand_graphs
.
apply_edges
(
cand_graphs
.
apply_edges
(
func
=
lambda
edges
:
{
'
src_x
'
:
edges
.
src
[
'x'
]},
func
=
lambda
edges
:
{
"
src_x
"
:
edges
.
src
[
"x"
]},
)
)
cand_line_graph
.
ndata
.
update
(
cand_graphs
.
edata
)
cand_line_graph
.
ndata
.
update
(
cand_graphs
.
edata
)
bond_features
=
cand_line_graph
.
ndata
[
'x'
]
bond_features
=
cand_line_graph
.
ndata
[
"x"
]
source_features
=
cand_line_graph
.
ndata
[
'
src_x
'
]
source_features
=
cand_line_graph
.
ndata
[
"
src_x
"
]
features
=
torch
.
cat
([
source_features
,
bond_features
],
1
)
features
=
torch
.
cat
([
source_features
,
bond_features
],
1
)
msg_input
=
self
.
W_i
(
features
)
msg_input
=
self
.
W_i
(
features
)
cand_line_graph
.
ndata
.
update
({
cand_line_graph
.
ndata
.
update
(
'msg_input'
:
msg_input
,
{
'msg'
:
torch
.
relu
(
msg_input
),
"msg_input"
:
msg_input
,
'accum_msg'
:
torch
.
zeros_like
(
msg_input
),
"msg"
:
torch
.
relu
(
msg_input
),
})
"accum_msg"
:
torch
.
zeros_like
(
msg_input
),
}
)
zero_node_state
=
bond_features
.
new
(
n_nodes
,
self
.
hidden_size
).
zero_
()
zero_node_state
=
bond_features
.
new
(
n_nodes
,
self
.
hidden_size
).
zero_
()
cand_graphs
.
ndata
.
update
({
cand_graphs
.
ndata
.
update
(
'm'
:
zero_node_state
.
clone
(),
{
'h'
:
zero_node_state
.
clone
(),
"m"
:
zero_node_state
.
clone
(),
})
"h"
:
zero_node_state
.
clone
(),
}
cand_graphs
.
edata
[
'alpha'
]
=
\
)
cuda
(
torch
.
zeros
(
cand_graphs
.
number_of_edges
(),
self
.
hidden_size
))
cand_graphs
.
ndata
[
'alpha'
]
=
zero_node_state
cand_graphs
.
edata
[
"alpha"
]
=
cuda
(
torch
.
zeros
(
cand_graphs
.
number_of_edges
(),
self
.
hidden_size
)
)
cand_graphs
.
ndata
[
"alpha"
]
=
zero_node_state
if
tree_mess_src_edges
.
shape
[
0
]
>
0
:
if
tree_mess_src_edges
.
shape
[
0
]
>
0
:
if
PAPER
:
if
PAPER
:
src_u
,
src_v
=
tree_mess_src_edges
.
unbind
(
1
)
src_u
,
src_v
=
tree_mess_src_edges
.
unbind
(
1
)
...
@@ -222,33 +292,39 @@ class DGLJTMPN(nn.Module):
...
@@ -222,33 +292,39 @@ class DGLJTMPN(nn.Module):
src_u
=
src_u
.
to
(
mol_tree_batch
.
device
)
src_u
=
src_u
.
to
(
mol_tree_batch
.
device
)
src_v
=
src_v
.
to
(
mol_tree_batch
.
device
)
src_v
=
src_v
.
to
(
mol_tree_batch
.
device
)
eid
=
mol_tree_batch
.
edge_ids
(
src_u
,
src_v
)
eid
=
mol_tree_batch
.
edge_ids
(
src_u
,
src_v
)
alpha
=
mol_tree_batch
.
edata
[
'm'
][
eid
]
alpha
=
mol_tree_batch
.
edata
[
"m"
][
eid
]
cand_graphs
.
edges
[
tgt_u
,
tgt_v
].
data
[
'
alpha
'
]
=
alpha
cand_graphs
.
edges
[
tgt_u
,
tgt_v
].
data
[
"
alpha
"
]
=
alpha
else
:
else
:
src_u
,
src_v
=
tree_mess_src_edges
.
unbind
(
1
)
src_u
,
src_v
=
tree_mess_src_edges
.
unbind
(
1
)
src_u
=
src_u
.
to
(
mol_tree_batch
.
device
)
src_u
=
src_u
.
to
(
mol_tree_batch
.
device
)
src_v
=
src_v
.
to
(
mol_tree_batch
.
device
)
src_v
=
src_v
.
to
(
mol_tree_batch
.
device
)
eid
=
mol_tree_batch
.
edge_ids
(
src_u
,
src_v
)
eid
=
mol_tree_batch
.
edge_ids
(
src_u
,
src_v
)
alpha
=
mol_tree_batch
.
edata
[
'm'
][
eid
]
alpha
=
mol_tree_batch
.
edata
[
"m"
][
eid
]
node_idx
=
(
tree_mess_tgt_nodes
node_idx
=
tree_mess_tgt_nodes
.
to
(
.
to
(
device
=
zero_node_state
.
device
)[:,
None
]
device
=
zero_node_state
.
device
.
expand_as
(
alpha
))
)[:,
None
].
expand_as
(
alpha
)
node_alpha
=
zero_node_state
.
clone
().
scatter_add
(
0
,
node_idx
,
alpha
)
node_alpha
=
zero_node_state
.
clone
().
scatter_add
(
cand_graphs
.
ndata
[
'alpha'
]
=
node_alpha
0
,
node_idx
,
alpha
)
cand_graphs
.
ndata
[
"alpha"
]
=
node_alpha
cand_graphs
.
apply_edges
(
cand_graphs
.
apply_edges
(
func
=
lambda
edges
:
{
'
alpha
'
:
edges
.
src
[
'
alpha
'
]},
func
=
lambda
edges
:
{
"
alpha
"
:
edges
.
src
[
"
alpha
"
]},
)
)
cand_line_graph
.
ndata
.
update
(
cand_graphs
.
edata
)
cand_line_graph
.
ndata
.
update
(
cand_graphs
.
edata
)
for
i
in
range
(
self
.
depth
-
1
):
for
i
in
range
(
self
.
depth
-
1
):
cand_line_graph
.
update_all
(
DGLF
.
copy_u
(
'msg'
,
'msg'
),
DGLF
.
sum
(
'msg'
,
'accum_msg'
))
cand_line_graph
.
update_all
(
DGLF
.
copy_u
(
"msg"
,
"msg"
),
DGLF
.
sum
(
"msg"
,
"accum_msg"
)
)
cand_line_graph
.
apply_nodes
(
self
.
loopy_bp_updater
)
cand_line_graph
.
apply_nodes
(
self
.
loopy_bp_updater
)
cand_graphs
.
edata
.
update
(
cand_line_graph
.
ndata
)
cand_graphs
.
edata
.
update
(
cand_line_graph
.
ndata
)
cand_graphs
.
update_all
(
DGLF
.
copy_e
(
'
msg
'
,
'
msg
'
),
DGLF
.
sum
(
'
msg
'
,
'm'
))
cand_graphs
.
update_all
(
DGLF
.
copy_e
(
"
msg
"
,
"
msg
"
),
DGLF
.
sum
(
"
msg
"
,
"m"
))
if
PAPER
:
if
PAPER
:
cand_graphs
.
update_all
(
DGLF
.
copy_e
(
'alpha'
,
'alpha'
),
DGLF
.
sum
(
'alpha'
,
'accum_alpha'
))
cand_graphs
.
update_all
(
DGLF
.
copy_e
(
"alpha"
,
"alpha"
),
DGLF
.
sum
(
"alpha"
,
"accum_alpha"
)
)
cand_graphs
.
apply_nodes
(
self
.
gather_updater
)
cand_graphs
.
apply_nodes
(
self
.
gather_updater
)
return
cand_graphs
return
cand_graphs
examples/pytorch/jtnn/jtnn/jtnn_dec.py
View file @
704bcaf6
import
dgl.function
as
DGLF
import
numpy
as
np
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
.mol_tree_nx
import
DGLMolTree
from
.chemutils
import
enum_assemble_nx
,
get_mol
from
.nnutils
import
GRUUpdate
,
cuda
,
tocpu
from
dgl
import
batch
,
dfs_labeled_edges_generator
,
line_graph
from
dgl
import
batch
,
dfs_labeled_edges_generator
,
line_graph
import
dgl.function
as
DGLF
import
numpy
as
np
from
.chemutils
import
enum_assemble_nx
,
get_mol
from
.mol_tree_nx
import
DGLMolTree
from
.nnutils
import
cuda
,
GRUUpdate
,
tocpu
MAX_NB
=
8
MAX_NB
=
8
MAX_DECODE_LEN
=
100
MAX_DECODE_LEN
=
100
...
@@ -21,60 +22,70 @@ def dfs_order(forest, roots):
...
@@ -21,60 +22,70 @@ def dfs_order(forest, roots):
# using find_edges().
# using find_edges().
yield
e
^
l
,
l
yield
e
^
l
,
l
dec_tree_node_msg
=
DGLF
.
copy_e
(
edge
=
'm'
,
out
=
'm'
)
dec_tree_node_reduce
=
DGLF
.
sum
(
msg
=
'm'
,
out
=
'h'
)
dec_tree_node_msg
=
DGLF
.
copy_e
(
edge
=
"m"
,
out
=
"m"
)
dec_tree_node_reduce
=
DGLF
.
sum
(
msg
=
"m"
,
out
=
"h"
)
def
dec_tree_node_update
(
nodes
):
def
dec_tree_node_update
(
nodes
):
return
{
'
new
'
:
nodes
.
data
[
'
new
'
].
clone
().
zero_
()}
return
{
"
new
"
:
nodes
.
data
[
"
new
"
].
clone
().
zero_
()}
def
have_slots
(
fa_slots
,
ch_slots
):
def
have_slots
(
fa_slots
,
ch_slots
):
if
len
(
fa_slots
)
>
2
and
len
(
ch_slots
)
>
2
:
if
len
(
fa_slots
)
>
2
and
len
(
ch_slots
)
>
2
:
return
True
return
True
matches
=
[]
matches
=
[]
for
i
,
s1
in
enumerate
(
fa_slots
):
for
i
,
s1
in
enumerate
(
fa_slots
):
a1
,
c1
,
h1
=
s1
a1
,
c1
,
h1
=
s1
for
j
,
s2
in
enumerate
(
ch_slots
):
for
j
,
s2
in
enumerate
(
ch_slots
):
a2
,
c2
,
h2
=
s2
a2
,
c2
,
h2
=
s2
if
a1
==
a2
and
c1
==
c2
and
(
a1
!=
"C"
or
h1
+
h2
>=
4
):
if
a1
==
a2
and
c1
==
c2
and
(
a1
!=
"C"
or
h1
+
h2
>=
4
):
matches
.
append
(
(
i
,
j
)
)
matches
.
append
((
i
,
j
))
if
len
(
matches
)
==
0
:
return
False
if
len
(
matches
)
==
0
:
return
False
fa_match
,
ch_match
=
list
(
zip
(
*
matches
))
fa_match
,
ch_match
=
list
(
zip
(
*
matches
))
if
len
(
set
(
fa_match
))
==
1
and
1
<
len
(
fa_slots
)
<=
2
:
#never remove atom from ring
if
(
len
(
set
(
fa_match
))
==
1
and
1
<
len
(
fa_slots
)
<=
2
):
# never remove atom from ring
fa_slots
.
pop
(
fa_match
[
0
])
fa_slots
.
pop
(
fa_match
[
0
])
if
len
(
set
(
ch_match
))
==
1
and
1
<
len
(
ch_slots
)
<=
2
:
#never remove atom from ring
if
(
len
(
set
(
ch_match
))
==
1
and
1
<
len
(
ch_slots
)
<=
2
):
# never remove atom from ring
ch_slots
.
pop
(
ch_match
[
0
])
ch_slots
.
pop
(
ch_match
[
0
])
return
True
return
True
def
can_assemble
(
mol_tree
,
u
,
v_node_dict
):
def
can_assemble
(
mol_tree
,
u
,
v_node_dict
):
u_node_dict
=
mol_tree
.
nodes_dict
[
u
]
u_node_dict
=
mol_tree
.
nodes_dict
[
u
]
u_neighbors
=
mol_tree
.
graph
.
successors
(
u
)
u_neighbors
=
mol_tree
.
graph
.
successors
(
u
)
u_neighbors_node_dict
=
[
u_neighbors_node_dict
=
[
mol_tree
.
nodes_dict
[
_u
]
mol_tree
.
nodes_dict
[
_u
]
for
_u
in
u_neighbors
for
_u
in
u_neighbors
if
_u
in
mol_tree
.
nodes_dict
if
_u
in
mol_tree
.
nodes_dict
]
]
neis
=
u_neighbors_node_dict
+
[
v_node_dict
]
neis
=
u_neighbors_node_dict
+
[
v_node_dict
]
for
i
,
nei
in
enumerate
(
neis
):
for
i
,
nei
in
enumerate
(
neis
):
nei
[
'nid'
]
=
i
nei
[
"nid"
]
=
i
neighbors
=
[
nei
for
nei
in
neis
if
nei
[
'mol'
].
GetNumAtoms
()
>
1
]
neighbors
=
[
nei
for
nei
in
neis
if
nei
[
"mol"
].
GetNumAtoms
()
>
1
]
neighbors
=
sorted
(
neighbors
,
key
=
lambda
x
:
x
[
'mol'
].
GetNumAtoms
(),
reverse
=
True
)
neighbors
=
sorted
(
singletons
=
[
nei
for
nei
in
neis
if
nei
[
'mol'
].
GetNumAtoms
()
==
1
]
neighbors
,
key
=
lambda
x
:
x
[
"mol"
].
GetNumAtoms
(),
reverse
=
True
)
singletons
=
[
nei
for
nei
in
neis
if
nei
[
"mol"
].
GetNumAtoms
()
==
1
]
neighbors
=
singletons
+
neighbors
neighbors
=
singletons
+
neighbors
cands
=
enum_assemble_nx
(
u_node_dict
,
neighbors
)
cands
=
enum_assemble_nx
(
u_node_dict
,
neighbors
)
return
len
(
cands
)
>
0
return
len
(
cands
)
>
0
def
create_node_dict
(
smiles
,
clique
=
[]):
def
create_node_dict
(
smiles
,
clique
=
[]):
return
dict
(
return
dict
(
smiles
=
smiles
,
smiles
=
smiles
,
mol
=
get_mol
(
smiles
),
mol
=
get_mol
(
smiles
),
clique
=
clique
,
clique
=
clique
,
)
)
class
DGLJTNNDecoder
(
nn
.
Module
):
class
DGLJTNNDecoder
(
nn
.
Module
):
...
@@ -98,41 +109,54 @@ class DGLJTNNDecoder(nn.Module):
...
@@ -98,41 +109,54 @@ class DGLJTNNDecoder(nn.Module):
self
.
U_s
=
nn
.
Linear
(
hidden_size
,
1
)
self
.
U_s
=
nn
.
Linear
(
hidden_size
,
1
)
def
forward
(
self
,
mol_trees
,
tree_vec
):
def
forward
(
self
,
mol_trees
,
tree_vec
):
'''
"""
The training procedure which computes the prediction loss given the
The training procedure which computes the prediction loss given the
ground truth tree
ground truth tree
'''
"""
mol_tree_batch
=
batch
(
mol_trees
)
mol_tree_batch
=
batch
(
mol_trees
)
mol_tree_batch_lg
=
line_graph
(
mol_tree_batch
,
backtracking
=
False
,
shared
=
True
)
mol_tree_batch_lg
=
line_graph
(
mol_tree_batch
,
backtracking
=
False
,
shared
=
True
)
n_trees
=
len
(
mol_trees
)
n_trees
=
len
(
mol_trees
)
return
self
.
run
(
mol_tree_batch
,
mol_tree_batch_lg
,
n_trees
,
tree_vec
)
return
self
.
run
(
mol_tree_batch
,
mol_tree_batch_lg
,
n_trees
,
tree_vec
)
def
run
(
self
,
mol_tree_batch
,
mol_tree_batch_lg
,
n_trees
,
tree_vec
):
def
run
(
self
,
mol_tree_batch
,
mol_tree_batch_lg
,
n_trees
,
tree_vec
):
node_offset
=
np
.
cumsum
(
np
.
insert
(
mol_tree_batch
.
batch_num_nodes
().
cpu
().
numpy
(),
0
,
0
))
node_offset
=
np
.
cumsum
(
np
.
insert
(
mol_tree_batch
.
batch_num_nodes
().
cpu
().
numpy
(),
0
,
0
)
)
root_ids
=
node_offset
[:
-
1
]
root_ids
=
node_offset
[:
-
1
]
n_nodes
=
mol_tree_batch
.
number_of_nodes
()
n_nodes
=
mol_tree_batch
.
number_of_nodes
()
n_edges
=
mol_tree_batch
.
number_of_edges
()
n_edges
=
mol_tree_batch
.
number_of_edges
()
mol_tree_batch
.
ndata
.
update
({
mol_tree_batch
.
ndata
.
update
(
'x'
:
self
.
embedding
(
mol_tree_batch
.
ndata
[
'wid'
]),
{
'h'
:
cuda
(
torch
.
zeros
(
n_nodes
,
self
.
hidden_size
)),
"x"
:
self
.
embedding
(
mol_tree_batch
.
ndata
[
"wid"
]),
'new'
:
cuda
(
torch
.
ones
(
n_nodes
).
bool
()),
# whether it's newly generated node
"h"
:
cuda
(
torch
.
zeros
(
n_nodes
,
self
.
hidden_size
)),
})
"new"
:
cuda
(
torch
.
ones
(
n_nodes
).
bool
()
mol_tree_batch
.
edata
.
update
({
),
# whether it's newly generated node
's'
:
cuda
(
torch
.
zeros
(
n_edges
,
self
.
hidden_size
)),
}
'm'
:
cuda
(
torch
.
zeros
(
n_edges
,
self
.
hidden_size
)),
)
'r'
:
cuda
(
torch
.
zeros
(
n_edges
,
self
.
hidden_size
)),
'z'
:
cuda
(
torch
.
zeros
(
n_edges
,
self
.
hidden_size
)),
mol_tree_batch
.
edata
.
update
(
'src_x'
:
cuda
(
torch
.
zeros
(
n_edges
,
self
.
hidden_size
)),
{
'dst_x'
:
cuda
(
torch
.
zeros
(
n_edges
,
self
.
hidden_size
)),
"s"
:
cuda
(
torch
.
zeros
(
n_edges
,
self
.
hidden_size
)),
'rm'
:
cuda
(
torch
.
zeros
(
n_edges
,
self
.
hidden_size
)),
"m"
:
cuda
(
torch
.
zeros
(
n_edges
,
self
.
hidden_size
)),
'accum_rm'
:
cuda
(
torch
.
zeros
(
n_edges
,
self
.
hidden_size
)),
"r"
:
cuda
(
torch
.
zeros
(
n_edges
,
self
.
hidden_size
)),
})
"z"
:
cuda
(
torch
.
zeros
(
n_edges
,
self
.
hidden_size
)),
"src_x"
:
cuda
(
torch
.
zeros
(
n_edges
,
self
.
hidden_size
)),
"dst_x"
:
cuda
(
torch
.
zeros
(
n_edges
,
self
.
hidden_size
)),
"rm"
:
cuda
(
torch
.
zeros
(
n_edges
,
self
.
hidden_size
)),
"accum_rm"
:
cuda
(
torch
.
zeros
(
n_edges
,
self
.
hidden_size
)),
}
)
mol_tree_batch
.
apply_edges
(
mol_tree_batch
.
apply_edges
(
func
=
lambda
edges
:
{
'src_x'
:
edges
.
src
[
'x'
],
'dst_x'
:
edges
.
dst
[
'x'
]},
func
=
lambda
edges
:
{
"src_x"
:
edges
.
src
[
"x"
],
"dst_x"
:
edges
.
dst
[
"x"
],
},
)
)
# input tensors for stop prediction (p) and label prediction (q)
# input tensors for stop prediction (p) and label prediction (q)
...
@@ -142,16 +166,16 @@ class DGLJTNNDecoder(nn.Module):
...
@@ -142,16 +166,16 @@ class DGLJTNNDecoder(nn.Module):
q_targets
=
[]
q_targets
=
[]
# Predict root
# Predict root
mol_tree_batch
.
pull
(
root_ids
,
DGLF
.
copy_e
(
'm'
,
'm'
),
DGLF
.
sum
(
'm'
,
'h'
))
mol_tree_batch
.
pull
(
root_ids
,
DGLF
.
copy_e
(
"m"
,
"m"
),
DGLF
.
sum
(
"m"
,
"h"
))
mol_tree_batch
.
apply_nodes
(
dec_tree_node_update
,
v
=
root_ids
)
mol_tree_batch
.
apply_nodes
(
dec_tree_node_update
,
v
=
root_ids
)
# Extract hidden states and store them for stop/label prediction
# Extract hidden states and store them for stop/label prediction
h
=
mol_tree_batch
.
nodes
[
root_ids
].
data
[
'h'
]
h
=
mol_tree_batch
.
nodes
[
root_ids
].
data
[
"h"
]
x
=
mol_tree_batch
.
nodes
[
root_ids
].
data
[
'x'
]
x
=
mol_tree_batch
.
nodes
[
root_ids
].
data
[
"x"
]
p_inputs
.
append
(
torch
.
cat
([
x
,
h
,
tree_vec
],
1
))
p_inputs
.
append
(
torch
.
cat
([
x
,
h
,
tree_vec
],
1
))
# If the out degree is 0 we don't generate any edges at all
# If the out degree is 0 we don't generate any edges at all
root_out_degrees
=
mol_tree_batch
.
out_degrees
(
root_ids
)
root_out_degrees
=
mol_tree_batch
.
out_degrees
(
root_ids
)
q_inputs
.
append
(
torch
.
cat
([
h
,
tree_vec
],
1
))
q_inputs
.
append
(
torch
.
cat
([
h
,
tree_vec
],
1
))
q_targets
.
append
(
mol_tree_batch
.
nodes
[
root_ids
].
data
[
'
wid
'
])
q_targets
.
append
(
mol_tree_batch
.
nodes
[
root_ids
].
data
[
"
wid
"
])
# Traverse the tree and predict on children
# Traverse the tree and predict on children
for
eid
,
p
in
dfs_order
(
mol_tree_batch
,
root_ids
):
for
eid
,
p
in
dfs_order
(
mol_tree_batch
,
root_ids
):
...
@@ -160,29 +184,35 @@ class DGLJTNNDecoder(nn.Module):
...
@@ -160,29 +184,35 @@ class DGLJTNNDecoder(nn.Module):
u
,
v
=
mol_tree_batch
.
find_edges
(
eid
)
u
,
v
=
mol_tree_batch
.
find_edges
(
eid
)
p_target_list
=
torch
.
zeros_like
(
root_out_degrees
)
p_target_list
=
torch
.
zeros_like
(
root_out_degrees
)
p_target_list
[
root_out_degrees
>
0
]
=
(
1
-
p
)
p_target_list
[
root_out_degrees
>
0
]
=
1
-
p
p_target_list
=
p_target_list
[
root_out_degrees
>=
0
]
p_target_list
=
p_target_list
[
root_out_degrees
>=
0
]
p_targets
.
append
(
torch
.
tensor
(
p_target_list
))
p_targets
.
append
(
torch
.
tensor
(
p_target_list
))
root_out_degrees
-=
(
root_out_degrees
==
0
).
long
()
root_out_degrees
-=
(
root_out_degrees
==
0
).
long
()
root_out_degrees
-=
torch
.
tensor
(
np
.
isin
(
root_ids
,
v
.
cpu
().
numpy
())).
to
(
root_out_degrees
)
root_out_degrees
-=
torch
.
tensor
(
np
.
isin
(
root_ids
,
v
.
cpu
().
numpy
())
).
to
(
root_out_degrees
)
mol_tree_batch_lg
.
ndata
.
update
(
mol_tree_batch
.
edata
)
mol_tree_batch_lg
.
ndata
.
update
(
mol_tree_batch
.
edata
)
mol_tree_batch_lg
.
pull
(
eid
,
DGLF
.
copy_u
(
'm'
,
'm'
),
DGLF
.
sum
(
'm'
,
's'
))
mol_tree_batch_lg
.
pull
(
mol_tree_batch_lg
.
pull
(
eid
,
DGLF
.
copy_u
(
'rm'
,
'rm'
),
DGLF
.
sum
(
'rm'
,
'accum_rm'
))
eid
,
DGLF
.
copy_u
(
"m"
,
"m"
),
DGLF
.
sum
(
"m"
,
"s"
)
)
mol_tree_batch_lg
.
pull
(
eid
,
DGLF
.
copy_u
(
"rm"
,
"rm"
),
DGLF
.
sum
(
"rm"
,
"accum_rm"
)
)
mol_tree_batch_lg
.
apply_nodes
(
self
.
dec_tree_edge_update
,
v
=
eid
)
mol_tree_batch_lg
.
apply_nodes
(
self
.
dec_tree_edge_update
,
v
=
eid
)
mol_tree_batch
.
edata
.
update
(
mol_tree_batch_lg
.
ndata
)
mol_tree_batch
.
edata
.
update
(
mol_tree_batch_lg
.
ndata
)
is_new
=
mol_tree_batch
.
nodes
[
v
].
data
[
'
new
'
]
is_new
=
mol_tree_batch
.
nodes
[
v
].
data
[
"
new
"
]
mol_tree_batch
.
pull
(
v
,
DGLF
.
copy_e
(
'm'
,
'm'
),
DGLF
.
sum
(
'm'
,
'h'
))
mol_tree_batch
.
pull
(
v
,
DGLF
.
copy_e
(
"m"
,
"m"
),
DGLF
.
sum
(
"m"
,
"h"
))
mol_tree_batch
.
apply_nodes
(
dec_tree_node_update
,
v
=
v
)
mol_tree_batch
.
apply_nodes
(
dec_tree_node_update
,
v
=
v
)
# Extract
# Extract
n_repr
=
mol_tree_batch
.
nodes
[
v
].
data
n_repr
=
mol_tree_batch
.
nodes
[
v
].
data
h
=
n_repr
[
'h'
]
h
=
n_repr
[
"h"
]
x
=
n_repr
[
'x'
]
x
=
n_repr
[
"x"
]
tree_vec_set
=
tree_vec
[
root_out_degrees
>=
0
]
tree_vec_set
=
tree_vec
[
root_out_degrees
>=
0
]
wid
=
n_repr
[
'
wid
'
]
wid
=
n_repr
[
"
wid
"
]
p_inputs
.
append
(
torch
.
cat
([
x
,
h
,
tree_vec_set
],
1
))
p_inputs
.
append
(
torch
.
cat
([
x
,
h
,
tree_vec_set
],
1
))
# Only newly generated nodes are needed for label prediction
# Only newly generated nodes are needed for label prediction
# NOTE: The following works since the uncomputed messages are zeros.
# NOTE: The following works since the uncomputed messages are zeros.
...
@@ -192,10 +222,13 @@ class DGLJTNNDecoder(nn.Module):
...
@@ -192,10 +222,13 @@ class DGLJTNNDecoder(nn.Module):
if
q_input
.
shape
[
0
]
>
0
:
if
q_input
.
shape
[
0
]
>
0
:
q_inputs
.
append
(
q_input
)
q_inputs
.
append
(
q_input
)
q_targets
.
append
(
q_target
)
q_targets
.
append
(
q_target
)
p_targets
.
append
(
torch
.
zeros
(
p_targets
.
append
(
(
root_out_degrees
==
0
).
sum
(),
torch
.
zeros
(
device
=
root_out_degrees
.
device
,
(
root_out_degrees
==
0
).
sum
(),
dtype
=
torch
.
int64
))
device
=
root_out_degrees
.
device
,
dtype
=
torch
.
int64
,
)
)
# Batch compute the stop/label prediction losses
# Batch compute the stop/label prediction losses
p_inputs
=
torch
.
cat
(
p_inputs
,
0
)
p_inputs
=
torch
.
cat
(
p_inputs
,
0
)
...
@@ -206,9 +239,12 @@ class DGLJTNNDecoder(nn.Module):
...
@@ -206,9 +239,12 @@ class DGLJTNNDecoder(nn.Module):
q
=
self
.
W_o
(
torch
.
relu
(
self
.
W
(
q_inputs
)))
q
=
self
.
W_o
(
torch
.
relu
(
self
.
W
(
q_inputs
)))
p
=
self
.
U_s
(
torch
.
relu
(
self
.
U
(
p_inputs
)))[:,
0
]
p
=
self
.
U_s
(
torch
.
relu
(
self
.
U
(
p_inputs
)))[:,
0
]
p_loss
=
F
.
binary_cross_entropy_with_logits
(
p_loss
=
(
p
,
p_targets
.
float
(),
size_average
=
False
F
.
binary_cross_entropy_with_logits
(
)
/
n_trees
p
,
p_targets
.
float
(),
size_average
=
False
)
/
n_trees
)
q_loss
=
F
.
cross_entropy
(
q
,
q_targets
,
size_average
=
False
)
/
n_trees
q_loss
=
F
.
cross_entropy
(
q
,
q_targets
,
size_average
=
False
)
/
n_trees
p_acc
=
((
p
>
0
).
long
()
==
p_targets
).
sum
().
float
()
/
p_targets
.
shape
[
0
]
p_acc
=
((
p
>
0
).
long
()
==
p_targets
).
sum
().
float
()
/
p_targets
.
shape
[
0
]
q_acc
=
(
q
.
max
(
1
)[
1
]
==
q_targets
).
float
().
sum
()
/
q_targets
.
shape
[
0
]
q_acc
=
(
q
.
max
(
1
)[
1
]
==
q_targets
).
float
().
sum
()
/
q_targets
.
shape
[
0
]
...
@@ -237,13 +273,14 @@ class DGLJTNNDecoder(nn.Module):
...
@@ -237,13 +273,14 @@ class DGLJTNNDecoder(nn.Module):
_
,
root_wid
=
torch
.
max
(
root_score
,
1
)
_
,
root_wid
=
torch
.
max
(
root_score
,
1
)
root_wid
=
root_wid
.
view
(
1
)
root_wid
=
root_wid
.
view
(
1
)
mol_tree_graph
.
add_nodes
(
1
)
# root
mol_tree_graph
.
add_nodes
(
1
)
# root
mol_tree_graph
.
ndata
[
'
wid
'
]
=
root_wid
mol_tree_graph
.
ndata
[
"
wid
"
]
=
root_wid
mol_tree_graph
.
ndata
[
'x'
]
=
self
.
embedding
(
root_wid
)
mol_tree_graph
.
ndata
[
"x"
]
=
self
.
embedding
(
root_wid
)
mol_tree_graph
.
ndata
[
'h'
]
=
init_hidden
mol_tree_graph
.
ndata
[
"h"
]
=
init_hidden
mol_tree_graph
.
ndata
[
'
fail
'
]
=
cuda
(
torch
.
tensor
([
0
]))
mol_tree_graph
.
ndata
[
"
fail
"
]
=
cuda
(
torch
.
tensor
([
0
]))
mol_tree
.
nodes_dict
[
0
]
=
root_node_dict
=
create_node_dict
(
mol_tree
.
nodes_dict
[
0
]
=
root_node_dict
=
create_node_dict
(
self
.
vocab
.
get_smiles
(
root_wid
))
self
.
vocab
.
get_smiles
(
root_wid
)
)
stack
,
trace
=
[],
[]
stack
,
trace
=
[],
[]
stack
.
append
((
0
,
self
.
vocab
.
get_slots
(
root_wid
)))
stack
.
append
((
0
,
self
.
vocab
.
get_slots
(
root_wid
)))
...
@@ -256,13 +293,13 @@ class DGLJTNNDecoder(nn.Module):
...
@@ -256,13 +293,13 @@ class DGLJTNNDecoder(nn.Module):
for
step
in
range
(
MAX_DECODE_LEN
):
for
step
in
range
(
MAX_DECODE_LEN
):
u
,
u_slots
=
stack
[
-
1
]
u
,
u_slots
=
stack
[
-
1
]
x
=
mol_tree_graph
.
ndata
[
'x'
][
u
:
u
+
1
]
x
=
mol_tree_graph
.
ndata
[
"x"
][
u
:
u
+
1
]
h
=
mol_tree_graph
.
ndata
[
'h'
][
u
:
u
+
1
]
h
=
mol_tree_graph
.
ndata
[
"h"
][
u
:
u
+
1
]
# Predict stop
# Predict stop
p_input
=
torch
.
cat
([
x
,
h
,
mol_vec
],
1
)
p_input
=
torch
.
cat
([
x
,
h
,
mol_vec
],
1
)
p_score
=
torch
.
sigmoid
(
self
.
U_s
(
torch
.
relu
(
self
.
U
(
p_input
))))
p_score
=
torch
.
sigmoid
(
self
.
U_s
(
torch
.
relu
(
self
.
U
(
p_input
))))
backtrack
=
(
p_score
.
item
()
<
0.5
)
backtrack
=
p_score
.
item
()
<
0.5
if
not
backtrack
:
if
not
backtrack
:
# Predict next clique. Note that the prediction may fail due
# Predict next clique. Note that the prediction may fail due
...
@@ -273,49 +310,61 @@ class DGLJTNNDecoder(nn.Module):
...
@@ -273,49 +310,61 @@ class DGLJTNNDecoder(nn.Module):
mol_tree_graph
.
add_edges
(
u
,
v
)
mol_tree_graph
.
add_edges
(
u
,
v
)
uv
=
new_edge_id
uv
=
new_edge_id
new_edge_id
+=
1
new_edge_id
+=
1
if
first
:
if
first
:
mol_tree_graph
.
edata
.
update
({
mol_tree_graph
.
edata
.
update
(
's'
:
cuda
(
torch
.
zeros
(
1
,
self
.
hidden_size
)),
{
'm'
:
cuda
(
torch
.
zeros
(
1
,
self
.
hidden_size
)),
"s"
:
cuda
(
torch
.
zeros
(
1
,
self
.
hidden_size
)),
'r'
:
cuda
(
torch
.
zeros
(
1
,
self
.
hidden_size
)),
"m"
:
cuda
(
torch
.
zeros
(
1
,
self
.
hidden_size
)),
'z'
:
cuda
(
torch
.
zeros
(
1
,
self
.
hidden_size
)),
"r"
:
cuda
(
torch
.
zeros
(
1
,
self
.
hidden_size
)),
'src_x'
:
cuda
(
torch
.
zeros
(
1
,
self
.
hidden_size
)),
"z"
:
cuda
(
torch
.
zeros
(
1
,
self
.
hidden_size
)),
'dst_x'
:
cuda
(
torch
.
zeros
(
1
,
self
.
hidden_size
)),
"src_x"
:
cuda
(
torch
.
zeros
(
1
,
self
.
hidden_size
)),
'rm'
:
cuda
(
torch
.
zeros
(
1
,
self
.
hidden_size
)),
"dst_x"
:
cuda
(
torch
.
zeros
(
1
,
self
.
hidden_size
)),
'accum_rm'
:
cuda
(
torch
.
zeros
(
1
,
self
.
hidden_size
)),
"rm"
:
cuda
(
torch
.
zeros
(
1
,
self
.
hidden_size
)),
})
"accum_rm"
:
cuda
(
torch
.
zeros
(
1
,
self
.
hidden_size
)),
}
)
first
=
False
first
=
False
mol_tree_graph
.
edata
[
'
src_x
'
][
uv
]
=
mol_tree_graph
.
ndata
[
'x'
][
u
]
mol_tree_graph
.
edata
[
"
src_x
"
][
uv
]
=
mol_tree_graph
.
ndata
[
"x"
][
u
]
# keeping dst_x 0 is fine as h on new edge doesn't depend on that.
# keeping dst_x 0 is fine as h on new edge doesn't depend on that.
# DGL doesn't dynamically maintain a line graph.
# DGL doesn't dynamically maintain a line graph.
mol_tree_graph_lg
=
line_graph
(
mol_tree_graph
,
backtracking
=
False
,
shared
=
True
)
mol_tree_graph_lg
=
line_graph
(
mol_tree_graph
,
backtracking
=
False
,
shared
=
True
)
mol_tree_graph_lg
.
pull
(
mol_tree_graph_lg
.
pull
(
uv
,
uv
,
DGLF
.
copy_u
(
"m"
,
"m"
),
DGLF
.
sum
(
"m"
,
"s"
)
DGLF
.
copy_u
(
'm'
,
'm'
),
)
DGLF
.
sum
(
'm'
,
's'
))
mol_tree_graph_lg
.
pull
(
mol_tree_graph_lg
.
pull
(
uv
,
uv
,
DGLF
.
copy_u
(
"rm"
,
"rm"
),
DGLF
.
sum
(
"rm"
,
"accum_rm"
)
DGLF
.
copy_u
(
'rm'
,
'rm'
),
)
DGLF
.
sum
(
'rm'
,
'accum_rm'
))
mol_tree_graph_lg
.
apply_nodes
(
mol_tree_graph_lg
.
apply_nodes
(
self
.
dec_tree_edge_update
.
update_zm
,
v
=
uv
)
self
.
dec_tree_edge_update
.
update_zm
,
v
=
uv
)
mol_tree_graph
.
edata
.
update
(
mol_tree_graph_lg
.
ndata
)
mol_tree_graph
.
edata
.
update
(
mol_tree_graph_lg
.
ndata
)
mol_tree_graph
.
pull
(
v
,
DGLF
.
copy_e
(
'm'
,
'm'
),
DGLF
.
sum
(
'm'
,
'h'
))
mol_tree_graph
.
pull
(
v
,
DGLF
.
copy_e
(
"m"
,
"m"
),
DGLF
.
sum
(
"m"
,
"h"
)
)
h_v
=
mol_tree_graph
.
ndata
[
'h'
][
v
:
v
+
1
]
h_v
=
mol_tree_graph
.
ndata
[
"h"
][
v
:
v
+
1
]
q_input
=
torch
.
cat
([
h_v
,
mol_vec
],
1
)
q_input
=
torch
.
cat
([
h_v
,
mol_vec
],
1
)
q_score
=
torch
.
softmax
(
self
.
W_o
(
torch
.
relu
(
self
.
W
(
q_input
))),
-
1
)
q_score
=
torch
.
softmax
(
self
.
W_o
(
torch
.
relu
(
self
.
W
(
q_input
))),
-
1
)
_
,
sort_wid
=
torch
.
sort
(
q_score
,
1
,
descending
=
True
)
_
,
sort_wid
=
torch
.
sort
(
q_score
,
1
,
descending
=
True
)
sort_wid
=
sort_wid
.
squeeze
()
sort_wid
=
sort_wid
.
squeeze
()
next_wid
=
None
next_wid
=
None
for
wid
in
sort_wid
.
tolist
()[:
5
]:
for
wid
in
sort_wid
.
tolist
()[:
5
]:
slots
=
self
.
vocab
.
get_slots
(
wid
)
slots
=
self
.
vocab
.
get_slots
(
wid
)
cand_node_dict
=
create_node_dict
(
self
.
vocab
.
get_smiles
(
wid
))
cand_node_dict
=
create_node_dict
(
if
(
have_slots
(
u_slots
,
slots
)
and
can_assemble
(
mol_tree
,
u
,
cand_node_dict
)):
self
.
vocab
.
get_smiles
(
wid
)
)
if
have_slots
(
u_slots
,
slots
)
and
can_assemble
(
mol_tree
,
u
,
cand_node_dict
):
next_wid
=
wid
next_wid
=
wid
next_slots
=
slots
next_slots
=
slots
next_node_dict
=
cand_node_dict
next_node_dict
=
cand_node_dict
...
@@ -324,44 +373,59 @@ class DGLJTNNDecoder(nn.Module):
...
@@ -324,44 +373,59 @@ class DGLJTNNDecoder(nn.Module):
if
next_wid
is
None
:
if
next_wid
is
None
:
# Failed adding an actual children; v is a spurious node
# Failed adding an actual children; v is a spurious node
# and we mark it.
# and we mark it.
mol_tree_graph
.
ndata
[
'
fail
'
][
v
]
=
cuda
(
torch
.
tensor
([
1
]))
mol_tree_graph
.
ndata
[
"
fail
"
][
v
]
=
cuda
(
torch
.
tensor
([
1
]))
backtrack
=
True
backtrack
=
True
else
:
else
:
next_wid
=
cuda
(
torch
.
tensor
([
next_wid
]))
next_wid
=
cuda
(
torch
.
tensor
([
next_wid
]))
mol_tree_graph
.
ndata
[
'
wid
'
][
v
]
=
next_wid
mol_tree_graph
.
ndata
[
"
wid
"
][
v
]
=
next_wid
mol_tree_graph
.
ndata
[
'x'
][
v
]
=
self
.
embedding
(
next_wid
)
mol_tree_graph
.
ndata
[
"x"
][
v
]
=
self
.
embedding
(
next_wid
)
mol_tree
.
nodes_dict
[
v
]
=
next_node_dict
mol_tree
.
nodes_dict
[
v
]
=
next_node_dict
all_nodes
[
v
]
=
next_node_dict
all_nodes
[
v
]
=
next_node_dict
stack
.
append
((
v
,
next_slots
))
stack
.
append
((
v
,
next_slots
))
mol_tree_graph
.
add_edges
(
v
,
u
)
mol_tree_graph
.
add_edges
(
v
,
u
)
vu
=
new_edge_id
vu
=
new_edge_id
new_edge_id
+=
1
new_edge_id
+=
1
mol_tree_graph
.
edata
[
'dst_x'
][
uv
]
=
mol_tree_graph
.
ndata
[
'x'
][
v
]
mol_tree_graph
.
edata
[
"dst_x"
][
uv
]
=
mol_tree_graph
.
ndata
[
mol_tree_graph
.
edata
[
'src_x'
][
vu
]
=
mol_tree_graph
.
ndata
[
'x'
][
v
]
"x"
mol_tree_graph
.
edata
[
'dst_x'
][
vu
]
=
mol_tree_graph
.
ndata
[
'x'
][
u
]
][
v
]
mol_tree_graph
.
edata
[
"src_x"
][
vu
]
=
mol_tree_graph
.
ndata
[
"x"
][
v
]
mol_tree_graph
.
edata
[
"dst_x"
][
vu
]
=
mol_tree_graph
.
ndata
[
"x"
][
u
]
# DGL doesn't dynamically maintain a line graph.
# DGL doesn't dynamically maintain a line graph.
mol_tree_graph_lg
=
line_graph
(
mol_tree_graph
,
backtracking
=
False
,
shared
=
True
)
mol_tree_graph_lg
=
line_graph
(
mol_tree_graph
,
backtracking
=
False
,
shared
=
True
)
mol_tree_graph_lg
.
apply_nodes
(
mol_tree_graph_lg
.
apply_nodes
(
self
.
dec_tree_edge_update
.
update_r
,
self
.
dec_tree_edge_update
.
update_r
,
uv
uv
)
)
mol_tree_graph
.
edata
.
update
(
mol_tree_graph_lg
.
ndata
)
mol_tree_graph
.
edata
.
update
(
mol_tree_graph_lg
.
ndata
)
if
backtrack
:
if
backtrack
:
if
len
(
stack
)
==
1
:
if
len
(
stack
)
==
1
:
break
# At root, terminate
break
# At root, terminate
pu
,
_
=
stack
[
-
2
]
pu
,
_
=
stack
[
-
2
]
u_pu
=
mol_tree_graph
.
edge_ids
(
u
,
pu
)
u_pu
=
mol_tree_graph
.
edge_ids
(
u
,
pu
)
mol_tree_graph_lg
.
pull
(
u_pu
,
DGLF
.
copy_u
(
'm'
,
'm'
),
DGLF
.
sum
(
'm'
,
's'
))
mol_tree_graph_lg
.
pull
(
mol_tree_graph_lg
.
pull
(
u_pu
,
DGLF
.
copy_u
(
'rm'
,
'rm'
),
DGLF
.
sum
(
'rm'
,
'accum_rm'
))
u_pu
,
DGLF
.
copy_u
(
"m"
,
"m"
),
DGLF
.
sum
(
"m"
,
"s"
)
)
mol_tree_graph_lg
.
pull
(
u_pu
,
DGLF
.
copy_u
(
"rm"
,
"rm"
),
DGLF
.
sum
(
"rm"
,
"accum_rm"
)
)
mol_tree_graph_lg
.
apply_nodes
(
self
.
dec_tree_edge_update
,
v
=
u_pu
)
mol_tree_graph_lg
.
apply_nodes
(
self
.
dec_tree_edge_update
,
v
=
u_pu
)
mol_tree_graph
.
edata
.
update
(
mol_tree_graph_lg
.
ndata
)
mol_tree_graph
.
edata
.
update
(
mol_tree_graph_lg
.
ndata
)
mol_tree_graph
.
pull
(
pu
,
DGLF
.
copy_e
(
'm'
,
'm'
),
DGLF
.
sum
(
'm'
,
'h'
))
mol_tree_graph
.
pull
(
pu
,
DGLF
.
copy_e
(
"m"
,
"m"
),
DGLF
.
sum
(
"m"
,
"h"
)
)
stack
.
pop
()
stack
.
pop
()
effective_nodes
=
mol_tree_graph
.
filter_nodes
(
lambda
nodes
:
nodes
.
data
[
'fail'
]
!=
1
)
effective_nodes
=
mol_tree_graph
.
filter_nodes
(
lambda
nodes
:
nodes
.
data
[
"fail"
]
!=
1
)
effective_nodes
,
_
=
torch
.
sort
(
effective_nodes
)
effective_nodes
,
_
=
torch
.
sort
(
effective_nodes
)
return
mol_tree
,
all_nodes
,
effective_nodes
return
mol_tree
,
all_nodes
,
effective_nodes
examples/pytorch/jtnn/jtnn/jtnn_enc.py
View file @
704bcaf6
import
dgl.function
as
DGLF
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
dgl.function
as
DGLF
from
dgl
import
batch
,
bfs_edges_generator
,
line_graph
from
dgl
import
batch
,
bfs_edges_generator
,
line_graph
from
.nnutils
import
GRUUpdate
,
cuda
,
tocpu
from
.nnutils
import
cuda
,
GRUUpdate
,
tocpu
MAX_NB
=
8
MAX_NB
=
8
...
...
examples/pytorch/jtnn/jtnn/jtnn_vae.py
View file @
704bcaf6
...
@@ -14,12 +14,10 @@ from .chemutils import (
...
@@ -14,12 +14,10 @@ from .chemutils import (
enum_assemble_nx
,
enum_assemble_nx
,
set_atommap
,
set_atommap
,
)
)
from
.jtmpn
import
DGLJTMPN
from
.jtmpn
import
DGLJTMPN
,
mol2dgl_single
as
mol2dgl_dec
from
.jtmpn
import
mol2dgl_single
as
mol2dgl_dec
from
.jtnn_dec
import
DGLJTNNDecoder
from
.jtnn_dec
import
DGLJTNNDecoder
from
.jtnn_enc
import
DGLJTNNEncoder
from
.jtnn_enc
import
DGLJTNNEncoder
from
.mpn
import
DGLMPN
from
.mpn
import
DGLMPN
,
mol2dgl_single
as
mol2dgl_enc
from
.mpn
import
mol2dgl_single
as
mol2dgl_enc
from
.nnutils
import
cuda
from
.nnutils
import
cuda
...
...
examples/pytorch/jtnn/jtnn/mol_tree_nx.py
View file @
704bcaf6
import
dgl
import
numpy
as
np
import
numpy
as
np
import
rdkit.Chem
as
Chem
import
rdkit.Chem
as
Chem
import
dgl
from
.chemutils
import
(
from
.chemutils
import
(
decode_stereo
,
decode_stereo
,
enum_assemble_nx
,
enum_assemble_nx
,
...
...
examples/pytorch/jtnn/jtnn/mpn.py
View file @
704bcaf6
import
dgl
import
dgl.function
as
DGLF
import
rdkit.Chem
as
Chem
import
rdkit.Chem
as
Chem
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
import
dgl
import
dgl.function
as
DGLF
from
dgl
import
line_graph
,
mean_nodes
from
dgl
import
line_graph
,
mean_nodes
from
.chemutils
import
get_mol
from
.chemutils
import
get_mol
...
...
examples/pytorch/jtnn/jtnn/nnutils.py
View file @
704bcaf6
import
os
import
os
import
dgl
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
dgl
def
cuda
(
x
):
def
cuda
(
x
):
if
torch
.
cuda
.
is_available
()
and
not
os
.
getenv
(
"NOCUDA"
,
None
):
if
torch
.
cuda
.
is_available
()
and
not
os
.
getenv
(
"NOCUDA"
,
None
):
...
...
examples/pytorch/label_propagation/main.py
View file @
704bcaf6
import
argparse
import
argparse
import
torch
import
dgl
import
dgl
import
torch
from
dgl.data
import
CiteseerGraphDataset
,
CoraGraphDataset
,
PubmedGraphDataset
from
dgl.data
import
CiteseerGraphDataset
,
CoraGraphDataset
,
PubmedGraphDataset
from
dgl.nn
import
LabelPropagation
from
dgl.nn
import
LabelPropagation
...
...
examples/pytorch/lda/example_20newsgroups.py
View file @
704bcaf6
...
@@ -20,18 +20,18 @@
...
@@ -20,18 +20,18 @@
import
warnings
import
warnings
from
time
import
time
from
time
import
time
import
dgl
import
matplotlib.pyplot
as
plt
import
matplotlib.pyplot
as
plt
import
numpy
as
np
import
numpy
as
np
import
scipy.sparse
as
ss
import
scipy.sparse
as
ss
import
torch
import
torch
from
dgl
import
function
as
fn
from
lda_model
import
LatentDirichletAllocation
as
LDAModel
from
lda_model
import
LatentDirichletAllocation
as
LDAModel
from
sklearn.datasets
import
fetch_20newsgroups
from
sklearn.datasets
import
fetch_20newsgroups
from
sklearn.decomposition
import
NMF
,
LatentDirichletAllocation
from
sklearn.decomposition
import
LatentDirichletAllocation
,
NMF
from
sklearn.feature_extraction.text
import
CountVectorizer
,
TfidfVectorizer
from
sklearn.feature_extraction.text
import
CountVectorizer
,
TfidfVectorizer
import
dgl
from
dgl
import
function
as
fn
n_samples
=
2000
n_samples
=
2000
n_features
=
1000
n_features
=
1000
n_components
=
10
n_components
=
10
...
...
examples/pytorch/lda/lda_model.py
View file @
704bcaf6
...
@@ -23,12 +23,12 @@ import io
...
@@ -23,12 +23,12 @@ import io
import
os
import
os
import
warnings
import
warnings
import
dgl
import
numpy
as
np
import
numpy
as
np
import
scipy
as
sp
import
scipy
as
sp
import
torch
import
torch
import
dgl
try
:
try
:
from
functools
import
cached_property
from
functools
import
cached_property
except
ImportError
:
except
ImportError
:
...
...
examples/pytorch/line_graph/gnn.py
View file @
704bcaf6
import
copy
import
copy
import
itertools
import
itertools
import
dgl
import
dgl
import
dgl.function
as
fn
import
dgl.function
as
fn
import
networkx
as
nx
import
networkx
as
nx
import
numpy
as
np
import
torch
as
th
import
torch
as
th
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
import
numpy
as
np
class
GNNModule
(
nn
.
Module
):
class
GNNModule
(
nn
.
Module
):
def
__init__
(
self
,
in_feats
,
out_feats
,
radius
):
def
__init__
(
self
,
in_feats
,
out_feats
,
radius
):
...
@@ -15,14 +17,22 @@ class GNNModule(nn.Module):
...
@@ -15,14 +17,22 @@ class GNNModule(nn.Module):
self
.
radius
=
radius
self
.
radius
=
radius
new_linear
=
lambda
:
nn
.
Linear
(
in_feats
,
out_feats
)
new_linear
=
lambda
:
nn
.
Linear
(
in_feats
,
out_feats
)
new_linear_list
=
lambda
:
nn
.
ModuleList
([
new_linear
()
for
i
in
range
(
radius
)])
new_linear_list
=
lambda
:
nn
.
ModuleList
(
[
new_linear
()
for
i
in
range
(
radius
)]
)
self
.
theta_x
,
self
.
theta_deg
,
self
.
theta_y
=
\
self
.
theta_x
,
self
.
theta_deg
,
self
.
theta_y
=
(
new_linear
(),
new_linear
(),
new_linear
()
new_linear
(),
new_linear
(),
new_linear
(),
)
self
.
theta_list
=
new_linear_list
()
self
.
theta_list
=
new_linear_list
()
self
.
gamma_y
,
self
.
gamma_deg
,
self
.
gamma_x
=
\
self
.
gamma_y
,
self
.
gamma_deg
,
self
.
gamma_x
=
(
new_linear
(),
new_linear
(),
new_linear
()
new_linear
(),
new_linear
(),
new_linear
(),
)
self
.
gamma_list
=
new_linear_list
()
self
.
gamma_list
=
new_linear_list
()
self
.
bn_x
=
nn
.
BatchNorm1d
(
out_feats
)
self
.
bn_x
=
nn
.
BatchNorm1d
(
out_feats
)
...
@@ -30,43 +40,61 @@ class GNNModule(nn.Module):
...
@@ -30,43 +40,61 @@ class GNNModule(nn.Module):
def
aggregate
(
self
,
g
,
z
):
def
aggregate
(
self
,
g
,
z
):
z_list
=
[]
z_list
=
[]
g
.
ndata
[
'z'
]
=
z
g
.
ndata
[
"z"
]
=
z
g
.
update_all
(
fn
.
copy_u
(
u
=
'z'
,
out
=
'm'
),
fn
.
sum
(
msg
=
'm'
,
out
=
'z'
))
g
.
update_all
(
fn
.
copy_u
(
u
=
"z"
,
out
=
"m"
),
fn
.
sum
(
msg
=
"m"
,
out
=
"z"
))
z_list
.
append
(
g
.
ndata
[
'z'
])
z_list
.
append
(
g
.
ndata
[
"z"
])
for
i
in
range
(
self
.
radius
-
1
):
for
i
in
range
(
self
.
radius
-
1
):
for
j
in
range
(
2
**
i
):
for
j
in
range
(
2
**
i
):
g
.
update_all
(
fn
.
copy_u
(
u
=
'z'
,
out
=
'm'
),
fn
.
sum
(
msg
=
'm'
,
out
=
'z'
))
g
.
update_all
(
z_list
.
append
(
g
.
ndata
[
'z'
])
fn
.
copy_u
(
u
=
"z"
,
out
=
"m"
),
fn
.
sum
(
msg
=
"m"
,
out
=
"z"
)
)
z_list
.
append
(
g
.
ndata
[
"z"
])
return
z_list
return
z_list
def
forward
(
self
,
g
,
lg
,
x
,
y
,
deg_g
,
deg_lg
,
pm_pd
):
def
forward
(
self
,
g
,
lg
,
x
,
y
,
deg_g
,
deg_lg
,
pm_pd
):
pmpd_x
=
F
.
embedding
(
pm_pd
,
x
)
pmpd_x
=
F
.
embedding
(
pm_pd
,
x
)
sum_x
=
sum
(
theta
(
z
)
for
theta
,
z
in
zip
(
self
.
theta_list
,
self
.
aggregate
(
g
,
x
)))
sum_x
=
sum
(
theta
(
z
)
for
theta
,
z
in
zip
(
self
.
theta_list
,
self
.
aggregate
(
g
,
x
))
)
g
.
edata
[
'y'
]
=
y
g
.
edata
[
"y"
]
=
y
g
.
update_all
(
fn
.
copy_e
(
e
=
'y'
,
out
=
'm'
),
fn
.
sum
(
'm'
,
'
pmpd_y
'
))
g
.
update_all
(
fn
.
copy_e
(
e
=
"y"
,
out
=
"m"
),
fn
.
sum
(
"m"
,
"
pmpd_y
"
))
pmpd_y
=
g
.
ndata
.
pop
(
'
pmpd_y
'
)
pmpd_y
=
g
.
ndata
.
pop
(
"
pmpd_y
"
)
x
=
self
.
theta_x
(
x
)
+
self
.
theta_deg
(
deg_g
*
x
)
+
sum_x
+
self
.
theta_y
(
pmpd_y
)
x
=
(
self
.
theta_x
(
x
)
+
self
.
theta_deg
(
deg_g
*
x
)
+
sum_x
+
self
.
theta_y
(
pmpd_y
)
)
n
=
self
.
out_feats
//
2
n
=
self
.
out_feats
//
2
x
=
th
.
cat
([
x
[:,
:
n
],
F
.
relu
(
x
[:,
n
:])],
1
)
x
=
th
.
cat
([
x
[:,
:
n
],
F
.
relu
(
x
[:,
n
:])],
1
)
x
=
self
.
bn_x
(
x
)
x
=
self
.
bn_x
(
x
)
sum_y
=
sum
(
gamma
(
z
)
for
gamma
,
z
in
zip
(
self
.
gamma_list
,
self
.
aggregate
(
lg
,
y
)))
sum_y
=
sum
(
gamma
(
z
)
for
gamma
,
z
in
zip
(
self
.
gamma_list
,
self
.
aggregate
(
lg
,
y
))
)
y
=
self
.
gamma_y
(
y
)
+
self
.
gamma_deg
(
deg_lg
*
y
)
+
sum_y
+
self
.
gamma_x
(
pmpd_x
)
y
=
(
self
.
gamma_y
(
y
)
+
self
.
gamma_deg
(
deg_lg
*
y
)
+
sum_y
+
self
.
gamma_x
(
pmpd_x
)
)
y
=
th
.
cat
([
y
[:,
:
n
],
F
.
relu
(
y
[:,
n
:])],
1
)
y
=
th
.
cat
([
y
[:,
:
n
],
F
.
relu
(
y
[:,
n
:])],
1
)
y
=
self
.
bn_y
(
y
)
y
=
self
.
bn_y
(
y
)
return
x
,
y
return
x
,
y
class
GNN
(
nn
.
Module
):
class
GNN
(
nn
.
Module
):
def
__init__
(
self
,
feats
,
radius
,
n_classes
):
def
__init__
(
self
,
feats
,
radius
,
n_classes
):
super
(
GNN
,
self
).
__init__
()
super
(
GNN
,
self
).
__init__
()
self
.
linear
=
nn
.
Linear
(
feats
[
-
1
],
n_classes
)
self
.
linear
=
nn
.
Linear
(
feats
[
-
1
],
n_classes
)
self
.
module_list
=
nn
.
ModuleList
([
GNNModule
(
m
,
n
,
radius
)
self
.
module_list
=
nn
.
ModuleList
(
for
m
,
n
in
zip
(
feats
[:
-
1
],
feats
[
1
:])])
[
GNNModule
(
m
,
n
,
radius
)
for
m
,
n
in
zip
(
feats
[:
-
1
],
feats
[
1
:])]
)
def
forward
(
self
,
g
,
lg
,
deg_g
,
deg_lg
,
pm_pd
):
def
forward
(
self
,
g
,
lg
,
deg_g
,
deg_lg
,
pm_pd
):
x
,
y
=
deg_g
,
deg_lg
x
,
y
=
deg_g
,
deg_lg
...
...
examples/pytorch/line_graph/train.py
View file @
704bcaf6
...
@@ -16,9 +16,9 @@ import numpy as np
...
@@ -16,9 +16,9 @@ import numpy as np
import
torch
as
th
import
torch
as
th
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
import
torch.optim
as
optim
import
torch.optim
as
optim
from
torch.utils.data
import
DataLoader
from
dgl.data
import
SBMMixtureDataset
from
dgl.data
import
SBMMixtureDataset
from
torch.utils.data
import
DataLoader
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--batch-size"
,
type
=
int
,
help
=
"Batch size"
,
default
=
1
)
parser
.
add_argument
(
"--batch-size"
,
type
=
int
,
help
=
"Batch size"
,
default
=
1
)
...
...
examples/pytorch/metapath2vec/download.py
View file @
704bcaf6
import
os
import
os
import
torch
as
th
import
torch
as
th
import
torch.nn
as
nn
import
torch.nn
as
nn
import
tqdm
import
tqdm
...
@@ -20,35 +21,37 @@ class PBar(object):
...
@@ -20,35 +21,37 @@ class PBar(object):
class
AminerDataset
(
object
):
class
AminerDataset
(
object
):
"""
"""
Download Aminer Dataset from Amazon S3 bucket.
Download Aminer Dataset from Amazon S3 bucket.
"""
"""
def
__init__
(
self
,
path
):
self
.
url
=
'https://data.dgl.ai/dataset/aminer.zip'
def
__init__
(
self
,
path
):
self
.
url
=
"https://data.dgl.ai/dataset/aminer.zip"
if
not
os
.
path
.
exists
(
os
.
path
.
join
(
path
,
'
aminer.txt
'
)):
if
not
os
.
path
.
exists
(
os
.
path
.
join
(
path
,
"
aminer.txt
"
)):
print
(
'
File not found. Downloading from
'
,
self
.
url
)
print
(
"
File not found. Downloading from
"
,
self
.
url
)
self
.
_download_and_extract
(
path
,
'
aminer.zip
'
)
self
.
_download_and_extract
(
path
,
"
aminer.zip
"
)
self
.
fn
=
os
.
path
.
join
(
path
,
'
aminer.txt
'
)
self
.
fn
=
os
.
path
.
join
(
path
,
"
aminer.txt
"
)
def
_download_and_extract
(
self
,
path
,
filename
):
def
_download_and_extract
(
self
,
path
,
filename
):
import
shutil
,
zipfile
,
zlib
import
shutil
,
zipfile
,
zlib
from
tqdm
import
tqdm
import
urllib.request
import
urllib.request
from
tqdm
import
tqdm
fn
=
os
.
path
.
join
(
path
,
filename
)
fn
=
os
.
path
.
join
(
path
,
filename
)
with
PBar
()
as
pb
:
with
PBar
()
as
pb
:
urllib
.
request
.
urlretrieve
(
self
.
url
,
fn
,
pb
)
urllib
.
request
.
urlretrieve
(
self
.
url
,
fn
,
pb
)
print
(
'
Download finished. Unzipping the file...
'
)
print
(
"
Download finished. Unzipping the file...
"
)
with
zipfile
.
ZipFile
(
fn
)
as
zf
:
with
zipfile
.
ZipFile
(
fn
)
as
zf
:
zf
.
extractall
(
path
)
zf
.
extractall
(
path
)
print
(
'
Unzip finished.
'
)
print
(
"
Unzip finished.
"
)
class
CustomDataset
(
object
):
class
CustomDataset
(
object
):
"""
"""
Custom dataset generated by sampler.py (e.g. NetDBIS)
Custom dataset generated by sampler.py (e.g. NetDBIS)
"""
"""
def
__init__
(
self
,
path
):
def
__init__
(
self
,
path
):
self
.
fn
=
path
self
.
fn
=
path
examples/pytorch/metapath2vec/metapath2vec.py
View file @
704bcaf6
import
torch
import
argparse
import
argparse
import
torch
import
torch.optim
as
optim
import
torch.optim
as
optim
from
download
import
AminerDataset
,
CustomDataset
from
model
import
SkipGramModel
from
reading_data
import
DataReader
,
Metapath2vecDataset
from
torch.utils.data
import
DataLoader
from
torch.utils.data
import
DataLoader
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
reading_data
import
DataReader
,
Metapath2vecDataset
from
model
import
SkipGramModel
from
download
import
AminerDataset
,
CustomDataset
class
Metapath2VecTrainer
:
class
Metapath2VecTrainer
:
def
__init__
(
self
,
args
):
def
__init__
(
self
,
args
):
...
@@ -18,8 +19,13 @@ class Metapath2VecTrainer:
...
@@ -18,8 +19,13 @@ class Metapath2VecTrainer:
dataset
=
CustomDataset
(
args
.
path
)
dataset
=
CustomDataset
(
args
.
path
)
self
.
data
=
DataReader
(
dataset
,
args
.
min_count
,
args
.
care_type
)
self
.
data
=
DataReader
(
dataset
,
args
.
min_count
,
args
.
care_type
)
dataset
=
Metapath2vecDataset
(
self
.
data
,
args
.
window_size
)
dataset
=
Metapath2vecDataset
(
self
.
data
,
args
.
window_size
)
self
.
dataloader
=
DataLoader
(
dataset
,
batch_size
=
args
.
batch_size
,
self
.
dataloader
=
DataLoader
(
shuffle
=
True
,
num_workers
=
args
.
num_workers
,
collate_fn
=
dataset
.
collate
)
dataset
,
batch_size
=
args
.
batch_size
,
shuffle
=
True
,
num_workers
=
args
.
num_workers
,
collate_fn
=
dataset
.
collate
,
)
self
.
output_file_name
=
args
.
output_file
self
.
output_file_name
=
args
.
output_file
self
.
emb_size
=
len
(
self
.
data
.
word2id
)
self
.
emb_size
=
len
(
self
.
data
.
word2id
)
...
@@ -35,15 +41,17 @@ class Metapath2VecTrainer:
...
@@ -35,15 +41,17 @@ class Metapath2VecTrainer:
self
.
skip_gram_model
.
cuda
()
self
.
skip_gram_model
.
cuda
()
def
train
(
self
):
def
train
(
self
):
optimizer
=
optim
.
SparseAdam
(
optimizer
=
optim
.
SparseAdam
(
list
(
self
.
skip_gram_model
.
parameters
()),
lr
=
self
.
initial_lr
)
list
(
self
.
skip_gram_model
.
parameters
()),
lr
=
self
.
initial_lr
scheduler
=
torch
.
optim
.
lr_scheduler
.
CosineAnnealingLR
(
optimizer
,
len
(
self
.
dataloader
))
)
scheduler
=
torch
.
optim
.
lr_scheduler
.
CosineAnnealingLR
(
optimizer
,
len
(
self
.
dataloader
)
)
for
iteration
in
range
(
self
.
iterations
):
for
iteration
in
range
(
self
.
iterations
):
print
(
"
\n\n\n
Iteration: "
+
str
(
iteration
+
1
))
print
(
"
\n\n\n
Iteration: "
+
str
(
iteration
+
1
))
running_loss
=
0.0
running_loss
=
0.0
for
i
,
sample_batched
in
enumerate
(
tqdm
(
self
.
dataloader
)):
for
i
,
sample_batched
in
enumerate
(
tqdm
(
self
.
dataloader
)):
if
len
(
sample_batched
[
0
])
>
1
:
if
len
(
sample_batched
[
0
])
>
1
:
pos_u
=
sample_batched
[
0
].
to
(
self
.
device
)
pos_u
=
sample_batched
[
0
].
to
(
self
.
device
)
pos_v
=
sample_batched
[
1
].
to
(
self
.
device
)
pos_v
=
sample_batched
[
1
].
to
(
self
.
device
)
...
@@ -59,23 +67,40 @@ class Metapath2VecTrainer:
...
@@ -59,23 +67,40 @@ class Metapath2VecTrainer:
if
i
>
0
and
i
%
500
==
0
:
if
i
>
0
and
i
%
500
==
0
:
print
(
" Loss: "
+
str
(
running_loss
))
print
(
" Loss: "
+
str
(
running_loss
))
self
.
skip_gram_model
.
save_embedding
(
self
.
data
.
id2word
,
self
.
output_file_name
)
self
.
skip_gram_model
.
save_embedding
(
self
.
data
.
id2word
,
self
.
output_file_name
)
if
__name__
==
'
__main__
'
:
if
__name__
==
"
__main__
"
:
parser
=
argparse
.
ArgumentParser
(
description
=
"Metapath2vec"
)
parser
=
argparse
.
ArgumentParser
(
description
=
"Metapath2vec"
)
#parser.add_argument('--input_file', type=str, help="input_file")
# parser.add_argument('--input_file', type=str, help="input_file")
parser
.
add_argument
(
'--aminer'
,
action
=
'store_true'
,
help
=
'Use AMiner dataset'
)
parser
.
add_argument
(
parser
.
add_argument
(
'--path'
,
type
=
str
,
help
=
"input_path"
)
"--aminer"
,
action
=
"store_true"
,
help
=
"Use AMiner dataset"
parser
.
add_argument
(
'--output_file'
,
type
=
str
,
help
=
'output_file'
)
)
parser
.
add_argument
(
'--dim'
,
default
=
128
,
type
=
int
,
help
=
"embedding dimensions"
)
parser
.
add_argument
(
"--path"
,
type
=
str
,
help
=
"input_path"
)
parser
.
add_argument
(
'--window_size'
,
default
=
7
,
type
=
int
,
help
=
"context window size"
)
parser
.
add_argument
(
"--output_file"
,
type
=
str
,
help
=
"output_file"
)
parser
.
add_argument
(
'--iterations'
,
default
=
5
,
type
=
int
,
help
=
"iterations"
)
parser
.
add_argument
(
parser
.
add_argument
(
'--batch_size'
,
default
=
50
,
type
=
int
,
help
=
"batch size"
)
"--dim"
,
default
=
128
,
type
=
int
,
help
=
"embedding dimensions"
parser
.
add_argument
(
'--care_type'
,
default
=
0
,
type
=
int
,
help
=
"if 1, heterogeneous negative sampling, else normal negative sampling"
)
)
parser
.
add_argument
(
'--initial_lr'
,
default
=
0.025
,
type
=
float
,
help
=
"learning rate"
)
parser
.
add_argument
(
parser
.
add_argument
(
'--min_count'
,
default
=
5
,
type
=
int
,
help
=
"min count"
)
"--window_size"
,
default
=
7
,
type
=
int
,
help
=
"context window size"
parser
.
add_argument
(
'--num_workers'
,
default
=
16
,
type
=
int
,
help
=
"number of workers"
)
)
parser
.
add_argument
(
"--iterations"
,
default
=
5
,
type
=
int
,
help
=
"iterations"
)
parser
.
add_argument
(
"--batch_size"
,
default
=
50
,
type
=
int
,
help
=
"batch size"
)
parser
.
add_argument
(
"--care_type"
,
default
=
0
,
type
=
int
,
help
=
"if 1, heterogeneous negative sampling, else normal negative sampling"
,
)
parser
.
add_argument
(
"--initial_lr"
,
default
=
0.025
,
type
=
float
,
help
=
"learning rate"
)
parser
.
add_argument
(
"--min_count"
,
default
=
5
,
type
=
int
,
help
=
"min count"
)
parser
.
add_argument
(
"--num_workers"
,
default
=
16
,
type
=
int
,
help
=
"number of workers"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
m2v
=
Metapath2VecTrainer
(
args
)
m2v
=
Metapath2VecTrainer
(
args
)
m2v
.
train
()
m2v
.
train
()
examples/pytorch/metapath2vec/model.py
View file @
704bcaf6
...
@@ -10,7 +10,6 @@ from torch.nn import init
...
@@ -10,7 +10,6 @@ from torch.nn import init
class
SkipGramModel
(
nn
.
Module
):
class
SkipGramModel
(
nn
.
Module
):
def
__init__
(
self
,
emb_size
,
emb_dimension
):
def
__init__
(
self
,
emb_size
,
emb_dimension
):
super
(
SkipGramModel
,
self
).
__init__
()
super
(
SkipGramModel
,
self
).
__init__
()
self
.
emb_size
=
emb_size
self
.
emb_size
=
emb_size
...
@@ -39,8 +38,8 @@ class SkipGramModel(nn.Module):
...
@@ -39,8 +38,8 @@ class SkipGramModel(nn.Module):
def
save_embedding
(
self
,
id2word
,
file_name
):
def
save_embedding
(
self
,
id2word
,
file_name
):
embedding
=
self
.
u_embeddings
.
weight
.
cpu
().
data
.
numpy
()
embedding
=
self
.
u_embeddings
.
weight
.
cpu
().
data
.
numpy
()
with
open
(
file_name
,
'w'
)
as
f
:
with
open
(
file_name
,
"w"
)
as
f
:
f
.
write
(
'
%d %d
\n
'
%
(
len
(
id2word
),
self
.
emb_dimension
))
f
.
write
(
"
%d %d
\n
"
%
(
len
(
id2word
),
self
.
emb_dimension
))
for
wid
,
w
in
id2word
.
items
():
for
wid
,
w
in
id2word
.
items
():
e
=
' '
.
join
(
map
(
lambda
x
:
str
(
x
),
embedding
[
wid
]))
e
=
" "
.
join
(
map
(
lambda
x
:
str
(
x
),
embedding
[
wid
]))
f
.
write
(
'%s %s
\n
'
%
(
w
,
e
))
f
.
write
(
"%s %s
\n
"
%
(
w
,
e
))
\ No newline at end of file
examples/pytorch/metapath2vec/reading_data.py
View file @
704bcaf6
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
torch.utils.data
import
Dataset
from
download
import
AminerDataset
from
download
import
AminerDataset
from
torch.utils.data
import
Dataset
np
.
random
.
seed
(
12345
)
np
.
random
.
seed
(
12345
)
class
DataReader
:
class
DataReader
:
NEGATIVE_TABLE_SIZE
=
1e8
NEGATIVE_TABLE_SIZE
=
1e8
def
__init__
(
self
,
dataset
,
min_count
,
care_type
):
def
__init__
(
self
,
dataset
,
min_count
,
care_type
):
self
.
negatives
=
[]
self
.
negatives
=
[]
self
.
discards
=
[]
self
.
discards
=
[]
self
.
negpos
=
0
self
.
negpos
=
0
...
@@ -35,7 +36,11 @@ class DataReader:
...
@@ -35,7 +36,11 @@ class DataReader:
word_frequency
[
word
]
=
word_frequency
.
get
(
word
,
0
)
+
1
word_frequency
[
word
]
=
word_frequency
.
get
(
word
,
0
)
+
1
if
self
.
token_count
%
1000000
==
0
:
if
self
.
token_count
%
1000000
==
0
:
print
(
"Read "
+
str
(
int
(
self
.
token_count
/
1000000
))
+
"M words."
)
print
(
"Read "
+
str
(
int
(
self
.
token_count
/
1000000
))
+
"M words."
)
wid
=
0
wid
=
0
for
w
,
c
in
word_frequency
.
items
():
for
w
,
c
in
word_frequency
.
items
():
...
@@ -71,15 +76,18 @@ class DataReader:
...
@@ -71,15 +76,18 @@ class DataReader:
def
getNegatives
(
self
,
target
,
size
):
# TODO check equality with target
def
getNegatives
(
self
,
target
,
size
):
# TODO check equality with target
if
self
.
care_type
==
0
:
if
self
.
care_type
==
0
:
response
=
self
.
negatives
[
self
.
negpos
:
self
.
negpos
+
size
]
response
=
self
.
negatives
[
self
.
negpos
:
self
.
negpos
+
size
]
self
.
negpos
=
(
self
.
negpos
+
size
)
%
len
(
self
.
negatives
)
self
.
negpos
=
(
self
.
negpos
+
size
)
%
len
(
self
.
negatives
)
if
len
(
response
)
!=
size
:
if
len
(
response
)
!=
size
:
return
np
.
concatenate
((
response
,
self
.
negatives
[
0
:
self
.
negpos
]))
return
np
.
concatenate
(
(
response
,
self
.
negatives
[
0
:
self
.
negpos
])
)
return
response
return
response
# -----------------------------------------------------------------------------------------------------------------
# -----------------------------------------------------------------------------------------------------------------
class
Metapath2vecDataset
(
Dataset
):
class
Metapath2vecDataset
(
Dataset
):
def
__init__
(
self
,
data
,
window_size
):
def
__init__
(
self
,
data
,
window_size
):
# read in data, window_size and input filename
# read in data, window_size and input filename
...
@@ -103,25 +111,44 @@ class Metapath2vecDataset(Dataset):
...
@@ -103,25 +111,44 @@ class Metapath2vecDataset(Dataset):
words
=
line
.
split
()
words
=
line
.
split
()
if
len
(
words
)
>
1
:
if
len
(
words
)
>
1
:
word_ids
=
[
self
.
data
.
word2id
[
w
]
for
w
in
words
if
word_ids
=
[
w
in
self
.
data
.
word2id
and
np
.
random
.
rand
()
<
self
.
data
.
discards
[
self
.
data
.
word2id
[
w
]]]
self
.
data
.
word2id
[
w
]
for
w
in
words
if
w
in
self
.
data
.
word2id
and
np
.
random
.
rand
()
<
self
.
data
.
discards
[
self
.
data
.
word2id
[
w
]]
]
pair_catch
=
[]
pair_catch
=
[]
for
i
,
u
in
enumerate
(
word_ids
):
for
i
,
u
in
enumerate
(
word_ids
):
for
j
,
v
in
enumerate
(
for
j
,
v
in
enumerate
(
word_ids
[
max
(
i
-
self
.
window_size
,
0
):
i
+
self
.
window_size
]):
word_ids
[
max
(
i
-
self
.
window_size
,
0
)
:
i
+
self
.
window_size
]
):
assert
u
<
self
.
data
.
word_count
assert
u
<
self
.
data
.
word_count
assert
v
<
self
.
data
.
word_count
assert
v
<
self
.
data
.
word_count
if
i
==
j
:
if
i
==
j
:
continue
continue
pair_catch
.
append
((
u
,
v
,
self
.
data
.
getNegatives
(
v
,
5
)))
pair_catch
.
append
(
(
u
,
v
,
self
.
data
.
getNegatives
(
v
,
5
))
)
return
pair_catch
return
pair_catch
@
staticmethod
@
staticmethod
def
collate
(
batches
):
def
collate
(
batches
):
all_u
=
[
u
for
batch
in
batches
for
u
,
_
,
_
in
batch
if
len
(
batch
)
>
0
]
all_u
=
[
u
for
batch
in
batches
for
u
,
_
,
_
in
batch
if
len
(
batch
)
>
0
]
all_v
=
[
v
for
batch
in
batches
for
_
,
v
,
_
in
batch
if
len
(
batch
)
>
0
]
all_v
=
[
v
for
batch
in
batches
for
_
,
v
,
_
in
batch
if
len
(
batch
)
>
0
]
all_neg_v
=
[
neg_v
for
batch
in
batches
for
_
,
_
,
neg_v
in
batch
if
len
(
batch
)
>
0
]
all_neg_v
=
[
neg_v
return
torch
.
LongTensor
(
all_u
),
torch
.
LongTensor
(
all_v
),
torch
.
LongTensor
(
all_neg_v
)
for
batch
in
batches
for
_
,
_
,
neg_v
in
batch
if
len
(
batch
)
>
0
]
return
(
torch
.
LongTensor
(
all_u
),
torch
.
LongTensor
(
all_v
),
torch
.
LongTensor
(
all_neg_v
),
)
examples/pytorch/metapath2vec/sampler.py
View file @
704bcaf6
import
numpy
as
np
import
os
import
random
import
random
import
sys
import
time
import
time
import
tqdm
import
dgl
import
dgl
import
sys
import
numpy
as
np
import
os
import
tqdm
num_walks_per_node
=
1000
num_walks_per_node
=
1000
walk_length
=
100
walk_length
=
100
path
=
sys
.
argv
[
1
]
path
=
sys
.
argv
[
1
]
def
construct_graph
():
def
construct_graph
():
paper_ids
=
[]
paper_ids
=
[]
paper_names
=
[]
paper_names
=
[]
...
@@ -31,7 +33,7 @@ def construct_graph():
...
@@ -31,7 +33,7 @@ def construct_graph():
while
True
:
while
True
:
w
=
f_4
.
readline
()
w
=
f_4
.
readline
()
if
not
w
:
if
not
w
:
break
;
break
w
=
w
.
strip
().
split
()
w
=
w
.
strip
().
split
()
identity
=
int
(
w
[
0
])
identity
=
int
(
w
[
0
])
conf_ids
.
append
(
identity
)
conf_ids
.
append
(
identity
)
...
@@ -39,10 +41,10 @@ def construct_graph():
...
@@ -39,10 +41,10 @@ def construct_graph():
while
True
:
while
True
:
v
=
f_5
.
readline
()
v
=
f_5
.
readline
()
if
not
v
:
if
not
v
:
break
;
break
v
=
v
.
strip
().
split
()
v
=
v
.
strip
().
split
()
identity
=
int
(
v
[
0
])
identity
=
int
(
v
[
0
])
paper_name
=
'p'
+
''
.
join
(
v
[
1
:])
paper_name
=
"p"
+
""
.
join
(
v
[
1
:])
paper_ids
.
append
(
identity
)
paper_ids
.
append
(
identity
)
paper_names
.
append
(
paper_name
)
paper_names
.
append
(
paper_name
)
f_3
.
close
()
f_3
.
close
()
...
@@ -60,41 +62,49 @@ def construct_graph():
...
@@ -60,41 +62,49 @@ def construct_graph():
f_1
=
open
(
os
.
path
.
join
(
path
,
"paper_author.txt"
),
"r"
)
f_1
=
open
(
os
.
path
.
join
(
path
,
"paper_author.txt"
),
"r"
)
f_2
=
open
(
os
.
path
.
join
(
path
,
"paper_conf.txt"
),
"r"
)
f_2
=
open
(
os
.
path
.
join
(
path
,
"paper_conf.txt"
),
"r"
)
for
x
in
f_1
:
for
x
in
f_1
:
x
=
x
.
split
(
'
\t
'
)
x
=
x
.
split
(
"
\t
"
)
x
[
0
]
=
int
(
x
[
0
])
x
[
0
]
=
int
(
x
[
0
])
x
[
1
]
=
int
(
x
[
1
].
strip
(
'
\n
'
))
x
[
1
]
=
int
(
x
[
1
].
strip
(
"
\n
"
))
paper_author_src
.
append
(
paper_ids_invmap
[
x
[
0
]])
paper_author_src
.
append
(
paper_ids_invmap
[
x
[
0
]])
paper_author_dst
.
append
(
author_ids_invmap
[
x
[
1
]])
paper_author_dst
.
append
(
author_ids_invmap
[
x
[
1
]])
for
y
in
f_2
:
for
y
in
f_2
:
y
=
y
.
split
(
'
\t
'
)
y
=
y
.
split
(
"
\t
"
)
y
[
0
]
=
int
(
y
[
0
])
y
[
0
]
=
int
(
y
[
0
])
y
[
1
]
=
int
(
y
[
1
].
strip
(
'
\n
'
))
y
[
1
]
=
int
(
y
[
1
].
strip
(
"
\n
"
))
paper_conf_src
.
append
(
paper_ids_invmap
[
y
[
0
]])
paper_conf_src
.
append
(
paper_ids_invmap
[
y
[
0
]])
paper_conf_dst
.
append
(
conf_ids_invmap
[
y
[
1
]])
paper_conf_dst
.
append
(
conf_ids_invmap
[
y
[
1
]])
f_1
.
close
()
f_1
.
close
()
f_2
.
close
()
f_2
.
close
()
hg
=
dgl
.
heterograph
({
hg
=
dgl
.
heterograph
(
(
'paper'
,
'pa'
,
'author'
)
:
(
paper_author_src
,
paper_author_dst
),
{
(
'author'
,
'ap'
,
'paper'
)
:
(
paper_author_dst
,
paper_author_src
),
(
"paper"
,
"pa"
,
"author"
):
(
paper_author_src
,
paper_author_dst
),
(
'paper'
,
'pc'
,
'conf'
)
:
(
paper_conf_src
,
paper_conf_dst
),
(
"author"
,
"ap"
,
"paper"
):
(
paper_author_dst
,
paper_author_src
),
(
'conf'
,
'cp'
,
'paper'
)
:
(
paper_conf_dst
,
paper_conf_src
)})
(
"paper"
,
"pc"
,
"conf"
):
(
paper_conf_src
,
paper_conf_dst
),
(
"conf"
,
"cp"
,
"paper"
):
(
paper_conf_dst
,
paper_conf_src
),
}
)
return
hg
,
author_names
,
conf_names
,
paper_names
return
hg
,
author_names
,
conf_names
,
paper_names
#"conference - paper - Author - paper - conference" metapath sampling
# "conference - paper - Author - paper - conference" metapath sampling
def
generate_metapath
():
def
generate_metapath
():
output_path
=
open
(
os
.
path
.
join
(
path
,
"output_path.txt"
),
"w"
)
output_path
=
open
(
os
.
path
.
join
(
path
,
"output_path.txt"
),
"w"
)
count
=
0
count
=
0
hg
,
author_names
,
conf_names
,
paper_names
=
construct_graph
()
hg
,
author_names
,
conf_names
,
paper_names
=
construct_graph
()
for
conf_idx
in
tqdm
.
trange
(
hg
.
number_of_nodes
(
'
conf
'
)):
for
conf_idx
in
tqdm
.
trange
(
hg
.
number_of_nodes
(
"
conf
"
)):
traces
,
_
=
dgl
.
sampling
.
random_walk
(
traces
,
_
=
dgl
.
sampling
.
random_walk
(
hg
,
[
conf_idx
]
*
num_walks_per_node
,
metapath
=
[
'cp'
,
'pa'
,
'ap'
,
'pc'
]
*
walk_length
)
hg
,
[
conf_idx
]
*
num_walks_per_node
,
metapath
=
[
"cp"
,
"pa"
,
"ap"
,
"pc"
]
*
walk_length
,
)
for
tr
in
traces
:
for
tr
in
traces
:
outline
=
' '
.
join
(
outline
=
" "
.
join
(
(
conf_names
if
i
%
4
==
0
else
author_names
)[
tr
[
i
]]
(
conf_names
if
i
%
4
==
0
else
author_names
)[
tr
[
i
]]
for
i
in
range
(
0
,
len
(
tr
),
2
))
# skip paper
for
i
in
range
(
0
,
len
(
tr
),
2
)
)
# skip paper
print
(
outline
,
file
=
output_path
)
print
(
outline
,
file
=
output_path
)
output_path
.
close
()
output_path
.
close
()
...
...
Prev
1
…
6
7
8
9
10
11
12
13
14
…
17
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment