Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
704bcaf6
Unverified
Commit
704bcaf6
authored
Feb 19, 2023
by
Hongzhi (Steve), Chen
Committed by
GitHub
Feb 19, 2023
Browse files
examples (#5323)
Co-authored-by:
Ubuntu
<
ubuntu@ip-172-31-28-63.ap-northeast-1.compute.internal
>
parent
6bc82161
Changes
332
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
281 additions
and
213 deletions
+281
-213
examples/pytorch/correct_and_smooth/main.py
examples/pytorch/correct_and_smooth/main.py
+3
-4
examples/pytorch/correct_and_smooth/model.py
examples/pytorch/correct_and_smooth/model.py
+1
-2
examples/pytorch/dagnn/main.py
examples/pytorch/dagnn/main.py
+4
-7
examples/pytorch/deepergcn/layers.py
examples/pytorch/deepergcn/layers.py
+3
-4
examples/pytorch/deepergcn/main.py
examples/pytorch/deepergcn/main.py
+1
-1
examples/pytorch/deepergcn/models.py
examples/pytorch/deepergcn/models.py
+2
-3
examples/pytorch/dgi/train.py
examples/pytorch/dgi/train.py
+97
-59
examples/pytorch/dgmg/configure.py
examples/pytorch/dgmg/configure.py
+0
-1
examples/pytorch/dgmg/main.py
examples/pytorch/dgmg/main.py
+0
-1
examples/pytorch/dgmg/model.py
examples/pytorch/dgmg/model.py
+72
-62
examples/pytorch/diffpool/model/dgl_layers/gnn.py
examples/pytorch/diffpool/model/dgl_layers/gnn.py
+75
-43
examples/pytorch/diffpool/model/encoder.py
examples/pytorch/diffpool/model/encoder.py
+2
-3
examples/pytorch/diffpool/model/tensorized_layers/assignment.py
...es/pytorch/diffpool/model/tensorized_layers/assignment.py
+2
-1
examples/pytorch/diffpool/model/tensorized_layers/diffpool.py
...ples/pytorch/diffpool/model/tensorized_layers/diffpool.py
+2
-1
examples/pytorch/diffpool/train.py
examples/pytorch/diffpool/train.py
+7
-7
examples/pytorch/dimenet/main.py
examples/pytorch/dimenet/main.py
+3
-3
examples/pytorch/dimenet/modules/interaction_block.py
examples/pytorch/dimenet/modules/interaction_block.py
+1
-2
examples/pytorch/dimenet/modules/interaction_pp_block.py
examples/pytorch/dimenet/modules/interaction_pp_block.py
+2
-3
examples/pytorch/dimenet/modules/output_block.py
examples/pytorch/dimenet/modules/output_block.py
+2
-3
examples/pytorch/dimenet/modules/output_pp_block.py
examples/pytorch/dimenet/modules/output_pp_block.py
+2
-3
No files found.
examples/pytorch/correct_and_smooth/main.py
View file @
704bcaf6
...
@@ -2,14 +2,14 @@ import argparse
...
@@ -2,14 +2,14 @@ import argparse
import
copy
import
copy
import
os
import
os
import
dgl
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
import
torch.optim
as
optim
import
torch.optim
as
optim
from
model
import
MLP
,
CorrectAndSmooth
,
MLPLinear
from
model
import
CorrectAndSmooth
,
MLP
,
MLPLinear
from
ogb.nodeproppred
import
DglNodePropPredDataset
,
Evaluator
from
ogb.nodeproppred
import
DglNodePropPredDataset
,
Evaluator
import
dgl
def
evaluate
(
y_pred
,
y_true
,
idx
,
evaluator
):
def
evaluate
(
y_pred
,
y_true
,
idx
,
evaluator
):
return
evaluator
.
eval
({
"y_true"
:
y_true
[
idx
],
"y_pred"
:
y_pred
[
idx
]})[
"acc"
]
return
evaluator
.
eval
({
"y_true"
:
y_true
[
idx
],
"y_pred"
:
y_pred
[
idx
]})[
"acc"
]
...
@@ -104,7 +104,6 @@ def main():
...
@@ -104,7 +104,6 @@ def main():
# training
# training
print
(
"---------- Training ----------"
)
print
(
"---------- Training ----------"
)
for
i
in
range
(
args
.
epochs
):
for
i
in
range
(
args
.
epochs
):
model
.
train
()
model
.
train
()
opt
.
zero_grad
()
opt
.
zero_grad
()
...
...
examples/pytorch/correct_and_smooth/model.py
View file @
704bcaf6
import
dgl.function
as
fn
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
import
dgl.function
as
fn
class
MLPLinear
(
nn
.
Module
):
class
MLPLinear
(
nn
.
Module
):
def
__init__
(
self
,
in_dim
,
out_dim
):
def
__init__
(
self
,
in_dim
,
out_dim
):
...
...
examples/pytorch/dagnn/main.py
View file @
704bcaf6
import
argparse
import
argparse
import
dgl.function
as
fn
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
dgl.data
import
CiteseerGraphDataset
,
CoraGraphDataset
,
PubmedGraphDataset
from
torch
import
nn
from
torch
import
nn
from
torch.nn
import
Parameter
from
torch.nn
import
functional
as
F
,
Parameter
from
torch.nn
import
functional
as
F
from
tqdm
import
trange
from
tqdm
import
trange
from
utils
import
evaluate
,
generate_random_seeds
,
set_random_state
from
utils
import
evaluate
,
generate_random_seeds
,
set_random_state
import
dgl.function
as
fn
from
dgl.data
import
CiteseerGraphDataset
,
CoraGraphDataset
,
PubmedGraphDataset
class
DAGNNConv
(
nn
.
Module
):
class
DAGNNConv
(
nn
.
Module
):
def
__init__
(
self
,
in_dim
,
k
):
def
__init__
(
self
,
in_dim
,
k
):
...
@@ -26,7 +25,6 @@ class DAGNNConv(nn.Module):
...
@@ -26,7 +25,6 @@ class DAGNNConv(nn.Module):
nn
.
init
.
xavier_uniform_
(
self
.
s
,
gain
=
gain
)
nn
.
init
.
xavier_uniform_
(
self
.
s
,
gain
=
gain
)
def
forward
(
self
,
graph
,
feats
):
def
forward
(
self
,
graph
,
feats
):
with
graph
.
local_scope
():
with
graph
.
local_scope
():
results
=
[
feats
]
results
=
[
feats
]
...
@@ -68,7 +66,6 @@ class MLPLayer(nn.Module):
...
@@ -68,7 +66,6 @@ class MLPLayer(nn.Module):
nn
.
init
.
zeros_
(
self
.
linear
.
bias
)
nn
.
init
.
zeros_
(
self
.
linear
.
bias
)
def
forward
(
self
,
feats
):
def
forward
(
self
,
feats
):
feats
=
self
.
dropout
(
feats
)
feats
=
self
.
dropout
(
feats
)
feats
=
self
.
linear
(
feats
)
feats
=
self
.
linear
(
feats
)
if
self
.
activation
:
if
self
.
activation
:
...
...
examples/pytorch/deepergcn/layers.py
View file @
704bcaf6
import
dgl.function
as
fn
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
modules
import
MLP
,
MessageNorm
from
ogb.graphproppred.mol_encoder
import
BondEncoder
import
dgl.function
as
fn
from
dgl.nn.functional
import
edge_softmax
from
dgl.nn.functional
import
edge_softmax
from
modules
import
MessageNorm
,
MLP
from
ogb.graphproppred.mol_encoder
import
BondEncoder
class
GENConv
(
nn
.
Module
):
class
GENConv
(
nn
.
Module
):
...
...
examples/pytorch/deepergcn/main.py
View file @
704bcaf6
...
@@ -6,7 +6,7 @@ import torch
...
@@ -6,7 +6,7 @@ import torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.optim
as
optim
import
torch.optim
as
optim
from
models
import
DeeperGCN
from
models
import
DeeperGCN
from
ogb.graphproppred
import
DglGraphPropPredDataset
,
Evaluator
,
collate_dgl
from
ogb.graphproppred
import
collate_dgl
,
DglGraphPropPredDataset
,
Evaluator
from
torch.utils.data
import
DataLoader
from
torch.utils.data
import
DataLoader
...
...
examples/pytorch/deepergcn/models.py
View file @
704bcaf6
import
dgl.function
as
fn
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
dgl.nn.pytorch.glob
import
AvgPooling
from
layers
import
GENConv
from
layers
import
GENConv
from
ogb.graphproppred.mol_encoder
import
AtomEncoder
from
ogb.graphproppred.mol_encoder
import
AtomEncoder
import
dgl.function
as
fn
from
dgl.nn.pytorch.glob
import
AvgPooling
class
DeeperGCN
(
nn
.
Module
):
class
DeeperGCN
(
nn
.
Module
):
r
"""
r
"""
...
...
examples/pytorch/dgi/train.py
View file @
704bcaf6
import
argparse
,
time
import
argparse
,
time
import
numpy
as
np
import
dgl
import
networkx
as
nx
import
networkx
as
nx
import
numpy
as
np
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
import
dgl
from
dgi
import
Classifier
,
DGI
from
dgl
import
DGLGraph
from
dgl
import
DGLGraph
from
dgl.data
import
register_data_args
,
load_data
from
dgl.data
import
load_data
,
register_data_args
from
dgi
import
DGI
,
Classifier
def
evaluate
(
model
,
features
,
labels
,
mask
):
def
evaluate
(
model
,
features
,
labels
,
mask
):
model
.
eval
()
model
.
eval
()
...
@@ -19,20 +21,21 @@ def evaluate(model, features, labels, mask):
...
@@ -19,20 +21,21 @@ def evaluate(model, features, labels, mask):
correct
=
torch
.
sum
(
indices
==
labels
)
correct
=
torch
.
sum
(
indices
==
labels
)
return
correct
.
item
()
*
1.0
/
len
(
labels
)
return
correct
.
item
()
*
1.0
/
len
(
labels
)
def
main
(
args
):
def
main
(
args
):
# load and preprocess dataset
# load and preprocess dataset
data
=
load_data
(
args
)
data
=
load_data
(
args
)
g
=
data
[
0
]
g
=
data
[
0
]
features
=
torch
.
FloatTensor
(
g
.
ndata
[
'
feat
'
])
features
=
torch
.
FloatTensor
(
g
.
ndata
[
"
feat
"
])
labels
=
torch
.
LongTensor
(
g
.
ndata
[
'
label
'
])
labels
=
torch
.
LongTensor
(
g
.
ndata
[
"
label
"
])
if
hasattr
(
torch
,
'
BoolTensor
'
):
if
hasattr
(
torch
,
"
BoolTensor
"
):
train_mask
=
torch
.
BoolTensor
(
g
.
ndata
[
'
train_mask
'
])
train_mask
=
torch
.
BoolTensor
(
g
.
ndata
[
"
train_mask
"
])
val_mask
=
torch
.
BoolTensor
(
g
.
ndata
[
'
val_mask
'
])
val_mask
=
torch
.
BoolTensor
(
g
.
ndata
[
"
val_mask
"
])
test_mask
=
torch
.
BoolTensor
(
g
.
ndata
[
'
test_mask
'
])
test_mask
=
torch
.
BoolTensor
(
g
.
ndata
[
"
test_mask
"
])
else
:
else
:
train_mask
=
torch
.
ByteTensor
(
g
.
ndata
[
'
train_mask
'
])
train_mask
=
torch
.
ByteTensor
(
g
.
ndata
[
"
train_mask
"
])
val_mask
=
torch
.
ByteTensor
(
g
.
ndata
[
'
val_mask
'
])
val_mask
=
torch
.
ByteTensor
(
g
.
ndata
[
"
val_mask
"
])
test_mask
=
torch
.
ByteTensor
(
g
.
ndata
[
'
test_mask
'
])
test_mask
=
torch
.
ByteTensor
(
g
.
ndata
[
"
test_mask
"
])
in_feats
=
features
.
shape
[
1
]
in_feats
=
features
.
shape
[
1
]
n_classes
=
data
.
num_classes
n_classes
=
data
.
num_classes
n_edges
=
g
.
number_of_edges
()
n_edges
=
g
.
number_of_edges
()
...
@@ -57,19 +60,21 @@ def main(args):
...
@@ -57,19 +60,21 @@ def main(args):
if
args
.
gpu
>=
0
:
if
args
.
gpu
>=
0
:
g
=
g
.
to
(
args
.
gpu
)
g
=
g
.
to
(
args
.
gpu
)
# create DGI model
# create DGI model
dgi
=
DGI
(
g
,
dgi
=
DGI
(
in_feats
,
g
,
args
.
n_hidden
,
in_feats
,
args
.
n_layers
,
args
.
n_hidden
,
nn
.
PReLU
(
args
.
n_hidden
),
args
.
n_layers
,
args
.
dropout
)
nn
.
PReLU
(
args
.
n_hidden
),
args
.
dropout
,
)
if
cuda
:
if
cuda
:
dgi
.
cuda
()
dgi
.
cuda
()
dgi_optimizer
=
torch
.
optim
.
Adam
(
dgi
.
parameters
(),
dgi_optimizer
=
torch
.
optim
.
Adam
(
lr
=
args
.
dgi_lr
,
dgi
.
parameters
(),
lr
=
args
.
dgi_lr
,
weight_decay
=
args
.
weight_decay
weight_decay
=
args
.
weight_decay
)
)
# train deep graph infomax
# train deep graph infomax
cnt_wait
=
0
cnt_wait
=
0
...
@@ -90,33 +95,38 @@ def main(args):
...
@@ -90,33 +95,38 @@ def main(args):
best
=
loss
best
=
loss
best_t
=
epoch
best_t
=
epoch
cnt_wait
=
0
cnt_wait
=
0
torch
.
save
(
dgi
.
state_dict
(),
'
best_dgi.pkl
'
)
torch
.
save
(
dgi
.
state_dict
(),
"
best_dgi.pkl
"
)
else
:
else
:
cnt_wait
+=
1
cnt_wait
+=
1
if
cnt_wait
==
args
.
patience
:
if
cnt_wait
==
args
.
patience
:
print
(
'
Early stopping!
'
)
print
(
"
Early stopping!
"
)
break
break
if
epoch
>=
3
:
if
epoch
>=
3
:
dur
.
append
(
time
.
time
()
-
t0
)
dur
.
append
(
time
.
time
()
-
t0
)
print
(
"Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | "
print
(
"ETputs(KTEPS) {:.2f}"
.
format
(
epoch
,
np
.
mean
(
dur
),
loss
.
item
(),
"Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | "
n_edges
/
np
.
mean
(
dur
)
/
1000
))
"ETputs(KTEPS) {:.2f}"
.
format
(
epoch
,
np
.
mean
(
dur
),
loss
.
item
(),
n_edges
/
np
.
mean
(
dur
)
/
1000
)
)
# create classifier model
# create classifier model
classifier
=
Classifier
(
args
.
n_hidden
,
n_classes
)
classifier
=
Classifier
(
args
.
n_hidden
,
n_classes
)
if
cuda
:
if
cuda
:
classifier
.
cuda
()
classifier
.
cuda
()
classifier_optimizer
=
torch
.
optim
.
Adam
(
classifier
.
parameters
(),
classifier_optimizer
=
torch
.
optim
.
Adam
(
lr
=
args
.
classifier_lr
,
classifier
.
parameters
(),
weight_decay
=
args
.
weight_decay
)
lr
=
args
.
classifier_lr
,
weight_decay
=
args
.
weight_decay
,
)
# train classifier
# train classifier
print
(
'
Loading {}th epoch
'
.
format
(
best_t
))
print
(
"
Loading {}th epoch
"
.
format
(
best_t
))
dgi
.
load_state_dict
(
torch
.
load
(
'
best_dgi.pkl
'
))
dgi
.
load_state_dict
(
torch
.
load
(
"
best_dgi.pkl
"
))
embeds
=
dgi
.
encoder
(
features
,
corrupt
=
False
)
embeds
=
dgi
.
encoder
(
features
,
corrupt
=
False
)
embeds
=
embeds
.
detach
()
embeds
=
embeds
.
detach
()
dur
=
[]
dur
=
[]
...
@@ -135,39 +145,67 @@ def main(args):
...
@@ -135,39 +145,67 @@ def main(args):
dur
.
append
(
time
.
time
()
-
t0
)
dur
.
append
(
time
.
time
()
-
t0
)
acc
=
evaluate
(
classifier
,
embeds
,
labels
,
val_mask
)
acc
=
evaluate
(
classifier
,
embeds
,
labels
,
val_mask
)
print
(
"Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | "
print
(
"ETputs(KTEPS) {:.2f}"
.
format
(
epoch
,
np
.
mean
(
dur
),
loss
.
item
(),
"Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | "
acc
,
n_edges
/
np
.
mean
(
dur
)
/
1000
))
"ETputs(KTEPS) {:.2f}"
.
format
(
epoch
,
np
.
mean
(
dur
),
loss
.
item
(),
acc
,
n_edges
/
np
.
mean
(
dur
)
/
1000
,
)
)
print
()
print
()
acc
=
evaluate
(
classifier
,
embeds
,
labels
,
test_mask
)
acc
=
evaluate
(
classifier
,
embeds
,
labels
,
test_mask
)
print
(
"Test Accuracy {:.4f}"
.
format
(
acc
))
print
(
"Test Accuracy {:.4f}"
.
format
(
acc
))
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
(
description
=
'DGI'
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
(
description
=
"DGI"
)
register_data_args
(
parser
)
register_data_args
(
parser
)
parser
.
add_argument
(
"--dropout"
,
type
=
float
,
default
=
0.
,
parser
.
add_argument
(
help
=
"dropout probability"
)
"--dropout"
,
type
=
float
,
default
=
0.0
,
help
=
"dropout probability"
parser
.
add_argument
(
"--gpu"
,
type
=
int
,
default
=-
1
,
)
help
=
"gpu"
)
parser
.
add_argument
(
"--gpu"
,
type
=
int
,
default
=-
1
,
help
=
"gpu"
)
parser
.
add_argument
(
"--dgi-lr"
,
type
=
float
,
default
=
1e-3
,
parser
.
add_argument
(
help
=
"dgi learning rate"
)
"--dgi-lr"
,
type
=
float
,
default
=
1e-3
,
help
=
"dgi learning rate"
parser
.
add_argument
(
"--classifier-lr"
,
type
=
float
,
default
=
1e-2
,
)
help
=
"classifier learning rate"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--n-dgi-epochs"
,
type
=
int
,
default
=
300
,
"--classifier-lr"
,
help
=
"number of training epochs"
)
type
=
float
,
parser
.
add_argument
(
"--n-classifier-epochs"
,
type
=
int
,
default
=
300
,
default
=
1e-2
,
help
=
"number of training epochs"
)
help
=
"classifier learning rate"
,
parser
.
add_argument
(
"--n-hidden"
,
type
=
int
,
default
=
512
,
)
help
=
"number of hidden gcn units"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--n-layers"
,
type
=
int
,
default
=
1
,
"--n-dgi-epochs"
,
help
=
"number of hidden gcn layers"
)
type
=
int
,
parser
.
add_argument
(
"--weight-decay"
,
type
=
float
,
default
=
0.
,
default
=
300
,
help
=
"Weight for L2 loss"
)
help
=
"number of training epochs"
,
parser
.
add_argument
(
"--patience"
,
type
=
int
,
default
=
20
,
)
help
=
"early stop patience condition"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--self-loop"
,
action
=
'store_true'
,
"--n-classifier-epochs"
,
help
=
"graph self-loop (default=False)"
)
type
=
int
,
default
=
300
,
help
=
"number of training epochs"
,
)
parser
.
add_argument
(
"--n-hidden"
,
type
=
int
,
default
=
512
,
help
=
"number of hidden gcn units"
)
parser
.
add_argument
(
"--n-layers"
,
type
=
int
,
default
=
1
,
help
=
"number of hidden gcn layers"
)
parser
.
add_argument
(
"--weight-decay"
,
type
=
float
,
default
=
0.0
,
help
=
"Weight for L2 loss"
)
parser
.
add_argument
(
"--patience"
,
type
=
int
,
default
=
20
,
help
=
"early stop patience condition"
)
parser
.
add_argument
(
"--self-loop"
,
action
=
"store_true"
,
help
=
"graph self-loop (default=False)"
,
)
parser
.
set_defaults
(
self_loop
=
False
)
parser
.
set_defaults
(
self_loop
=
False
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
print
(
args
)
print
(
args
)
...
...
examples/pytorch/dgmg/configure.py
View file @
704bcaf6
...
@@ -4,7 +4,6 @@ and will be loaded when setting up."""
...
@@ -4,7 +4,6 @@ and will be loaded when setting up."""
def
dataset_based_configure
(
opts
):
def
dataset_based_configure
(
opts
):
if
opts
[
"dataset"
]
==
"cycles"
:
if
opts
[
"dataset"
]
==
"cycles"
:
ds_configure
=
cycles_configure
ds_configure
=
cycles_configure
else
:
else
:
...
...
examples/pytorch/dgmg/main.py
View file @
704bcaf6
...
@@ -65,7 +65,6 @@ def main(opts):
...
@@ -65,7 +65,6 @@ def main(opts):
optimizer
.
zero_grad
()
optimizer
.
zero_grad
()
for
i
,
data
in
enumerate
(
data_loader
):
for
i
,
data
in
enumerate
(
data_loader
):
log_prob
=
model
(
actions
=
data
)
log_prob
=
model
(
actions
=
data
)
prob
=
log_prob
.
detach
().
exp
()
prob
=
log_prob
.
detach
().
exp
()
...
...
examples/pytorch/dgmg/model.py
View file @
704bcaf6
from
functools
import
partial
import
dgl
import
dgl
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
functools
import
partial
from
torch.distributions
import
Bernoulli
,
Categorical
from
torch.distributions
import
Bernoulli
,
Categorical
...
@@ -15,20 +16,19 @@ class GraphEmbed(nn.Module):
...
@@ -15,20 +16,19 @@ class GraphEmbed(nn.Module):
# Embed graphs
# Embed graphs
self
.
node_gating
=
nn
.
Sequential
(
self
.
node_gating
=
nn
.
Sequential
(
nn
.
Linear
(
node_hidden_size
,
1
),
nn
.
Linear
(
node_hidden_size
,
1
),
nn
.
Sigmoid
()
nn
.
Sigmoid
()
)
)
self
.
node_to_graph
=
nn
.
Linear
(
node_hidden_size
,
self
.
node_to_graph
=
nn
.
Linear
(
node_hidden_size
,
self
.
graph_hidden_size
)
self
.
graph_hidden_size
)
def
forward
(
self
,
g
):
def
forward
(
self
,
g
):
if
g
.
number_of_nodes
()
==
0
:
if
g
.
number_of_nodes
()
==
0
:
return
torch
.
zeros
(
1
,
self
.
graph_hidden_size
)
return
torch
.
zeros
(
1
,
self
.
graph_hidden_size
)
else
:
else
:
# Node features are stored as hv in ndata.
# Node features are stored as hv in ndata.
hvs
=
g
.
ndata
[
'hv'
]
hvs
=
g
.
ndata
[
"hv"
]
return
(
self
.
node_gating
(
hvs
)
*
return
(
self
.
node_gating
(
hvs
)
*
self
.
node_to_graph
(
hvs
)).
sum
(
self
.
node_to_graph
(
hvs
)).
sum
(
0
,
keepdim
=
True
)
0
,
keepdim
=
True
)
class
GraphProp
(
nn
.
Module
):
class
GraphProp
(
nn
.
Module
):
...
@@ -46,41 +46,45 @@ class GraphProp(nn.Module):
...
@@ -46,41 +46,45 @@ class GraphProp(nn.Module):
for
t
in
range
(
num_prop_rounds
):
for
t
in
range
(
num_prop_rounds
):
# input being [hv, hu, xuv]
# input being [hv, hu, xuv]
message_funcs
.
append
(
nn
.
Linear
(
2
*
node_hidden_size
+
1
,
message_funcs
.
append
(
self
.
node_activation_hidden_size
))
nn
.
Linear
(
2
*
node_hidden_size
+
1
,
self
.
node_activation_hidden_size
)
)
self
.
reduce_funcs
.
append
(
partial
(
self
.
dgmg_reduce
,
round
=
t
))
self
.
reduce_funcs
.
append
(
partial
(
self
.
dgmg_reduce
,
round
=
t
))
node_update_funcs
.
append
(
node_update_funcs
.
append
(
nn
.
GRUCell
(
self
.
node_activation_hidden_size
,
nn
.
GRUCell
(
self
.
node_activation_hidden_size
,
node_hidden_size
)
node_hidden_size
)
)
)
self
.
message_funcs
=
nn
.
ModuleList
(
message_funcs
)
self
.
message_funcs
=
nn
.
ModuleList
(
message_funcs
)
self
.
node_update_funcs
=
nn
.
ModuleList
(
node_update_funcs
)
self
.
node_update_funcs
=
nn
.
ModuleList
(
node_update_funcs
)
def
dgmg_msg
(
self
,
edges
):
def
dgmg_msg
(
self
,
edges
):
"""For an edge u->v, return concat([h_u, x_uv])"""
"""For an edge u->v, return concat([h_u, x_uv])"""
return
{
'm'
:
torch
.
cat
([
edges
.
src
[
'hv'
],
return
{
"m"
:
torch
.
cat
([
edges
.
src
[
"hv"
],
edges
.
data
[
"he"
]],
dim
=
1
)}
edges
.
data
[
'he'
]],
dim
=
1
)}
def
dgmg_reduce
(
self
,
nodes
,
round
):
def
dgmg_reduce
(
self
,
nodes
,
round
):
hv_old
=
nodes
.
data
[
'hv'
]
hv_old
=
nodes
.
data
[
"hv"
]
m
=
nodes
.
mailbox
[
'm'
]
m
=
nodes
.
mailbox
[
"m"
]
message
=
torch
.
cat
([
message
=
torch
.
cat
(
hv_old
.
unsqueeze
(
1
).
expand
(
-
1
,
m
.
size
(
1
),
-
1
),
m
],
dim
=
2
)
[
hv_old
.
unsqueeze
(
1
).
expand
(
-
1
,
m
.
size
(
1
),
-
1
),
m
],
dim
=
2
)
node_activation
=
(
self
.
message_funcs
[
round
](
message
)).
sum
(
1
)
node_activation
=
(
self
.
message_funcs
[
round
](
message
)).
sum
(
1
)
return
{
'a'
:
node_activation
}
return
{
"a"
:
node_activation
}
def
forward
(
self
,
g
):
def
forward
(
self
,
g
):
if
g
.
number_of_edges
()
==
0
:
if
g
.
number_of_edges
()
==
0
:
return
return
else
:
else
:
for
t
in
range
(
self
.
num_prop_rounds
):
for
t
in
range
(
self
.
num_prop_rounds
):
g
.
update_all
(
message_func
=
self
.
dgmg_msg
,
g
.
update_all
(
reduce_func
=
self
.
reduce_funcs
[
t
])
message_func
=
self
.
dgmg_msg
,
reduce_func
=
self
.
reduce_funcs
[
t
]
g
.
ndata
[
'hv'
]
=
self
.
node_update_funcs
[
t
](
)
g
.
ndata
[
'a'
],
g
.
ndata
[
'hv'
])
g
.
ndata
[
"hv"
]
=
self
.
node_update_funcs
[
t
](
g
.
ndata
[
"a"
],
g
.
ndata
[
"hv"
]
)
def
bernoulli_action_log_prob
(
logit
,
action
):
def
bernoulli_action_log_prob
(
logit
,
action
):
...
@@ -96,33 +100,39 @@ class AddNode(nn.Module):
...
@@ -96,33 +100,39 @@ class AddNode(nn.Module):
def
__init__
(
self
,
graph_embed_func
,
node_hidden_size
):
def
__init__
(
self
,
graph_embed_func
,
node_hidden_size
):
super
(
AddNode
,
self
).
__init__
()
super
(
AddNode
,
self
).
__init__
()
self
.
graph_op
=
{
'
embed
'
:
graph_embed_func
}
self
.
graph_op
=
{
"
embed
"
:
graph_embed_func
}
self
.
stop
=
1
self
.
stop
=
1
self
.
add_node
=
nn
.
Linear
(
graph_embed_func
.
graph_hidden_size
,
1
)
self
.
add_node
=
nn
.
Linear
(
graph_embed_func
.
graph_hidden_size
,
1
)
# If to add a node, initialize its hv
# If to add a node, initialize its hv
self
.
node_type_embed
=
nn
.
Embedding
(
1
,
node_hidden_size
)
self
.
node_type_embed
=
nn
.
Embedding
(
1
,
node_hidden_size
)
self
.
initialize_hv
=
nn
.
Linear
(
node_hidden_size
+
\
self
.
initialize_hv
=
nn
.
Linear
(
graph_embed_func
.
graph_hidden_size
,
node_hidden_size
+
graph_embed_func
.
graph_hidden_size
,
node_hidden_size
)
node_hidden_size
,
)
self
.
init_node_activation
=
torch
.
zeros
(
1
,
2
*
node_hidden_size
)
self
.
init_node_activation
=
torch
.
zeros
(
1
,
2
*
node_hidden_size
)
def
_initialize_node_repr
(
self
,
g
,
node_type
,
graph_embed
):
def
_initialize_node_repr
(
self
,
g
,
node_type
,
graph_embed
):
num_nodes
=
g
.
number_of_nodes
()
num_nodes
=
g
.
number_of_nodes
()
hv_init
=
self
.
initialize_hv
(
hv_init
=
self
.
initialize_hv
(
torch
.
cat
([
torch
.
cat
(
self
.
node_type_embed
(
torch
.
LongTensor
([
node_type
])),
[
graph_embed
],
dim
=
1
))
self
.
node_type_embed
(
torch
.
LongTensor
([
node_type
])),
g
.
nodes
[
num_nodes
-
1
].
data
[
'hv'
]
=
hv_init
graph_embed
,
g
.
nodes
[
num_nodes
-
1
].
data
[
'a'
]
=
self
.
init_node_activation
],
dim
=
1
,
)
)
g
.
nodes
[
num_nodes
-
1
].
data
[
"hv"
]
=
hv_init
g
.
nodes
[
num_nodes
-
1
].
data
[
"a"
]
=
self
.
init_node_activation
def
prepare_training
(
self
):
def
prepare_training
(
self
):
self
.
log_prob
=
[]
self
.
log_prob
=
[]
def
forward
(
self
,
g
,
action
=
None
):
def
forward
(
self
,
g
,
action
=
None
):
graph_embed
=
self
.
graph_op
[
'
embed
'
](
g
)
graph_embed
=
self
.
graph_op
[
"
embed
"
](
g
)
logit
=
self
.
add_node
(
graph_embed
)
logit
=
self
.
add_node
(
graph_embed
)
prob
=
torch
.
sigmoid
(
logit
)
prob
=
torch
.
sigmoid
(
logit
)
...
@@ -146,19 +156,19 @@ class AddEdge(nn.Module):
...
@@ -146,19 +156,19 @@ class AddEdge(nn.Module):
def
__init__
(
self
,
graph_embed_func
,
node_hidden_size
):
def
__init__
(
self
,
graph_embed_func
,
node_hidden_size
):
super
(
AddEdge
,
self
).
__init__
()
super
(
AddEdge
,
self
).
__init__
()
self
.
graph_op
=
{
'embed'
:
graph_embed_func
}
self
.
graph_op
=
{
"embed"
:
graph_embed_func
}
self
.
add_edge
=
nn
.
Linear
(
graph_embed_func
.
graph_hidden_size
+
\
self
.
add_edge
=
nn
.
Linear
(
node_hidden_size
,
1
)
graph_embed_func
.
graph_hidden_size
+
node_hidden_size
,
1
)
def
prepare_training
(
self
):
def
prepare_training
(
self
):
self
.
log_prob
=
[]
self
.
log_prob
=
[]
def
forward
(
self
,
g
,
action
=
None
):
def
forward
(
self
,
g
,
action
=
None
):
graph_embed
=
self
.
graph_op
[
'
embed
'
](
g
)
graph_embed
=
self
.
graph_op
[
"
embed
"
](
g
)
src_embed
=
g
.
nodes
[
g
.
number_of_nodes
()
-
1
].
data
[
'
hv
'
]
src_embed
=
g
.
nodes
[
g
.
number_of_nodes
()
-
1
].
data
[
"
hv
"
]
logit
=
self
.
add_edge
(
torch
.
cat
(
logit
=
self
.
add_edge
(
torch
.
cat
([
graph_embed
,
src_embed
],
dim
=
1
))
[
graph_embed
,
src_embed
],
dim
=
1
))
prob
=
torch
.
sigmoid
(
logit
)
prob
=
torch
.
sigmoid
(
logit
)
if
not
self
.
training
:
if
not
self
.
training
:
...
@@ -176,7 +186,7 @@ class ChooseDestAndUpdate(nn.Module):
...
@@ -176,7 +186,7 @@ class ChooseDestAndUpdate(nn.Module):
def
__init__
(
self
,
graph_prop_func
,
node_hidden_size
):
def
__init__
(
self
,
graph_prop_func
,
node_hidden_size
):
super
(
ChooseDestAndUpdate
,
self
).
__init__
()
super
(
ChooseDestAndUpdate
,
self
).
__init__
()
self
.
graph_op
=
{
'
prop
'
:
graph_prop_func
}
self
.
graph_op
=
{
"
prop
"
:
graph_prop_func
}
self
.
choose_dest
=
nn
.
Linear
(
2
*
node_hidden_size
,
1
)
self
.
choose_dest
=
nn
.
Linear
(
2
*
node_hidden_size
,
1
)
def
_initialize_edge_repr
(
self
,
g
,
src_list
,
dest_list
):
def
_initialize_edge_repr
(
self
,
g
,
src_list
,
dest_list
):
...
@@ -184,7 +194,7 @@ class ChooseDestAndUpdate(nn.Module):
...
@@ -184,7 +194,7 @@ class ChooseDestAndUpdate(nn.Module):
# For multiple edge types, we can use a one hot representation
# For multiple edge types, we can use a one hot representation
# or an embedding module.
# or an embedding module.
edge_repr
=
torch
.
ones
(
len
(
src_list
),
1
)
edge_repr
=
torch
.
ones
(
len
(
src_list
),
1
)
g
.
edges
[
src_list
,
dest_list
].
data
[
'
he
'
]
=
edge_repr
g
.
edges
[
src_list
,
dest_list
].
data
[
"
he
"
]
=
edge_repr
def
prepare_training
(
self
):
def
prepare_training
(
self
):
self
.
log_prob
=
[]
self
.
log_prob
=
[]
...
@@ -193,12 +203,12 @@ class ChooseDestAndUpdate(nn.Module):
...
@@ -193,12 +203,12 @@ class ChooseDestAndUpdate(nn.Module):
src
=
g
.
number_of_nodes
()
-
1
src
=
g
.
number_of_nodes
()
-
1
possible_dests
=
range
(
src
)
possible_dests
=
range
(
src
)
src_embed_expand
=
g
.
nodes
[
src
].
data
[
'
hv
'
].
expand
(
src
,
-
1
)
src_embed_expand
=
g
.
nodes
[
src
].
data
[
"
hv
"
].
expand
(
src
,
-
1
)
possible_dests_embed
=
g
.
nodes
[
possible_dests
].
data
[
'
hv
'
]
possible_dests_embed
=
g
.
nodes
[
possible_dests
].
data
[
"
hv
"
]
dests_scores
=
self
.
choose_dest
(
dests_scores
=
self
.
choose_dest
(
torch
.
cat
([
possible_dests_embed
,
torch
.
cat
([
possible_dests_embed
,
src_embed_expand
],
dim
=
1
)
src_embed_expand
],
dim
=
1
)
).
view
(
1
,
-
1
)
).
view
(
1
,
-
1
)
dests_probs
=
F
.
softmax
(
dests_scores
,
dim
=
1
)
dests_probs
=
F
.
softmax
(
dests_scores
,
dim
=
1
)
if
not
self
.
training
:
if
not
self
.
training
:
...
@@ -213,17 +223,17 @@ class ChooseDestAndUpdate(nn.Module):
...
@@ -213,17 +223,17 @@ class ChooseDestAndUpdate(nn.Module):
g
.
add_edges
(
src_list
,
dest_list
)
g
.
add_edges
(
src_list
,
dest_list
)
self
.
_initialize_edge_repr
(
g
,
src_list
,
dest_list
)
self
.
_initialize_edge_repr
(
g
,
src_list
,
dest_list
)
self
.
graph_op
[
'
prop
'
](
g
)
self
.
graph_op
[
"
prop
"
](
g
)
if
self
.
training
:
if
self
.
training
:
if
dests_probs
.
nelement
()
>
1
:
if
dests_probs
.
nelement
()
>
1
:
self
.
log_prob
.
append
(
self
.
log_prob
.
append
(
F
.
log_softmax
(
dests_scores
,
dim
=
1
)[:,
dest
:
dest
+
1
])
F
.
log_softmax
(
dests_scores
,
dim
=
1
)[:,
dest
:
dest
+
1
]
)
class
DGMG
(
nn
.
Module
):
class
DGMG
(
nn
.
Module
):
def
__init__
(
self
,
v_max
,
node_hidden_size
,
def
__init__
(
self
,
v_max
,
node_hidden_size
,
num_prop_rounds
):
num_prop_rounds
):
super
(
DGMG
,
self
).
__init__
()
super
(
DGMG
,
self
).
__init__
()
# Graph configuration
# Graph configuration
...
@@ -233,22 +243,20 @@ class DGMG(nn.Module):
...
@@ -233,22 +243,20 @@ class DGMG(nn.Module):
self
.
graph_embed
=
GraphEmbed
(
node_hidden_size
)
self
.
graph_embed
=
GraphEmbed
(
node_hidden_size
)
# Graph propagation module
# Graph propagation module
self
.
graph_prop
=
GraphProp
(
num_prop_rounds
,
self
.
graph_prop
=
GraphProp
(
num_prop_rounds
,
node_hidden_size
)
node_hidden_size
)
# Actions
# Actions
self
.
add_node_agent
=
AddNode
(
self
.
add_node_agent
=
AddNode
(
self
.
graph_embed
,
node_hidden_size
)
self
.
graph_embed
,
node_hidden_size
)
self
.
add_edge_agent
=
AddEdge
(
self
.
graph_embed
,
node_hidden_size
)
self
.
add_edge_agent
=
AddEdge
(
self
.
graph_embed
,
node_hidden_size
)
self
.
choose_dest_agent
=
ChooseDestAndUpdate
(
self
.
choose_dest_agent
=
ChooseDestAndUpdate
(
self
.
graph_prop
,
node_hidden_size
)
self
.
graph_prop
,
node_hidden_size
)
# Weight initialization
# Weight initialization
self
.
init_weights
()
self
.
init_weights
()
def
init_weights
(
self
):
def
init_weights
(
self
):
from
utils
import
weights_init
,
dgmg_message_weight_init
from
utils
import
dgmg_message_weight_init
,
weights_init
self
.
graph_embed
.
apply
(
weights_init
)
self
.
graph_embed
.
apply
(
weights_init
)
self
.
graph_prop
.
apply
(
weights_init
)
self
.
graph_prop
.
apply
(
weights_init
)
...
@@ -290,9 +298,11 @@ class DGMG(nn.Module):
...
@@ -290,9 +298,11 @@ class DGMG(nn.Module):
self
.
choose_dest_agent
(
self
.
g
,
a
)
self
.
choose_dest_agent
(
self
.
g
,
a
)
def
get_log_prob
(
self
):
def
get_log_prob
(
self
):
return
torch
.
cat
(
self
.
add_node_agent
.
log_prob
).
sum
()
\
return
(
+
torch
.
cat
(
self
.
add_edge_agent
.
log_prob
).
sum
()
\
torch
.
cat
(
self
.
add_node_agent
.
log_prob
).
sum
()
+
torch
.
cat
(
self
.
choose_dest_agent
.
log_prob
).
sum
()
+
torch
.
cat
(
self
.
add_edge_agent
.
log_prob
).
sum
()
+
torch
.
cat
(
self
.
choose_dest_agent
.
log_prob
).
sum
()
)
def
forward_train
(
self
,
actions
):
def
forward_train
(
self
,
actions
):
self
.
prepare_for_train
()
self
.
prepare_for_train
()
...
...
examples/pytorch/diffpool/model/dgl_layers/gnn.py
View file @
704bcaf6
import
dgl.function
as
fn
import
numpy
as
np
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
import
numpy
as
np
from
scipy.linalg
import
block_diag
from
scipy.linalg
import
block_diag
import
dgl.function
as
fn
from
model.loss
import
EntropyLoss
from
..model_utils
import
masked_softmax
from
.aggregator
import
MaxPoolAggregator
,
MeanAggregator
,
LSTMAggregator
from
.aggregator
import
LSTMAggregator
,
MaxPoolAggregator
,
MeanAggregator
from
.bundler
import
Bundler
from
.bundler
import
Bundler
from
..model_utils
import
masked_softmax
from
model.loss
import
EntropyLoss
class
GraphSageLayer
(
nn
.
Module
):
class
GraphSageLayer
(
nn
.
Module
):
...
@@ -18,17 +18,27 @@ class GraphSageLayer(nn.Module):
...
@@ -18,17 +18,27 @@ class GraphSageLayer(nn.Module):
Here, graphsage layer is a reduced function in DGL framework
Here, graphsage layer is a reduced function in DGL framework
"""
"""
def
__init__
(
self
,
in_feats
,
out_feats
,
activation
,
dropout
,
def
__init__
(
aggregator_type
,
bn
=
False
,
bias
=
True
):
self
,
in_feats
,
out_feats
,
activation
,
dropout
,
aggregator_type
,
bn
=
False
,
bias
=
True
,
):
super
(
GraphSageLayer
,
self
).
__init__
()
super
(
GraphSageLayer
,
self
).
__init__
()
self
.
use_bn
=
bn
self
.
use_bn
=
bn
self
.
bundler
=
Bundler
(
in_feats
,
out_feats
,
activation
,
dropout
,
self
.
bundler
=
Bundler
(
bias
=
bias
)
in_feats
,
out_feats
,
activation
,
dropout
,
bias
=
bias
)
self
.
dropout
=
nn
.
Dropout
(
p
=
dropout
)
self
.
dropout
=
nn
.
Dropout
(
p
=
dropout
)
if
aggregator_type
==
"maxpool"
:
if
aggregator_type
==
"maxpool"
:
self
.
aggregator
=
MaxPoolAggregator
(
in_feats
,
in_feats
,
self
.
aggregator
=
MaxPoolAggregator
(
activation
,
bias
)
in_feats
,
in_feats
,
activation
,
bias
)
elif
aggregator_type
==
"lstm"
:
elif
aggregator_type
==
"lstm"
:
self
.
aggregator
=
LSTMAggregator
(
in_feats
,
in_feats
)
self
.
aggregator
=
LSTMAggregator
(
in_feats
,
in_feats
)
else
:
else
:
...
@@ -36,15 +46,14 @@ class GraphSageLayer(nn.Module):
...
@@ -36,15 +46,14 @@ class GraphSageLayer(nn.Module):
def
forward
(
self
,
g
,
h
):
def
forward
(
self
,
g
,
h
):
h
=
self
.
dropout
(
h
)
h
=
self
.
dropout
(
h
)
g
.
ndata
[
'h'
]
=
h
g
.
ndata
[
"h"
]
=
h
if
self
.
use_bn
and
not
hasattr
(
self
,
'
bn
'
):
if
self
.
use_bn
and
not
hasattr
(
self
,
"
bn
"
):
device
=
h
.
device
device
=
h
.
device
self
.
bn
=
nn
.
BatchNorm1d
(
h
.
size
()[
1
]).
to
(
device
)
self
.
bn
=
nn
.
BatchNorm1d
(
h
.
size
()[
1
]).
to
(
device
)
g
.
update_all
(
fn
.
copy_u
(
u
=
'h'
,
out
=
'm'
),
self
.
aggregator
,
g
.
update_all
(
fn
.
copy_u
(
u
=
"h"
,
out
=
"m"
),
self
.
aggregator
,
self
.
bundler
)
self
.
bundler
)
if
self
.
use_bn
:
if
self
.
use_bn
:
h
=
self
.
bn
(
h
)
h
=
self
.
bn
(
h
)
h
=
g
.
ndata
.
pop
(
'h'
)
h
=
g
.
ndata
.
pop
(
"h"
)
return
h
return
h
...
@@ -53,21 +62,36 @@ class GraphSage(nn.Module):
...
@@ -53,21 +62,36 @@ class GraphSage(nn.Module):
Grahpsage network that concatenate several graphsage layer
Grahpsage network that concatenate several graphsage layer
"""
"""
def
__init__
(
self
,
in_feats
,
n_hidden
,
n_classes
,
n_layers
,
activation
,
def
__init__
(
dropout
,
aggregator_type
):
self
,
in_feats
,
n_hidden
,
n_classes
,
n_layers
,
activation
,
dropout
,
aggregator_type
,
):
super
(
GraphSage
,
self
).
__init__
()
super
(
GraphSage
,
self
).
__init__
()
self
.
layers
=
nn
.
ModuleList
()
self
.
layers
=
nn
.
ModuleList
()
# input layer
# input layer
self
.
layers
.
append
(
GraphSageLayer
(
in_feats
,
n_hidden
,
activation
,
dropout
,
self
.
layers
.
append
(
aggregator_type
))
GraphSageLayer
(
in_feats
,
n_hidden
,
activation
,
dropout
,
aggregator_type
)
)
# hidden layers
# hidden layers
for
_
in
range
(
n_layers
-
1
):
for
_
in
range
(
n_layers
-
1
):
self
.
layers
.
append
(
GraphSageLayer
(
n_hidden
,
n_hidden
,
activation
,
self
.
layers
.
append
(
dropout
,
aggregator_type
))
GraphSageLayer
(
n_hidden
,
n_hidden
,
activation
,
dropout
,
aggregator_type
)
)
# output layer
# output layer
self
.
layers
.
append
(
GraphSageLayer
(
n_hidden
,
n_classes
,
None
,
self
.
layers
.
append
(
dropout
,
aggregator_type
))
GraphSageLayer
(
n_hidden
,
n_classes
,
None
,
dropout
,
aggregator_type
)
)
def
forward
(
self
,
g
,
features
):
def
forward
(
self
,
g
,
features
):
h
=
features
h
=
features
...
@@ -77,37 +101,44 @@ class GraphSage(nn.Module):
...
@@ -77,37 +101,44 @@ class GraphSage(nn.Module):
class
DiffPoolBatchedGraphLayer
(
nn
.
Module
):
class
DiffPoolBatchedGraphLayer
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
input_dim
,
assign_dim
,
output_feat_dim
,
self
,
activation
,
dropout
,
aggregator_type
,
link_pred
):
input_dim
,
assign_dim
,
output_feat_dim
,
activation
,
dropout
,
aggregator_type
,
link_pred
,
):
super
(
DiffPoolBatchedGraphLayer
,
self
).
__init__
()
super
(
DiffPoolBatchedGraphLayer
,
self
).
__init__
()
self
.
embedding_dim
=
input_dim
self
.
embedding_dim
=
input_dim
self
.
assign_dim
=
assign_dim
self
.
assign_dim
=
assign_dim
self
.
hidden_dim
=
output_feat_dim
self
.
hidden_dim
=
output_feat_dim
self
.
link_pred
=
link_pred
self
.
link_pred
=
link_pred
self
.
feat_gc
=
GraphSageLayer
(
self
.
feat_gc
=
GraphSageLayer
(
input_dim
,
input_dim
,
output_feat_dim
,
activation
,
dropout
,
aggregator_type
output_feat_dim
,
)
activation
,
dropout
,
aggregator_type
)
self
.
pool_gc
=
GraphSageLayer
(
self
.
pool_gc
=
GraphSageLayer
(
input_dim
,
input_dim
,
assign_dim
,
activation
,
dropout
,
aggregator_type
assign_dim
,
)
activation
,
dropout
,
aggregator_type
)
self
.
reg_loss
=
nn
.
ModuleList
([])
self
.
reg_loss
=
nn
.
ModuleList
([])
self
.
loss_log
=
{}
self
.
loss_log
=
{}
self
.
reg_loss
.
append
(
EntropyLoss
())
self
.
reg_loss
.
append
(
EntropyLoss
())
def
forward
(
self
,
g
,
h
):
def
forward
(
self
,
g
,
h
):
feat
=
self
.
feat_gc
(
g
,
h
)
# size = (sum_N, F_out), sum_N is num of nodes in this batch
feat
=
self
.
feat_gc
(
g
,
h
)
# size = (sum_N, F_out), sum_N is num of nodes in this batch
device
=
feat
.
device
device
=
feat
.
device
assign_tensor
=
self
.
pool_gc
(
g
,
h
)
# size = (sum_N, N_a), N_a is num of nodes in pooled graph.
assign_tensor
=
self
.
pool_gc
(
g
,
h
)
# size = (sum_N, N_a), N_a is num of nodes in pooled graph.
assign_tensor
=
F
.
softmax
(
assign_tensor
,
dim
=
1
)
assign_tensor
=
F
.
softmax
(
assign_tensor
,
dim
=
1
)
assign_tensor
=
torch
.
split
(
assign_tensor
,
g
.
batch_num_nodes
().
tolist
())
assign_tensor
=
torch
.
split
(
assign_tensor
,
g
.
batch_num_nodes
().
tolist
())
assign_tensor
=
torch
.
block_diag
(
*
assign_tensor
)
# size = (sum_N, batch_size * N_a)
assign_tensor
=
torch
.
block_diag
(
*
assign_tensor
)
# size = (sum_N, batch_size * N_a)
h
=
torch
.
matmul
(
torch
.
t
(
assign_tensor
),
feat
)
h
=
torch
.
matmul
(
torch
.
t
(
assign_tensor
),
feat
)
adj
=
g
.
adjacency_matrix
(
transpose
=
True
,
ctx
=
device
)
adj
=
g
.
adjacency_matrix
(
transpose
=
True
,
ctx
=
device
)
...
@@ -115,9 +146,10 @@ class DiffPoolBatchedGraphLayer(nn.Module):
...
@@ -115,9 +146,10 @@ class DiffPoolBatchedGraphLayer(nn.Module):
adj_new
=
torch
.
mm
(
torch
.
t
(
assign_tensor
),
adj_new
)
adj_new
=
torch
.
mm
(
torch
.
t
(
assign_tensor
),
adj_new
)
if
self
.
link_pred
:
if
self
.
link_pred
:
current_lp_loss
=
torch
.
norm
(
adj
.
to_dense
()
-
current_lp_loss
=
torch
.
norm
(
torch
.
mm
(
assign_tensor
,
torch
.
t
(
assign_tensor
)))
/
np
.
power
(
g
.
number_of_nodes
(),
2
)
adj
.
to_dense
()
-
torch
.
mm
(
assign_tensor
,
torch
.
t
(
assign_tensor
))
self
.
loss_log
[
'LinkPredLoss'
]
=
current_lp_loss
)
/
np
.
power
(
g
.
number_of_nodes
(),
2
)
self
.
loss_log
[
"LinkPredLoss"
]
=
current_lp_loss
for
loss_layer
in
self
.
reg_loss
:
for
loss_layer
in
self
.
reg_loss
:
loss_name
=
str
(
type
(
loss_layer
).
__name__
)
loss_name
=
str
(
type
(
loss_layer
).
__name__
)
...
...
examples/pytorch/diffpool/model/encoder.py
View file @
704bcaf6
import
time
import
time
import
dgl
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -7,8 +9,6 @@ import torch.nn.functional as F
...
@@ -7,8 +9,6 @@ import torch.nn.functional as F
from
scipy.linalg
import
block_diag
from
scipy.linalg
import
block_diag
from
torch.nn
import
init
from
torch.nn
import
init
import
dgl
from
.dgl_layers
import
DiffPoolBatchedGraphLayer
,
GraphSage
,
GraphSageLayer
from
.dgl_layers
import
DiffPoolBatchedGraphLayer
,
GraphSage
,
GraphSageLayer
from
.model_utils
import
batch2tensor
from
.model_utils
import
batch2tensor
from
.tensorized_layers
import
*
from
.tensorized_layers
import
*
...
@@ -91,7 +91,6 @@ class DiffPool(nn.Module):
...
@@ -91,7 +91,6 @@ class DiffPool(nn.Module):
# and return pool_embedding_dim node embedding
# and return pool_embedding_dim node embedding
pool_embedding_dim
=
hidden_dim
*
(
n_layers
-
1
)
+
embedding_dim
pool_embedding_dim
=
hidden_dim
*
(
n_layers
-
1
)
+
embedding_dim
else
:
else
:
pool_embedding_dim
=
embedding_dim
pool_embedding_dim
=
embedding_dim
self
.
first_diffpool_layer
=
DiffPoolBatchedGraphLayer
(
self
.
first_diffpool_layer
=
DiffPoolBatchedGraphLayer
(
...
...
examples/pytorch/diffpool/model/tensorized_layers/assignment.py
View file @
704bcaf6
import
torch
import
torch
from
model.tensorized_layers.graphsage
import
BatchedGraphSAGE
from
torch
import
nn
as
nn
from
torch
import
nn
as
nn
from
torch.autograd
import
Variable
from
torch.autograd
import
Variable
from
torch.nn
import
functional
as
F
from
torch.nn
import
functional
as
F
from
model.tensorized_layers.graphsage
import
BatchedGraphSAGE
class
DiffPoolAssignment
(
nn
.
Module
):
class
DiffPoolAssignment
(
nn
.
Module
):
def
__init__
(
self
,
nfeat
,
nnext
):
def
__init__
(
self
,
nfeat
,
nnext
):
...
...
examples/pytorch/diffpool/model/tensorized_layers/diffpool.py
View file @
704bcaf6
import
torch
import
torch
from
torch
import
nn
as
nn
from
model.loss
import
EntropyLoss
,
LinkPredLoss
from
model.loss
import
EntropyLoss
,
LinkPredLoss
from
model.tensorized_layers.assignment
import
DiffPoolAssignment
from
model.tensorized_layers.assignment
import
DiffPoolAssignment
from
model.tensorized_layers.graphsage
import
BatchedGraphSAGE
from
model.tensorized_layers.graphsage
import
BatchedGraphSAGE
from
torch
import
nn
as
nn
class
BatchedDiffPool
(
nn
.
Module
):
class
BatchedDiffPool
(
nn
.
Module
):
...
...
examples/pytorch/diffpool/train.py
View file @
704bcaf6
...
@@ -3,6 +3,9 @@ import os
...
@@ -3,6 +3,9 @@ import os
import
random
import
random
import
time
import
time
import
dgl
import
dgl.function
as
fn
import
networkx
as
nx
import
networkx
as
nx
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -10,12 +13,9 @@ import torch.nn as nn
...
@@ -10,12 +13,9 @@ import torch.nn as nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
import
torch.utils.data
import
torch.utils.data
from
data_utils
import
pre_process
from
data_utils
import
pre_process
from
model.encoder
import
DiffPool
import
dgl
import
dgl.function
as
fn
from
dgl
import
DGLGraph
from
dgl
import
DGLGraph
from
dgl.data
import
tu
from
dgl.data
import
tu
from
model.encoder
import
DiffPool
global_train_time_per_epoch
=
[]
global_train_time_per_epoch
=
[]
...
@@ -261,8 +261,8 @@ def train(dataset, model, prog_args, same_feat=True, val_dataset=None):
...
@@ -261,8 +261,8 @@ def train(dataset, model, prog_args, same_feat=True, val_dataset=None):
total
=
0
total
=
0
print
(
"
\n
EPOCH ###### {} ######"
.
format
(
epoch
))
print
(
"
\n
EPOCH ###### {} ######"
.
format
(
epoch
))
computation_time
=
0.0
computation_time
=
0.0
for
(
batch_idx
,
(
batch_graph
,
graph_labels
)
)
in
enumerate
(
dataloader
):
for
batch_idx
,
(
batch_graph
,
graph_labels
)
in
enumerate
(
dataloader
):
for
(
key
,
value
)
in
batch_graph
.
ndata
.
items
():
for
key
,
value
in
batch_graph
.
ndata
.
items
():
batch_graph
.
ndata
[
key
]
=
value
.
float
()
batch_graph
.
ndata
[
key
]
=
value
.
float
()
graph_labels
=
graph_labels
.
long
()
graph_labels
=
graph_labels
.
long
()
if
torch
.
cuda
.
is_available
():
if
torch
.
cuda
.
is_available
():
...
@@ -341,7 +341,7 @@ def evaluate(dataloader, model, prog_args, logger=None):
...
@@ -341,7 +341,7 @@ def evaluate(dataloader, model, prog_args, logger=None):
correct_label
=
0
correct_label
=
0
with
torch
.
no_grad
():
with
torch
.
no_grad
():
for
batch_idx
,
(
batch_graph
,
graph_labels
)
in
enumerate
(
dataloader
):
for
batch_idx
,
(
batch_graph
,
graph_labels
)
in
enumerate
(
dataloader
):
for
(
key
,
value
)
in
batch_graph
.
ndata
.
items
():
for
key
,
value
in
batch_graph
.
ndata
.
items
():
batch_graph
.
ndata
[
key
]
=
value
.
float
()
batch_graph
.
ndata
[
key
]
=
value
.
float
()
graph_labels
=
graph_labels
.
long
()
graph_labels
=
graph_labels
.
long
()
if
torch
.
cuda
.
is_available
():
if
torch
.
cuda
.
is_available
():
...
...
examples/pytorch/dimenet/main.py
View file @
704bcaf6
...
@@ -2,11 +2,14 @@ import copy
...
@@ -2,11 +2,14 @@ import copy
from
pathlib
import
Path
from
pathlib
import
Path
import
click
import
click
import
dgl
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
import
torch.optim
as
optim
import
torch.optim
as
optim
from
dgl.data.utils
import
Subset
from
logzero
import
logger
from
logzero
import
logger
from
modules.dimenet
import
DimeNet
from
modules.dimenet
import
DimeNet
from
modules.dimenet_pp
import
DimeNetPP
from
modules.dimenet_pp
import
DimeNetPP
...
@@ -16,9 +19,6 @@ from ruamel.yaml import YAML
...
@@ -16,9 +19,6 @@ from ruamel.yaml import YAML
from
sklearn.metrics
import
mean_absolute_error
from
sklearn.metrics
import
mean_absolute_error
from
torch.utils.data
import
DataLoader
from
torch.utils.data
import
DataLoader
import
dgl
from
dgl.data.utils
import
Subset
def
split_dataset
(
def
split_dataset
(
dataset
,
num_train
,
num_valid
,
shuffle
=
False
,
random_state
=
None
dataset
,
num_train
,
num_valid
,
shuffle
=
False
,
random_state
=
None
...
...
examples/pytorch/dimenet/modules/interaction_block.py
View file @
704bcaf6
import
dgl.function
as
fn
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
modules.initializers
import
GlorotOrthogonal
from
modules.initializers
import
GlorotOrthogonal
from
modules.residual_layer
import
ResidualLayer
from
modules.residual_layer
import
ResidualLayer
import
dgl.function
as
fn
class
InteractionBlock
(
nn
.
Module
):
class
InteractionBlock
(
nn
.
Module
):
def
__init__
(
def
__init__
(
...
...
examples/pytorch/dimenet/modules/interaction_pp_block.py
View file @
704bcaf6
import
dgl
import
dgl.function
as
fn
import
torch.nn
as
nn
import
torch.nn
as
nn
from
modules.initializers
import
GlorotOrthogonal
from
modules.initializers
import
GlorotOrthogonal
from
modules.residual_layer
import
ResidualLayer
from
modules.residual_layer
import
ResidualLayer
import
dgl
import
dgl.function
as
fn
class
InteractionPPBlock
(
nn
.
Module
):
class
InteractionPPBlock
(
nn
.
Module
):
def
__init__
(
def
__init__
(
...
...
examples/pytorch/dimenet/modules/output_block.py
View file @
704bcaf6
import
torch.nn
as
nn
from
modules.initializers
import
GlorotOrthogonal
import
dgl
import
dgl
import
dgl.function
as
fn
import
dgl.function
as
fn
import
torch.nn
as
nn
from
modules.initializers
import
GlorotOrthogonal
class
OutputBlock
(
nn
.
Module
):
class
OutputBlock
(
nn
.
Module
):
...
...
examples/pytorch/dimenet/modules/output_pp_block.py
View file @
704bcaf6
import
torch.nn
as
nn
from
modules.initializers
import
GlorotOrthogonal
import
dgl
import
dgl
import
dgl.function
as
fn
import
dgl.function
as
fn
import
torch.nn
as
nn
from
modules.initializers
import
GlorotOrthogonal
class
OutputPPBlock
(
nn
.
Module
):
class
OutputPPBlock
(
nn
.
Module
):
...
...
Prev
1
2
3
4
5
6
7
8
…
17
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment