Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
704bcaf6
Unverified
Commit
704bcaf6
authored
Feb 19, 2023
by
Hongzhi (Steve), Chen
Committed by
GitHub
Feb 19, 2023
Browse files
examples (#5323)
Co-authored-by:
Ubuntu
<
ubuntu@ip-172-31-28-63.ap-northeast-1.compute.internal
>
parent
6bc82161
Changes
332
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
352 additions
and
178 deletions
+352
-178
examples/pytorch/graphsim/train.py
examples/pytorch/graphsim/train.py
+3
-3
examples/pytorch/graphwriter/graphwriter.py
examples/pytorch/graphwriter/graphwriter.py
+1
-1
examples/pytorch/graphwriter/utlis.py
examples/pytorch/graphwriter/utlis.py
+2
-2
examples/pytorch/gxn/layers.py
examples/pytorch/gxn/layers.py
+3
-3
examples/pytorch/gxn/main.py
examples/pytorch/gxn/main.py
+4
-4
examples/pytorch/gxn/main_early_stop.py
examples/pytorch/gxn/main_early_stop.py
+4
-4
examples/pytorch/gxn/networks.py
examples/pytorch/gxn/networks.py
+188
-76
examples/pytorch/han/model_hetero.py
examples/pytorch/han/model_hetero.py
+1
-2
examples/pytorch/han/train_sampling.py
examples/pytorch/han/train_sampling.py
+4
-4
examples/pytorch/han/utils.py
examples/pytorch/han/utils.py
+3
-4
examples/pytorch/hardgat/hgao.py
examples/pytorch/hardgat/hgao.py
+3
-3
examples/pytorch/hardgat/train.py
examples/pytorch/hardgat/train.py
+4
-5
examples/pytorch/hgp_sl/functions.py
examples/pytorch/hgp_sl/functions.py
+3
-4
examples/pytorch/hgp_sl/layers.py
examples/pytorch/hgp_sl/layers.py
+111
-54
examples/pytorch/hgp_sl/main.py
examples/pytorch/hgp_sl/main.py
+4
-4
examples/pytorch/hgp_sl/networks.py
examples/pytorch/hgp_sl/networks.py
+1
-1
examples/pytorch/hgt/model.py
examples/pytorch/hgt/model.py
+3
-3
examples/pytorch/hilander/PSS/Smooth_AP/src/auxiliaries.py
examples/pytorch/hilander/PSS/Smooth_AP/src/auxiliaries.py
+8
-0
examples/pytorch/hilander/PSS/Smooth_AP/src/datasets.py
examples/pytorch/hilander/PSS/Smooth_AP/src/datasets.py
+2
-0
examples/pytorch/hilander/PSS/Smooth_AP/src/evaluate_model.py
...ples/pytorch/hilander/PSS/Smooth_AP/src/evaluate_model.py
+0
-1
No files found.
examples/pytorch/graphsim/train.py
View file @
704bcaf6
...
@@ -2,6 +2,8 @@ import argparse
...
@@ -2,6 +2,8 @@ import argparse
import
time
import
time
import
traceback
import
traceback
import
dgl
import
networkx
as
nx
import
networkx
as
nx
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -11,12 +13,10 @@ from dataloader import (
...
@@ -11,12 +13,10 @@ from dataloader import (
MultiBodyTrainDataset
,
MultiBodyTrainDataset
,
MultiBodyValidDataset
,
MultiBodyValidDataset
,
)
)
from
models
import
MLP
,
InteractionNet
,
PrepareLayer
from
models
import
InteractionNet
,
MLP
,
PrepareLayer
from
torch.utils.data
import
DataLoader
from
torch.utils.data
import
DataLoader
from
utils
import
make_video
from
utils
import
make_video
import
dgl
def
train
(
def
train
(
optimizer
,
loss_fn
,
reg_fn
,
model
,
prep
,
dataloader
,
lambda_reg
,
device
optimizer
,
loss_fn
,
reg_fn
,
model
,
prep
,
dataloader
,
lambda_reg
,
device
...
...
examples/pytorch/graphwriter/graphwriter.py
View file @
704bcaf6
import
torch
import
torch
from
modules
import
MSA
,
BiLSTM
,
GraphTrans
from
modules
import
BiLSTM
,
GraphTrans
,
MSA
from
torch
import
nn
from
torch
import
nn
from
utlis
import
*
from
utlis
import
*
...
...
examples/pytorch/graphwriter/utlis.py
View file @
704bcaf6
...
@@ -2,11 +2,11 @@ import json
...
@@ -2,11 +2,11 @@ import json
import
pickle
import
pickle
import
random
import
random
import
dgl
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
import
dgl
NODE_TYPE
=
{
"entity"
:
0
,
"root"
:
1
,
"relation"
:
2
}
NODE_TYPE
=
{
"entity"
:
0
,
"root"
:
1
,
"relation"
:
2
}
...
...
examples/pytorch/gxn/layers.py
View file @
704bcaf6
from
typing
import
Optional
from
typing
import
Optional
import
dgl
import
torch
import
torch
import
torch.nn
import
torch.nn
from
torch
import
Tensor
import
dgl
from
dgl
import
DGLGraph
from
dgl
import
DGLGraph
from
dgl.nn
import
GraphConv
from
dgl.nn
import
GraphConv
from
torch
import
Tensor
class
GraphConvWithDropout
(
GraphConv
):
class
GraphConvWithDropout
(
GraphConv
):
...
...
examples/pytorch/gxn/main.py
View file @
704bcaf6
...
@@ -3,18 +3,18 @@ import os
...
@@ -3,18 +3,18 @@ import os
from
datetime
import
datetime
from
datetime
import
datetime
from
time
import
time
from
time
import
time
import
dgl
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
data_preprocess
import
degree_as_feature
,
node_label_as_feature
from
data_preprocess
import
degree_as_feature
,
node_label_as_feature
from
dgl.data
import
LegacyTUDataset
from
dgl.dataloading
import
GraphDataLoader
from
networks
import
GraphClassifier
from
networks
import
GraphClassifier
from
torch
import
Tensor
from
torch
import
Tensor
from
torch.utils.data
import
random_split
from
torch.utils.data
import
random_split
from
utils
import
get_stats
,
parse_args
from
utils
import
get_stats
,
parse_args
import
dgl
from
dgl.data
import
LegacyTUDataset
from
dgl.dataloading
import
GraphDataLoader
def
compute_loss
(
def
compute_loss
(
cls_logits
:
Tensor
,
cls_logits
:
Tensor
,
...
...
examples/pytorch/gxn/main_early_stop.py
View file @
704bcaf6
...
@@ -3,18 +3,18 @@ import os
...
@@ -3,18 +3,18 @@ import os
from
datetime
import
datetime
from
datetime
import
datetime
from
time
import
time
from
time
import
time
import
dgl
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
data_preprocess
import
degree_as_feature
,
node_label_as_feature
from
data_preprocess
import
degree_as_feature
,
node_label_as_feature
from
dgl.data
import
LegacyTUDataset
from
dgl.dataloading
import
GraphDataLoader
from
networks
import
GraphClassifier
from
networks
import
GraphClassifier
from
torch
import
Tensor
from
torch
import
Tensor
from
torch.utils.data
import
random_split
from
torch.utils.data
import
random_split
from
utils
import
get_stats
,
parse_args
from
utils
import
get_stats
,
parse_args
import
dgl
from
dgl.data
import
LegacyTUDataset
from
dgl.dataloading
import
GraphDataLoader
def
compute_loss
(
def
compute_loss
(
cls_logits
:
Tensor
,
cls_logits
:
Tensor
,
...
...
examples/pytorch/gxn/networks.py
View file @
704bcaf6
from
typing
import
List
,
Tuple
,
Union
from
typing
import
List
,
Tuple
,
Union
from
layers
import
*
from
layers
import
*
import
dgl.function
as
fn
import
torch
import
torch
import
torch.nn
import
torch.nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
import
dgl.function
as
fn
from
dgl.nn.pytorch.glob
import
SortPooling
from
dgl.nn.pytorch.glob
import
SortPooling
...
@@ -35,9 +35,18 @@ class GraphCrossModule(torch.nn.Module):
...
@@ -35,9 +35,18 @@ class GraphCrossModule(torch.nn.Module):
The weight parameter used at the end of GXN for channel fusion.
The weight parameter used at the end of GXN for channel fusion.
Default: :obj:`1.0`
Default: :obj:`1.0`
"""
"""
def
__init__
(
self
,
pool_ratios
:
Union
[
float
,
List
[
float
]],
in_dim
:
int
,
out_dim
:
int
,
hidden_dim
:
int
,
cross_weight
:
float
=
1.
,
def
__init__
(
fuse_weight
:
float
=
1.
,
dist
:
int
=
1
,
num_cross_layers
:
int
=
2
):
self
,
pool_ratios
:
Union
[
float
,
List
[
float
]],
in_dim
:
int
,
out_dim
:
int
,
hidden_dim
:
int
,
cross_weight
:
float
=
1.0
,
fuse_weight
:
float
=
1.0
,
dist
:
int
=
1
,
num_cross_layers
:
int
=
2
,
):
super
(
GraphCrossModule
,
self
).
__init__
()
super
(
GraphCrossModule
,
self
).
__init__
()
if
isinstance
(
pool_ratios
,
float
):
if
isinstance
(
pool_ratios
,
float
):
pool_ratios
=
(
pool_ratios
,
pool_ratios
)
pool_ratios
=
(
pool_ratios
,
pool_ratios
)
...
@@ -50,8 +59,12 @@ class GraphCrossModule(torch.nn.Module):
...
@@ -50,8 +59,12 @@ class GraphCrossModule(torch.nn.Module):
self
.
start_gcn_scale2
=
GraphConvWithDropout
(
hidden_dim
,
hidden_dim
)
self
.
start_gcn_scale2
=
GraphConvWithDropout
(
hidden_dim
,
hidden_dim
)
self
.
end_gcn
=
GraphConvWithDropout
(
2
*
hidden_dim
,
out_dim
)
self
.
end_gcn
=
GraphConvWithDropout
(
2
*
hidden_dim
,
out_dim
)
self
.
index_select_scale1
=
IndexSelect
(
pool_ratios
[
0
],
hidden_dim
,
act
=
"prelu"
,
dist
=
dist
)
self
.
index_select_scale1
=
IndexSelect
(
self
.
index_select_scale2
=
IndexSelect
(
pool_ratios
[
1
],
hidden_dim
,
act
=
"prelu"
,
dist
=
dist
)
pool_ratios
[
0
],
hidden_dim
,
act
=
"prelu"
,
dist
=
dist
)
self
.
index_select_scale2
=
IndexSelect
(
pool_ratios
[
1
],
hidden_dim
,
act
=
"prelu"
,
dist
=
dist
)
self
.
start_pool_s12
=
GraphPool
(
hidden_dim
)
self
.
start_pool_s12
=
GraphPool
(
hidden_dim
)
self
.
start_pool_s23
=
GraphPool
(
hidden_dim
)
self
.
start_pool_s23
=
GraphPool
(
hidden_dim
)
self
.
end_unpool_s21
=
GraphUnpool
(
hidden_dim
)
self
.
end_unpool_s21
=
GraphUnpool
(
hidden_dim
)
...
@@ -85,72 +98,139 @@ class GraphCrossModule(torch.nn.Module):
...
@@ -85,72 +98,139 @@ class GraphCrossModule(torch.nn.Module):
graph_scale1
=
graph
graph_scale1
=
graph
feat_scale1
=
self
.
start_gcn_scale1
(
graph_scale1
,
feat
)
feat_scale1
=
self
.
start_gcn_scale1
(
graph_scale1
,
feat
)
feat_origin
=
feat_scale1
feat_origin
=
feat_scale1
feat_scale1_neg
=
feat_scale1
[
torch
.
randperm
(
feat_scale1
.
size
(
0
))]
# negative samples
feat_scale1_neg
=
feat_scale1
[
logit_s1
,
scores_s1
,
select_idx_s1
,
non_select_idx_s1
,
feat_down_s1
=
\
torch
.
randperm
(
feat_scale1
.
size
(
0
))
self
.
index_select_scale1
(
graph_scale1
,
feat_scale1
,
feat_scale1_neg
)
]
# negative samples
feat_scale2
,
graph_scale2
=
self
.
start_pool_s12
(
graph_scale1
,
feat_scale1
,
(
select_idx_s1
,
non_select_idx_s1
,
logit_s1
,
scores_s1
,
pool_graph
=
True
)
scores_s1
,
select_idx_s1
,
non_select_idx_s1
,
feat_down_s1
,
)
=
self
.
index_select_scale1
(
graph_scale1
,
feat_scale1
,
feat_scale1_neg
)
feat_scale2
,
graph_scale2
=
self
.
start_pool_s12
(
graph_scale1
,
feat_scale1
,
select_idx_s1
,
non_select_idx_s1
,
scores_s1
,
pool_graph
=
True
,
)
# start of scale-2
# start of scale-2
feat_scale2
=
self
.
start_gcn_scale2
(
graph_scale2
,
feat_scale2
)
feat_scale2
=
self
.
start_gcn_scale2
(
graph_scale2
,
feat_scale2
)
feat_scale2_neg
=
feat_scale2
[
torch
.
randperm
(
feat_scale2
.
size
(
0
))]
# negative samples
feat_scale2_neg
=
feat_scale2
[
logit_s2
,
scores_s2
,
select_idx_s2
,
non_select_idx_s2
,
feat_down_s2
=
\
torch
.
randperm
(
feat_scale2
.
size
(
0
))
self
.
index_select_scale2
(
graph_scale2
,
feat_scale2
,
feat_scale2_neg
)
]
# negative samples
feat_scale3
,
graph_scale3
=
self
.
start_pool_s23
(
graph_scale2
,
feat_scale2
,
(
select_idx_s2
,
non_select_idx_s2
,
logit_s2
,
scores_s2
,
pool_graph
=
True
)
scores_s2
,
select_idx_s2
,
non_select_idx_s2
,
feat_down_s2
,
)
=
self
.
index_select_scale2
(
graph_scale2
,
feat_scale2
,
feat_scale2_neg
)
feat_scale3
,
graph_scale3
=
self
.
start_pool_s23
(
graph_scale2
,
feat_scale2
,
select_idx_s2
,
non_select_idx_s2
,
scores_s2
,
pool_graph
=
True
,
)
# layer-1
# layer-1
res_s1_0
,
res_s2_0
,
res_s3_0
=
feat_scale1
,
feat_scale2
,
feat_scale3
res_s1_0
,
res_s2_0
,
res_s3_0
=
feat_scale1
,
feat_scale2
,
feat_scale3
feat_scale1
=
F
.
relu
(
self
.
s1_l1_gcn
(
graph_scale1
,
feat_scale1
))
feat_scale1
=
F
.
relu
(
self
.
s1_l1_gcn
(
graph_scale1
,
feat_scale1
))
feat_scale2
=
F
.
relu
(
self
.
s2_l1_gcn
(
graph_scale2
,
feat_scale2
))
feat_scale2
=
F
.
relu
(
self
.
s2_l1_gcn
(
graph_scale2
,
feat_scale2
))
feat_scale3
=
F
.
relu
(
self
.
s3_l1_gcn
(
graph_scale3
,
feat_scale3
))
feat_scale3
=
F
.
relu
(
self
.
s3_l1_gcn
(
graph_scale3
,
feat_scale3
))
if
self
.
num_cross_layers
>=
1
:
if
self
.
num_cross_layers
>=
1
:
feat_s12_fu
=
self
.
pool_s12_1
(
graph_scale1
,
feat_scale1
,
feat_s12_fu
=
self
.
pool_s12_1
(
select_idx_s1
,
non_select_idx_s1
,
graph_scale1
,
scores_s1
)
feat_scale1
,
feat_s21_fu
=
self
.
unpool_s21_1
(
graph_scale1
,
feat_scale2
,
select_idx_s1
)
select_idx_s1
,
feat_s23_fu
=
self
.
pool_s23_1
(
graph_scale2
,
feat_scale2
,
non_select_idx_s1
,
select_idx_s2
,
non_select_idx_s2
,
scores_s1
,
scores_s2
)
)
feat_s32_fu
=
self
.
unpool_s32_1
(
graph_scale2
,
feat_scale3
,
select_idx_s2
)
feat_s21_fu
=
self
.
unpool_s21_1
(
graph_scale1
,
feat_scale2
,
select_idx_s1
feat_scale1
=
feat_scale1
+
self
.
cross_weight
*
feat_s21_fu
+
res_s1_0
)
feat_scale2
=
feat_scale2
+
self
.
cross_weight
*
(
feat_s12_fu
+
feat_s32_fu
)
/
2
+
res_s2_0
feat_s23_fu
=
self
.
pool_s23_1
(
feat_scale3
=
feat_scale3
+
self
.
cross_weight
*
feat_s23_fu
+
res_s3_0
graph_scale2
,
feat_scale2
,
select_idx_s2
,
non_select_idx_s2
,
scores_s2
,
)
feat_s32_fu
=
self
.
unpool_s32_1
(
graph_scale2
,
feat_scale3
,
select_idx_s2
)
feat_scale1
=
(
feat_scale1
+
self
.
cross_weight
*
feat_s21_fu
+
res_s1_0
)
feat_scale2
=
(
feat_scale2
+
self
.
cross_weight
*
(
feat_s12_fu
+
feat_s32_fu
)
/
2
+
res_s2_0
)
feat_scale3
=
(
feat_scale3
+
self
.
cross_weight
*
feat_s23_fu
+
res_s3_0
)
# layer-2
# layer-2
feat_scale1
=
F
.
relu
(
self
.
s1_l2_gcn
(
graph_scale1
,
feat_scale1
))
feat_scale1
=
F
.
relu
(
self
.
s1_l2_gcn
(
graph_scale1
,
feat_scale1
))
feat_scale2
=
F
.
relu
(
self
.
s2_l2_gcn
(
graph_scale2
,
feat_scale2
))
feat_scale2
=
F
.
relu
(
self
.
s2_l2_gcn
(
graph_scale2
,
feat_scale2
))
feat_scale3
=
F
.
relu
(
self
.
s3_l2_gcn
(
graph_scale3
,
feat_scale3
))
feat_scale3
=
F
.
relu
(
self
.
s3_l2_gcn
(
graph_scale3
,
feat_scale3
))
if
self
.
num_cross_layers
>=
2
:
if
self
.
num_cross_layers
>=
2
:
feat_s12_fu
=
self
.
pool_s12_2
(
graph_scale1
,
feat_scale1
,
feat_s12_fu
=
self
.
pool_s12_2
(
select_idx_s1
,
non_select_idx_s1
,
graph_scale1
,
scores_s1
)
feat_scale1
,
feat_s21_fu
=
self
.
unpool_s21_2
(
graph_scale1
,
feat_scale2
,
select_idx_s1
)
select_idx_s1
,
feat_s23_fu
=
self
.
pool_s23_2
(
graph_scale2
,
feat_scale2
,
non_select_idx_s1
,
select_idx_s2
,
non_select_idx_s2
,
scores_s1
,
scores_s2
)
)
feat_s32_fu
=
self
.
unpool_s32_2
(
graph_scale2
,
feat_scale3
,
select_idx_s2
)
feat_s21_fu
=
self
.
unpool_s21_2
(
graph_scale1
,
feat_scale2
,
select_idx_s1
)
feat_s23_fu
=
self
.
pool_s23_2
(
graph_scale2
,
feat_scale2
,
select_idx_s2
,
non_select_idx_s2
,
scores_s2
,
)
feat_s32_fu
=
self
.
unpool_s32_2
(
graph_scale2
,
feat_scale3
,
select_idx_s2
)
cross_weight
=
self
.
cross_weight
*
0.05
cross_weight
=
self
.
cross_weight
*
0.05
feat_scale1
=
feat_scale1
+
cross_weight
*
feat_s21_fu
feat_scale1
=
feat_scale1
+
cross_weight
*
feat_s21_fu
feat_scale2
=
feat_scale2
+
cross_weight
*
(
feat_s12_fu
+
feat_s32_fu
)
/
2
feat_scale2
=
(
feat_scale2
+
cross_weight
*
(
feat_s12_fu
+
feat_s32_fu
)
/
2
)
feat_scale3
=
feat_scale3
+
cross_weight
*
feat_s23_fu
feat_scale3
=
feat_scale3
+
cross_weight
*
feat_s23_fu
# layer-3
# layer-3
feat_scale1
=
F
.
relu
(
self
.
s1_l3_gcn
(
graph_scale1
,
feat_scale1
))
feat_scale1
=
F
.
relu
(
self
.
s1_l3_gcn
(
graph_scale1
,
feat_scale1
))
feat_scale2
=
F
.
relu
(
self
.
s2_l3_gcn
(
graph_scale2
,
feat_scale2
))
feat_scale2
=
F
.
relu
(
self
.
s2_l3_gcn
(
graph_scale2
,
feat_scale2
))
feat_scale3
=
F
.
relu
(
self
.
s3_l3_gcn
(
graph_scale3
,
feat_scale3
))
feat_scale3
=
F
.
relu
(
self
.
s3_l3_gcn
(
graph_scale3
,
feat_scale3
))
# final layers
# final layers
feat_s3_out
=
self
.
end_unpool_s32
(
graph_scale2
,
feat_scale3
,
select_idx_s2
)
+
feat_down_s2
feat_s3_out
=
(
feat_s2_out
=
self
.
end_unpool_s21
(
graph_scale1
,
feat_scale2
+
feat_s3_out
,
select_idx_s1
)
self
.
end_unpool_s32
(
graph_scale2
,
feat_scale3
,
select_idx_s2
)
feat_agg
=
feat_scale1
+
self
.
fuse_weight
*
feat_s2_out
+
self
.
fuse_weight
*
feat_down_s1
+
feat_down_s2
)
feat_s2_out
=
self
.
end_unpool_s21
(
graph_scale1
,
feat_scale2
+
feat_s3_out
,
select_idx_s1
)
feat_agg
=
(
feat_scale1
+
self
.
fuse_weight
*
feat_s2_out
+
self
.
fuse_weight
*
feat_down_s1
)
feat_agg
=
torch
.
cat
((
feat_agg
,
feat_origin
),
dim
=
1
)
feat_agg
=
torch
.
cat
((
feat_agg
,
feat_origin
),
dim
=
1
)
feat_agg
=
self
.
end_gcn
(
graph_scale1
,
feat_agg
)
feat_agg
=
self
.
end_gcn
(
graph_scale1
,
feat_agg
)
...
@@ -198,11 +278,21 @@ class GraphCrossNet(torch.nn.Module):
...
@@ -198,11 +278,21 @@ class GraphCrossNet(torch.nn.Module):
The weight parameter used at the end of GXN for channel fusion.
The weight parameter used at the end of GXN for channel fusion.
Default: :obj:`1.0`
Default: :obj:`1.0`
"""
"""
def
__init__
(
self
,
in_dim
:
int
,
out_dim
:
int
,
edge_feat_dim
:
int
=
0
,
hidden_dim
:
int
=
96
,
pool_ratios
:
Union
[
List
[
float
],
float
]
=
[
0.9
,
0.7
],
def
__init__
(
readout_nodes
:
int
=
30
,
conv1d_dims
:
List
[
int
]
=
[
16
,
32
],
self
,
conv1d_kws
:
List
[
int
]
=
[
5
],
in_dim
:
int
,
cross_weight
:
float
=
1.
,
fuse_weight
:
float
=
1.
,
dist
:
int
=
1
):
out_dim
:
int
,
edge_feat_dim
:
int
=
0
,
hidden_dim
:
int
=
96
,
pool_ratios
:
Union
[
List
[
float
],
float
]
=
[
0.9
,
0.7
],
readout_nodes
:
int
=
30
,
conv1d_dims
:
List
[
int
]
=
[
16
,
32
],
conv1d_kws
:
List
[
int
]
=
[
5
],
cross_weight
:
float
=
1.0
,
fuse_weight
:
float
=
1.0
,
dist
:
int
=
1
,
):
super
(
GraphCrossNet
,
self
).
__init__
()
super
(
GraphCrossNet
,
self
).
__init__
()
self
.
in_dim
=
in_dim
self
.
in_dim
=
in_dim
self
.
out_dim
=
out_dim
self
.
out_dim
=
out_dim
...
@@ -217,26 +307,35 @@ class GraphCrossNet(torch.nn.Module):
...
@@ -217,26 +307,35 @@ class GraphCrossNet(torch.nn.Module):
else
:
else
:
self
.
e2l_lin
=
None
self
.
e2l_lin
=
None
self
.
gxn
=
GraphCrossModule
(
pool_ratios
,
in_dim
=
self
.
in_dim
,
out_dim
=
hidden_dim
,
self
.
gxn
=
GraphCrossModule
(
hidden_dim
=
hidden_dim
//
2
,
cross_weight
=
cross_weight
,
pool_ratios
,
fuse_weight
=
fuse_weight
,
dist
=
dist
)
in_dim
=
self
.
in_dim
,
out_dim
=
hidden_dim
,
hidden_dim
=
hidden_dim
//
2
,
cross_weight
=
cross_weight
,
fuse_weight
=
fuse_weight
,
dist
=
dist
,
)
self
.
sortpool
=
SortPooling
(
readout_nodes
)
self
.
sortpool
=
SortPooling
(
readout_nodes
)
# final updates
# final updates
self
.
final_conv1
=
torch
.
nn
.
Conv1d
(
1
,
conv1d_dims
[
0
],
self
.
final_conv1
=
torch
.
nn
.
Conv1d
(
kernel_size
=
conv1d_kws
[
0
],
1
,
conv1d_dims
[
0
],
kernel_size
=
conv1d_kws
[
0
],
stride
=
conv1d_kws
[
0
]
stride
=
conv1d_kws
[
0
]
)
)
self
.
final_maxpool
=
torch
.
nn
.
MaxPool1d
(
2
,
2
)
self
.
final_maxpool
=
torch
.
nn
.
MaxPool1d
(
2
,
2
)
self
.
final_conv2
=
torch
.
nn
.
Conv1d
(
conv1d_dims
[
0
],
conv1d_dims
[
1
],
self
.
final_conv2
=
torch
.
nn
.
Conv1d
(
kernel_size
=
conv1d_kws
[
1
],
stride
=
1
)
conv1d_dims
[
0
],
conv1d_dims
[
1
],
kernel_size
=
conv1d_kws
[
1
],
stride
=
1
)
self
.
final_dense_dim
=
int
((
readout_nodes
-
2
)
/
2
+
1
)
self
.
final_dense_dim
=
int
((
readout_nodes
-
2
)
/
2
+
1
)
self
.
final_dense_dim
=
(
self
.
final_dense_dim
-
conv1d_kws
[
1
]
+
1
)
*
conv1d_dims
[
1
]
self
.
final_dense_dim
=
(
self
.
final_dense_dim
-
conv1d_kws
[
1
]
+
1
)
*
conv1d_dims
[
1
]
if
self
.
out_dim
>
0
:
if
self
.
out_dim
>
0
:
self
.
out_lin
=
torch
.
nn
.
Linear
(
self
.
final_dense_dim
,
out_dim
)
self
.
out_lin
=
torch
.
nn
.
Linear
(
self
.
final_dense_dim
,
out_dim
)
self
.
init_weights
()
self
.
init_weights
()
def
init_weights
(
self
):
def
init_weights
(
self
):
if
self
.
e2l_lin
is
not
None
:
if
self
.
e2l_lin
is
not
None
:
torch
.
nn
.
init
.
xavier_normal_
(
self
.
e2l_lin
.
weight
)
torch
.
nn
.
init
.
xavier_normal_
(
self
.
e2l_lin
.
weight
)
...
@@ -245,7 +344,12 @@ class GraphCrossNet(torch.nn.Module):
...
@@ -245,7 +344,12 @@ class GraphCrossNet(torch.nn.Module):
if
self
.
out_dim
>
0
:
if
self
.
out_dim
>
0
:
torch
.
nn
.
init
.
xavier_normal_
(
self
.
out_lin
.
weight
)
torch
.
nn
.
init
.
xavier_normal_
(
self
.
out_lin
.
weight
)
def
forward
(
self
,
graph
:
DGLGraph
,
node_feat
:
Tensor
,
edge_feat
:
Optional
[
Tensor
]
=
None
):
def
forward
(
self
,
graph
:
DGLGraph
,
node_feat
:
Tensor
,
edge_feat
:
Optional
[
Tensor
]
=
None
,
):
num_batch
=
graph
.
batch_size
num_batch
=
graph
.
batch_size
if
edge_feat
is
not
None
:
if
edge_feat
is
not
None
:
edge_feat
=
self
.
e2l_lin
(
edge_feat
)
edge_feat
=
self
.
e2l_lin
(
edge_feat
)
...
@@ -263,13 +367,13 @@ class GraphCrossNet(torch.nn.Module):
...
@@ -263,13 +367,13 @@ class GraphCrossNet(torch.nn.Module):
conv1d_result
=
F
.
relu
(
self
.
final_conv1
(
to_conv1d
))
conv1d_result
=
F
.
relu
(
self
.
final_conv1
(
to_conv1d
))
conv1d_result
=
self
.
final_maxpool
(
conv1d_result
)
conv1d_result
=
self
.
final_maxpool
(
conv1d_result
)
conv1d_result
=
F
.
relu
(
self
.
final_conv2
(
conv1d_result
))
conv1d_result
=
F
.
relu
(
self
.
final_conv2
(
conv1d_result
))
to_dense
=
conv1d_result
.
view
(
num_batch
,
-
1
)
to_dense
=
conv1d_result
.
view
(
num_batch
,
-
1
)
if
self
.
out_dim
>
0
:
if
self
.
out_dim
>
0
:
out
=
F
.
relu
(
self
.
out_lin
(
to_dense
))
out
=
F
.
relu
(
self
.
out_lin
(
to_dense
))
else
:
else
:
out
=
to_dense
out
=
to_dense
return
out
,
logits1
,
logits2
return
out
,
logits1
,
logits2
...
@@ -280,23 +384,31 @@ class GraphClassifier(torch.nn.Module):
...
@@ -280,23 +384,31 @@ class GraphClassifier(torch.nn.Module):
Graph Classifier for graph classification.
Graph Classifier for graph classification.
GXN + MLP
GXN + MLP
"""
"""
def
__init__
(
self
,
args
):
def
__init__
(
self
,
args
):
super
(
GraphClassifier
,
self
).
__init__
()
super
(
GraphClassifier
,
self
).
__init__
()
self
.
gxn
=
GraphCrossNet
(
in_dim
=
args
.
in_dim
,
self
.
gxn
=
GraphCrossNet
(
out_dim
=
args
.
embed_dim
,
in_dim
=
args
.
in_dim
,
edge_feat_dim
=
args
.
edge_feat_dim
,
out_dim
=
args
.
embed_dim
,
hidden_dim
=
args
.
hidden_dim
,
edge_feat_dim
=
args
.
edge_feat_dim
,
pool_ratios
=
args
.
pool_ratios
,
hidden_dim
=
args
.
hidden_dim
,
readout_nodes
=
args
.
readout_nodes
,
pool_ratios
=
args
.
pool_ratios
,
conv1d_dims
=
args
.
conv1d_dims
,
readout_nodes
=
args
.
readout_nodes
,
conv1d_kws
=
args
.
conv1d_kws
,
conv1d_dims
=
args
.
conv1d_dims
,
cross_weight
=
args
.
cross_weight
,
conv1d_kws
=
args
.
conv1d_kws
,
fuse_weight
=
args
.
fuse_weight
)
cross_weight
=
args
.
cross_weight
,
fuse_weight
=
args
.
fuse_weight
,
)
self
.
lin1
=
torch
.
nn
.
Linear
(
args
.
embed_dim
,
args
.
final_dense_hidden_dim
)
self
.
lin1
=
torch
.
nn
.
Linear
(
args
.
embed_dim
,
args
.
final_dense_hidden_dim
)
self
.
lin2
=
torch
.
nn
.
Linear
(
args
.
final_dense_hidden_dim
,
args
.
out_dim
)
self
.
lin2
=
torch
.
nn
.
Linear
(
args
.
final_dense_hidden_dim
,
args
.
out_dim
)
self
.
dropout
=
args
.
dropout
self
.
dropout
=
args
.
dropout
def
forward
(
self
,
graph
:
DGLGraph
,
node_feat
:
Tensor
,
edge_feat
:
Optional
[
Tensor
]
=
None
):
def
forward
(
self
,
graph
:
DGLGraph
,
node_feat
:
Tensor
,
edge_feat
:
Optional
[
Tensor
]
=
None
,
):
embed
,
logits1
,
logits2
=
self
.
gxn
(
graph
,
node_feat
,
edge_feat
)
embed
,
logits1
,
logits2
=
self
.
gxn
(
graph
,
node_feat
,
edge_feat
)
logits
=
F
.
relu
(
self
.
lin1
(
embed
))
logits
=
F
.
relu
(
self
.
lin1
(
embed
))
if
self
.
dropout
>
0
:
if
self
.
dropout
>
0
:
...
...
examples/pytorch/han/model_hetero.py
View file @
704bcaf6
...
@@ -7,11 +7,10 @@ constructed another dataset from ACM with a different set of papers, connections
...
@@ -7,11 +7,10 @@ constructed another dataset from ACM with a different set of papers, connections
labels.
labels.
"""
"""
import
dgl
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
import
dgl
from
dgl.nn.pytorch
import
GATConv
from
dgl.nn.pytorch
import
GATConv
...
...
examples/pytorch/han/train_sampling.py
View file @
704bcaf6
...
@@ -6,19 +6,19 @@ so we sampled twice as many neighbors during val/test than training.
...
@@ -6,19 +6,19 @@ so we sampled twice as many neighbors during val/test than training.
"""
"""
import
argparse
import
argparse
import
dgl
import
numpy
import
numpy
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
dgl.nn.pytorch
import
GATConv
from
dgl.sampling
import
RandomWalkNeighborSampler
from
model_hetero
import
SemanticAttention
from
model_hetero
import
SemanticAttention
from
sklearn.metrics
import
f1_score
from
sklearn.metrics
import
f1_score
from
torch.utils.data
import
DataLoader
from
torch.utils.data
import
DataLoader
from
utils
import
EarlyStopping
,
set_random_seed
from
utils
import
EarlyStopping
,
set_random_seed
import
dgl
from
dgl.nn.pytorch
import
GATConv
from
dgl.sampling
import
RandomWalkNeighborSampler
class
HANLayer
(
torch
.
nn
.
Module
):
class
HANLayer
(
torch
.
nn
.
Module
):
"""
"""
...
...
examples/pytorch/han/utils.py
View file @
704bcaf6
...
@@ -5,13 +5,12 @@ import pickle
...
@@ -5,13 +5,12 @@ import pickle
import
random
import
random
from
pprint
import
pprint
from
pprint
import
pprint
import
dgl
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
scipy
import
io
as
sio
from
scipy
import
sparse
import
dgl
from
dgl.data.utils
import
_get_dgl_url
,
download
,
get_download_dir
from
dgl.data.utils
import
_get_dgl_url
,
download
,
get_download_dir
from
scipy
import
io
as
sio
,
sparse
def
set_random_seed
(
seed
=
0
):
def
set_random_seed
(
seed
=
0
):
...
...
examples/pytorch/hardgat/hgao.py
View file @
704bcaf6
...
@@ -7,12 +7,12 @@ Paper: https://arxiv.org/abs/1907.04652
...
@@ -7,12 +7,12 @@ Paper: https://arxiv.org/abs/1907.04652
from
functools
import
partial
from
functools
import
partial
import
dgl
import
dgl.function
as
fn
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
import
dgl
import
dgl.function
as
fn
from
dgl.base
import
DGLError
from
dgl.base
import
DGLError
from
dgl.nn.pytorch
import
edge_softmax
from
dgl.nn.pytorch
import
edge_softmax
from
dgl.nn.pytorch.utils
import
Identity
from
dgl.nn.pytorch.utils
import
Identity
...
...
examples/pytorch/hardgat/train.py
View file @
704bcaf6
...
@@ -8,19 +8,19 @@ Paper: https://arxiv.org/abs/1907.04652
...
@@ -8,19 +8,19 @@ Paper: https://arxiv.org/abs/1907.04652
import
argparse
import
argparse
import
time
import
time
import
dgl
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
hgao
import
HardGAT
from
utils
import
EarlyStopping
import
dgl
from
dgl.data
import
(
from
dgl.data
import
(
CiteseerGraphDataset
,
CiteseerGraphDataset
,
CoraGraphDataset
,
CoraGraphDataset
,
PubmedGraphDataset
,
PubmedGraphDataset
,
register_data_args
,
register_data_args
,
)
)
from
hgao
import
HardGAT
from
utils
import
EarlyStopping
def
accuracy
(
logits
,
labels
):
def
accuracy
(
logits
,
labels
):
...
@@ -161,7 +161,6 @@ def main(args):
...
@@ -161,7 +161,6 @@ def main(args):
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
(
description
=
"GAT"
)
parser
=
argparse
.
ArgumentParser
(
description
=
"GAT"
)
register_data_args
(
parser
)
register_data_args
(
parser
)
parser
.
add_argument
(
parser
.
add_argument
(
...
...
examples/pytorch/hgp_sl/functions.py
View file @
704bcaf6
...
@@ -7,15 +7,14 @@ for detailed description.
...
@@ -7,15 +7,14 @@ for detailed description.
Here we implement a graph-edge version of sparsemax where we perform sparsemax for all edges
Here we implement a graph-edge version of sparsemax where we perform sparsemax for all edges
with the same node as end-node in graphs.
with the same node as end-node in graphs.
"""
"""
import
torch
from
torch
import
Tensor
from
torch.autograd
import
Function
import
dgl
import
dgl
import
torch
from
dgl.backend
import
astype
from
dgl.backend
import
astype
from
dgl.base
import
ALL
,
is_all
from
dgl.base
import
ALL
,
is_all
from
dgl.heterograph_index
import
HeteroGraphIndex
from
dgl.heterograph_index
import
HeteroGraphIndex
from
dgl.sparse
import
_gsddmm
,
_gspmm
from
dgl.sparse
import
_gsddmm
,
_gspmm
from
torch
import
Tensor
from
torch.autograd
import
Function
def
_neighbor_sort
(
def
_neighbor_sort
(
...
...
examples/pytorch/hgp_sl/layers.py
View file @
704bcaf6
...
@@ -7,10 +7,10 @@ import torch.nn.functional as F
...
@@ -7,10 +7,10 @@ import torch.nn.functional as F
from
dgl
import
DGLGraph
from
dgl
import
DGLGraph
from
dgl.nn
import
AvgPooling
,
GraphConv
,
MaxPooling
from
dgl.nn
import
AvgPooling
,
GraphConv
,
MaxPooling
from
dgl.ops
import
edge_softmax
from
dgl.ops
import
edge_softmax
from
torch
import
Tensor
from
torch.nn
import
Parameter
from
functions
import
edge_sparsemax
from
functions
import
edge_sparsemax
from
torch
import
Tensor
from
torch.nn
import
Parameter
from
utils
import
get_batch_id
,
topk
from
utils
import
get_batch_id
,
topk
...
@@ -30,10 +30,11 @@ class WeightedGraphConv(GraphConv):
...
@@ -30,10 +30,11 @@ class WeightedGraphConv(GraphConv):
e_feat : torch.Tensor, optional
e_feat : torch.Tensor, optional
The edge features. Default: :obj:`None`
The edge features. Default: :obj:`None`
"""
"""
def
forward
(
self
,
graph
:
DGLGraph
,
n_feat
,
e_feat
=
None
):
def
forward
(
self
,
graph
:
DGLGraph
,
n_feat
,
e_feat
=
None
):
if
e_feat
is
None
:
if
e_feat
is
None
:
return
super
(
WeightedGraphConv
,
self
).
forward
(
graph
,
n_feat
)
return
super
(
WeightedGraphConv
,
self
).
forward
(
graph
,
n_feat
)
with
graph
.
local_scope
():
with
graph
.
local_scope
():
if
self
.
weight
is
not
None
:
if
self
.
weight
is
not
None
:
n_feat
=
torch
.
matmul
(
n_feat
,
self
.
weight
)
n_feat
=
torch
.
matmul
(
n_feat
,
self
.
weight
)
...
@@ -44,8 +45,7 @@ class WeightedGraphConv(GraphConv):
...
@@ -44,8 +45,7 @@ class WeightedGraphConv(GraphConv):
n_feat
=
n_feat
*
src_norm
n_feat
=
n_feat
*
src_norm
graph
.
ndata
[
"h"
]
=
n_feat
graph
.
ndata
[
"h"
]
=
n_feat
graph
.
edata
[
"e"
]
=
e_feat
graph
.
edata
[
"e"
]
=
e_feat
graph
.
update_all
(
fn
.
u_mul_e
(
"h"
,
"e"
,
"m"
),
graph
.
update_all
(
fn
.
u_mul_e
(
"h"
,
"e"
,
"m"
),
fn
.
sum
(
"m"
,
"h"
))
fn
.
sum
(
"m"
,
"h"
))
n_feat
=
graph
.
ndata
.
pop
(
"h"
)
n_feat
=
graph
.
ndata
.
pop
(
"h"
)
n_feat
=
n_feat
*
dst_norm
n_feat
=
n_feat
*
dst_norm
if
self
.
bias
is
not
None
:
if
self
.
bias
is
not
None
:
...
@@ -77,35 +77,40 @@ class NodeInfoScoreLayer(nn.Module):
...
@@ -77,35 +77,40 @@ class NodeInfoScoreLayer(nn.Module):
The node features
The node features
e_feat : torch.Tensor, optional
e_feat : torch.Tensor, optional
The edge features. Default: :obj:`None`
The edge features. Default: :obj:`None`
Returns
Returns
-------
-------
Tensor
Tensor
Score for each node.
Score for each node.
"""
"""
def
__init__
(
self
,
sym_norm
:
bool
=
True
):
def
__init__
(
self
,
sym_norm
:
bool
=
True
):
super
(
NodeInfoScoreLayer
,
self
).
__init__
()
super
(
NodeInfoScoreLayer
,
self
).
__init__
()
self
.
sym_norm
=
sym_norm
self
.
sym_norm
=
sym_norm
def
forward
(
self
,
graph
:
dgl
.
DGLGraph
,
feat
:
Tensor
,
e_feat
:
Tensor
):
def
forward
(
self
,
graph
:
dgl
.
DGLGraph
,
feat
:
Tensor
,
e_feat
:
Tensor
):
with
graph
.
local_scope
():
with
graph
.
local_scope
():
if
self
.
sym_norm
:
if
self
.
sym_norm
:
src_norm
=
torch
.
pow
(
graph
.
out_degrees
().
float
().
clamp
(
min
=
1
),
-
0.5
)
src_norm
=
torch
.
pow
(
graph
.
out_degrees
().
float
().
clamp
(
min
=
1
),
-
0.5
)
src_norm
=
src_norm
.
view
(
-
1
,
1
).
to
(
feat
.
device
)
src_norm
=
src_norm
.
view
(
-
1
,
1
).
to
(
feat
.
device
)
dst_norm
=
torch
.
pow
(
graph
.
in_degrees
().
float
().
clamp
(
min
=
1
),
-
0.5
)
dst_norm
=
torch
.
pow
(
graph
.
in_degrees
().
float
().
clamp
(
min
=
1
),
-
0.5
)
dst_norm
=
dst_norm
.
view
(
-
1
,
1
).
to
(
feat
.
device
)
dst_norm
=
dst_norm
.
view
(
-
1
,
1
).
to
(
feat
.
device
)
src_feat
=
feat
*
src_norm
src_feat
=
feat
*
src_norm
graph
.
ndata
[
"h"
]
=
src_feat
graph
.
ndata
[
"h"
]
=
src_feat
graph
.
edata
[
"e"
]
=
e_feat
graph
.
edata
[
"e"
]
=
e_feat
graph
=
dgl
.
remove_self_loop
(
graph
)
graph
=
dgl
.
remove_self_loop
(
graph
)
graph
.
update_all
(
fn
.
u_mul_e
(
"h"
,
"e"
,
"m"
),
fn
.
sum
(
"m"
,
"h"
))
graph
.
update_all
(
fn
.
u_mul_e
(
"h"
,
"e"
,
"m"
),
fn
.
sum
(
"m"
,
"h"
))
dst_feat
=
graph
.
ndata
.
pop
(
"h"
)
*
dst_norm
dst_feat
=
graph
.
ndata
.
pop
(
"h"
)
*
dst_norm
feat
=
feat
-
dst_feat
feat
=
feat
-
dst_feat
else
:
else
:
dst_norm
=
1.
/
graph
.
in_degrees
().
float
().
clamp
(
min
=
1
)
dst_norm
=
1.
0
/
graph
.
in_degrees
().
float
().
clamp
(
min
=
1
)
dst_norm
=
dst_norm
.
view
(
-
1
,
1
)
dst_norm
=
dst_norm
.
view
(
-
1
,
1
)
graph
.
ndata
[
"h"
]
=
feat
graph
.
ndata
[
"h"
]
=
feat
...
@@ -124,7 +129,7 @@ class HGPSLPool(nn.Module):
...
@@ -124,7 +129,7 @@ class HGPSLPool(nn.Module):
Description
Description
-----------
-----------
The HGP-SL pooling layer from
The HGP-SL pooling layer from
`Hierarchical Graph Pooling with Structure Learning <https://arxiv.org/pdf/1911.05954.pdf>`
`Hierarchical Graph Pooling with Structure Learning <https://arxiv.org/pdf/1911.05954.pdf>`
Parameters
Parameters
...
@@ -134,7 +139,7 @@ class HGPSLPool(nn.Module):
...
@@ -134,7 +139,7 @@ class HGPSLPool(nn.Module):
ratio : float, optional
ratio : float, optional
Pooling ratio. Default: 0.8
Pooling ratio. Default: 0.8
sample : bool, optional
sample : bool, optional
Whether use k-hop union graph to increase efficiency.
Whether use k-hop union graph to increase efficiency.
Currently we only support full graph. Default: :obj:`False`
Currently we only support full graph. Default: :obj:`False`
sym_score_norm : bool, optional
sym_score_norm : bool, optional
Use symmetric norm for adjacency or not. Default: :obj:`True`
Use symmetric norm for adjacency or not. Default: :obj:`True`
...
@@ -147,7 +152,7 @@ class HGPSLPool(nn.Module):
...
@@ -147,7 +152,7 @@ class HGPSLPool(nn.Module):
HGP-SL paper. Default: 1.0
HGP-SL paper. Default: 1.0
negative_slop : float, optional
negative_slop : float, optional
Negative slop for leaky_relu. Default: 0.2
Negative slop for leaky_relu. Default: 0.2
Returns
Returns
-------
-------
DGLGraph
DGLGraph
...
@@ -159,9 +164,19 @@ class HGPSLPool(nn.Module):
...
@@ -159,9 +164,19 @@ class HGPSLPool(nn.Module):
torch.Tensor
torch.Tensor
Permutation index
Permutation index
"""
"""
def
__init__
(
self
,
in_feat
:
int
,
ratio
=
0.8
,
sample
=
True
,
sym_score_norm
=
True
,
sparse
=
True
,
sl
=
True
,
def
__init__
(
lamb
=
1.0
,
negative_slop
=
0.2
,
k_hop
=
3
):
self
,
in_feat
:
int
,
ratio
=
0.8
,
sample
=
True
,
sym_score_norm
=
True
,
sparse
=
True
,
sl
=
True
,
lamb
=
1.0
,
negative_slop
=
0.2
,
k_hop
=
3
,
):
super
(
HGPSLPool
,
self
).
__init__
()
super
(
HGPSLPool
,
self
).
__init__
()
self
.
in_feat
=
in_feat
self
.
in_feat
=
in_feat
self
.
ratio
=
ratio
self
.
ratio
=
ratio
...
@@ -180,16 +195,17 @@ class HGPSLPool(nn.Module):
...
@@ -180,16 +195,17 @@ class HGPSLPool(nn.Module):
def
reset_parameters
(
self
):
def
reset_parameters
(
self
):
nn
.
init
.
xavier_normal_
(
self
.
att
.
data
)
nn
.
init
.
xavier_normal_
(
self
.
att
.
data
)
def
forward
(
self
,
graph
:
DGLGraph
,
feat
:
Tensor
,
e_feat
=
None
):
def
forward
(
self
,
graph
:
DGLGraph
,
feat
:
Tensor
,
e_feat
=
None
):
# top-k pool first
# top-k pool first
if
e_feat
is
None
:
if
e_feat
is
None
:
e_feat
=
torch
.
ones
((
graph
.
number_of_edges
(),),
e_feat
=
torch
.
ones
(
dtype
=
feat
.
dtype
,
device
=
feat
.
device
)
(
graph
.
number_of_edges
(),),
dtype
=
feat
.
dtype
,
device
=
feat
.
device
)
batch_num_nodes
=
graph
.
batch_num_nodes
()
batch_num_nodes
=
graph
.
batch_num_nodes
()
x_score
=
self
.
calc_info_score
(
graph
,
feat
,
e_feat
)
x_score
=
self
.
calc_info_score
(
graph
,
feat
,
e_feat
)
perm
,
next_batch_num_nodes
=
topk
(
x_score
,
self
.
ratio
,
perm
,
next_batch_num_nodes
=
topk
(
get_batch_id
(
batch_num_nodes
),
x_score
,
self
.
ratio
,
get_batch_id
(
batch_num_nodes
),
batch_num_nodes
batch_num_nodes
)
)
feat
=
feat
[
perm
]
feat
=
feat
[
perm
]
pool_graph
=
None
pool_graph
=
None
if
not
self
.
sample
or
not
self
.
sl
:
if
not
self
.
sample
or
not
self
.
sl
:
...
@@ -210,36 +226,48 @@ class HGPSLPool(nn.Module):
...
@@ -210,36 +226,48 @@ class HGPSLPool(nn.Module):
# pair of nodes is time consuming. To accelerate this process,
# pair of nodes is time consuming. To accelerate this process,
# we sample it's K-Hop neighbors for each node and then learn the
# we sample it's K-Hop neighbors for each node and then learn the
# edge weights between them.
# edge weights between them.
# first build multi-hop graph
# first build multi-hop graph
row
,
col
=
graph
.
all_edges
()
row
,
col
=
graph
.
all_edges
()
num_nodes
=
graph
.
num_nodes
()
num_nodes
=
graph
.
num_nodes
()
scipy_adj
=
scipy
.
sparse
.
coo_matrix
((
e_feat
.
detach
().
cpu
(),
(
row
.
detach
().
cpu
(),
col
.
detach
().
cpu
())),
shape
=
(
num_nodes
,
num_nodes
))
scipy_adj
=
scipy
.
sparse
.
coo_matrix
(
(
e_feat
.
detach
().
cpu
(),
(
row
.
detach
().
cpu
(),
col
.
detach
().
cpu
()),
),
shape
=
(
num_nodes
,
num_nodes
),
)
for
_
in
range
(
self
.
k_hop
):
for
_
in
range
(
self
.
k_hop
):
two_hop
=
scipy_adj
**
2
two_hop
=
scipy_adj
**
2
two_hop
=
two_hop
*
(
1e-5
/
two_hop
.
max
())
two_hop
=
two_hop
*
(
1e-5
/
two_hop
.
max
())
scipy_adj
=
two_hop
+
scipy_adj
scipy_adj
=
two_hop
+
scipy_adj
row
,
col
=
scipy_adj
.
nonzero
()
row
,
col
=
scipy_adj
.
nonzero
()
row
=
torch
.
tensor
(
row
,
dtype
=
torch
.
long
,
device
=
graph
.
device
)
row
=
torch
.
tensor
(
row
,
dtype
=
torch
.
long
,
device
=
graph
.
device
)
col
=
torch
.
tensor
(
col
,
dtype
=
torch
.
long
,
device
=
graph
.
device
)
col
=
torch
.
tensor
(
col
,
dtype
=
torch
.
long
,
device
=
graph
.
device
)
e_feat
=
torch
.
tensor
(
scipy_adj
.
data
,
dtype
=
torch
.
float
,
device
=
feat
.
device
)
e_feat
=
torch
.
tensor
(
scipy_adj
.
data
,
dtype
=
torch
.
float
,
device
=
feat
.
device
)
# perform pooling on multi-hop graph
# perform pooling on multi-hop graph
mask
=
perm
.
new_full
((
num_nodes
,
),
-
1
)
mask
=
perm
.
new_full
((
num_nodes
,),
-
1
)
i
=
torch
.
arange
(
perm
.
size
(
0
),
dtype
=
torch
.
long
,
device
=
perm
.
device
)
i
=
torch
.
arange
(
perm
.
size
(
0
),
dtype
=
torch
.
long
,
device
=
perm
.
device
)
mask
[
perm
]
=
i
mask
[
perm
]
=
i
row
,
col
=
mask
[
row
],
mask
[
col
]
row
,
col
=
mask
[
row
],
mask
[
col
]
mask
=
(
row
>=
0
)
&
(
col
>=
0
)
mask
=
(
row
>=
0
)
&
(
col
>=
0
)
row
,
col
=
row
[
mask
],
col
[
mask
]
row
,
col
=
row
[
mask
],
col
[
mask
]
e_feat
=
e_feat
[
mask
]
e_feat
=
e_feat
[
mask
]
# add remaining self loops
# add remaining self loops
mask
=
row
!=
col
mask
=
row
!=
col
num_nodes
=
perm
.
size
(
0
)
# num nodes after pool
num_nodes
=
perm
.
size
(
0
)
# num nodes after pool
loop_index
=
torch
.
arange
(
0
,
num_nodes
,
dtype
=
row
.
dtype
,
device
=
row
.
device
)
loop_index
=
torch
.
arange
(
0
,
num_nodes
,
dtype
=
row
.
dtype
,
device
=
row
.
device
)
inv_mask
=
~
mask
inv_mask
=
~
mask
loop_weight
=
torch
.
full
((
num_nodes
,
),
0
,
dtype
=
e_feat
.
dtype
,
device
=
e_feat
.
device
)
loop_weight
=
torch
.
full
(
(
num_nodes
,),
0
,
dtype
=
e_feat
.
dtype
,
device
=
e_feat
.
device
)
remaining_e_feat
=
e_feat
[
inv_mask
]
remaining_e_feat
=
e_feat
[
inv_mask
]
if
remaining_e_feat
.
numel
()
>
0
:
if
remaining_e_feat
.
numel
()
>
0
:
loop_weight
[
row
[
inv_mask
]]
=
remaining_e_feat
loop_weight
[
row
[
inv_mask
]]
=
remaining_e_feat
...
@@ -249,16 +277,20 @@ class HGPSLPool(nn.Module):
...
@@ -249,16 +277,20 @@ class HGPSLPool(nn.Module):
col
=
torch
.
cat
([
col
,
loop_index
],
dim
=
0
)
col
=
torch
.
cat
([
col
,
loop_index
],
dim
=
0
)
# attention scores
# attention scores
weights
=
(
torch
.
cat
([
feat
[
row
],
feat
[
col
]],
dim
=
1
)
*
self
.
att
).
sum
(
dim
=-
1
)
weights
=
(
torch
.
cat
([
feat
[
row
],
feat
[
col
]],
dim
=
1
)
*
self
.
att
).
sum
(
weights
=
F
.
leaky_relu
(
weights
,
self
.
negative_slop
)
+
e_feat
*
self
.
lamb
dim
=-
1
)
weights
=
(
F
.
leaky_relu
(
weights
,
self
.
negative_slop
)
+
e_feat
*
self
.
lamb
)
# sl and normalization
# sl and normalization
sl_graph
=
dgl
.
graph
((
row
,
col
))
sl_graph
=
dgl
.
graph
((
row
,
col
))
if
self
.
sparse
:
if
self
.
sparse
:
weights
=
edge_sparsemax
(
sl_graph
,
weights
)
weights
=
edge_sparsemax
(
sl_graph
,
weights
)
else
:
else
:
weights
=
edge_softmax
(
sl_graph
,
weights
)
weights
=
edge_softmax
(
sl_graph
,
weights
)
# get final graph
# get final graph
mask
=
torch
.
abs
(
weights
)
>
0
mask
=
torch
.
abs
(
weights
)
>
0
row
,
col
,
weights
=
row
[
mask
],
col
[
mask
],
weights
[
mask
]
row
,
col
,
weights
=
row
[
mask
],
col
[
mask
],
weights
[
mask
]
...
@@ -266,7 +298,7 @@ class HGPSLPool(nn.Module):
...
@@ -266,7 +298,7 @@ class HGPSLPool(nn.Module):
pool_graph
.
set_batch_num_nodes
(
next_batch_num_nodes
)
pool_graph
.
set_batch_num_nodes
(
next_batch_num_nodes
)
e_feat
=
weights
e_feat
=
weights
else
:
else
:
# Learning the possible edge weights between each pair of
# Learning the possible edge weights between each pair of
# nodes in the pooled subgraph, relative slower.
# nodes in the pooled subgraph, relative slower.
...
@@ -274,19 +306,27 @@ class HGPSLPool(nn.Module):
...
@@ -274,19 +306,27 @@ class HGPSLPool(nn.Module):
# use dense to build, then transform to sparse.
# use dense to build, then transform to sparse.
# maybe there's more efficient way?
# maybe there's more efficient way?
batch_num_nodes
=
next_batch_num_nodes
batch_num_nodes
=
next_batch_num_nodes
block_begin_idx
=
torch
.
cat
([
batch_num_nodes
.
new_zeros
(
1
),
block_begin_idx
=
torch
.
cat
(
batch_num_nodes
.
cumsum
(
dim
=
0
)[:
-
1
]],
dim
=
0
)
[
batch_num_nodes
.
new_zeros
(
1
),
batch_num_nodes
.
cumsum
(
dim
=
0
)[:
-
1
],
],
dim
=
0
,
)
block_end_idx
=
batch_num_nodes
.
cumsum
(
dim
=
0
)
block_end_idx
=
batch_num_nodes
.
cumsum
(
dim
=
0
)
dense_adj
=
torch
.
zeros
((
pool_graph
.
num_nodes
(),
dense_adj
=
torch
.
zeros
(
pool_graph
.
num_nodes
()),
(
pool_graph
.
num_nodes
(),
pool_graph
.
num_nodes
()),
dtype
=
torch
.
float
,
dtype
=
torch
.
float
,
device
=
feat
.
device
)
device
=
feat
.
device
,
)
for
idx_b
,
idx_e
in
zip
(
block_begin_idx
,
block_end_idx
):
for
idx_b
,
idx_e
in
zip
(
block_begin_idx
,
block_end_idx
):
dense_adj
[
idx_b
:
idx_e
,
idx_b
:
idx_e
]
=
1.
dense_adj
[
idx_b
:
idx_e
,
idx_b
:
idx_e
]
=
1.
0
row
,
col
=
torch
.
nonzero
(
dense_adj
).
t
().
contiguous
()
row
,
col
=
torch
.
nonzero
(
dense_adj
).
t
().
contiguous
()
# compute weights for node-pairs
# compute weights for node-pairs
weights
=
(
torch
.
cat
([
feat
[
row
],
feat
[
col
]],
dim
=
1
)
*
self
.
att
).
sum
(
dim
=-
1
)
weights
=
(
torch
.
cat
([
feat
[
row
],
feat
[
col
]],
dim
=
1
)
*
self
.
att
).
sum
(
dim
=-
1
)
weights
=
F
.
leaky_relu
(
weights
,
self
.
negative_slop
)
weights
=
F
.
leaky_relu
(
weights
,
self
.
negative_slop
)
dense_adj
[
row
,
col
]
=
weights
dense_adj
[
row
,
col
]
=
weights
...
@@ -316,15 +356,30 @@ class HGPSLPool(nn.Module):
...
@@ -316,15 +356,30 @@ class HGPSLPool(nn.Module):
class
ConvPoolReadout
(
torch
.
nn
.
Module
):
class
ConvPoolReadout
(
torch
.
nn
.
Module
):
"""A helper class. (GraphConv -> Pooling -> Readout)"""
"""A helper class. (GraphConv -> Pooling -> Readout)"""
def
__init__
(
self
,
in_feat
:
int
,
out_feat
:
int
,
pool_ratio
=
0.8
,
sample
:
bool
=
False
,
sparse
:
bool
=
True
,
sl
:
bool
=
True
,
def
__init__
(
lamb
:
float
=
1.
,
pool
:
bool
=
True
):
self
,
in_feat
:
int
,
out_feat
:
int
,
pool_ratio
=
0.8
,
sample
:
bool
=
False
,
sparse
:
bool
=
True
,
sl
:
bool
=
True
,
lamb
:
float
=
1.0
,
pool
:
bool
=
True
,
):
super
(
ConvPoolReadout
,
self
).
__init__
()
super
(
ConvPoolReadout
,
self
).
__init__
()
self
.
use_pool
=
pool
self
.
use_pool
=
pool
self
.
conv
=
WeightedGraphConv
(
in_feat
,
out_feat
)
self
.
conv
=
WeightedGraphConv
(
in_feat
,
out_feat
)
if
pool
:
if
pool
:
self
.
pool
=
HGPSLPool
(
out_feat
,
ratio
=
pool_ratio
,
sparse
=
sparse
,
self
.
pool
=
HGPSLPool
(
sample
=
sample
,
sl
=
sl
,
lamb
=
lamb
)
out_feat
,
ratio
=
pool_ratio
,
sparse
=
sparse
,
sample
=
sample
,
sl
=
sl
,
lamb
=
lamb
,
)
else
:
else
:
self
.
pool
=
None
self
.
pool
=
None
self
.
avgpool
=
AvgPooling
()
self
.
avgpool
=
AvgPooling
()
...
@@ -334,5 +389,7 @@ class ConvPoolReadout(torch.nn.Module):
...
@@ -334,5 +389,7 @@ class ConvPoolReadout(torch.nn.Module):
out
=
F
.
relu
(
self
.
conv
(
graph
,
feature
,
e_feat
))
out
=
F
.
relu
(
self
.
conv
(
graph
,
feature
,
e_feat
))
if
self
.
use_pool
:
if
self
.
use_pool
:
graph
,
out
,
e_feat
,
_
=
self
.
pool
(
graph
,
out
,
e_feat
)
graph
,
out
,
e_feat
,
_
=
self
.
pool
(
graph
,
out
,
e_feat
)
readout
=
torch
.
cat
([
self
.
avgpool
(
graph
,
out
),
self
.
maxpool
(
graph
,
out
)],
dim
=-
1
)
readout
=
torch
.
cat
(
[
self
.
avgpool
(
graph
,
out
),
self
.
maxpool
(
graph
,
out
)],
dim
=-
1
)
return
graph
,
out
,
e_feat
,
readout
return
graph
,
out
,
e_feat
,
readout
examples/pytorch/hgp_sl/main.py
View file @
704bcaf6
...
@@ -4,17 +4,17 @@ import logging
...
@@ -4,17 +4,17 @@ import logging
import
os
import
os
from
time
import
time
from
time
import
time
import
dgl
import
torch
import
torch
import
torch.nn
import
torch.nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
dgl.data
import
LegacyTUDataset
from
dgl.dataloading
import
GraphDataLoader
from
networks
import
HGPSLModel
from
networks
import
HGPSLModel
from
torch.utils.data
import
random_split
from
torch.utils.data
import
random_split
from
utils
import
get_stats
from
utils
import
get_stats
import
dgl
from
dgl.data
import
LegacyTUDataset
from
dgl.dataloading
import
GraphDataLoader
def
parse_args
():
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
"HGP-SL-DGL"
)
parser
=
argparse
.
ArgumentParser
(
description
=
"HGP-SL-DGL"
)
...
...
examples/pytorch/hgp_sl/networks.py
View file @
704bcaf6
import
torch
import
torch
import
torch.nn
import
torch.nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
layers
import
ConvPoolReadout
from
dgl.nn
import
AvgPooling
,
MaxPooling
from
dgl.nn
import
AvgPooling
,
MaxPooling
from
layers
import
ConvPoolReadout
class
HGPSLModel
(
torch
.
nn
.
Module
):
class
HGPSLModel
(
torch
.
nn
.
Module
):
...
...
examples/pytorch/hgt/model.py
View file @
704bcaf6
import
math
import
math
import
dgl
import
dgl.function
as
fn
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
import
dgl
import
dgl.function
as
fn
from
dgl.nn.functional
import
edge_softmax
from
dgl.nn.functional
import
edge_softmax
...
...
examples/pytorch/hilander/PSS/Smooth_AP/src/auxiliaries.py
View file @
704bcaf6
...
@@ -20,6 +20,8 @@ from torch import nn
...
@@ -20,6 +20,8 @@ from torch import nn
from
tqdm
import
tqdm
from
tqdm
import
tqdm
"""============================================================================================================="""
"""============================================================================================================="""
################### TensorBoard Settings ###################
################### TensorBoard Settings ###################
def
args2exp_name
(
args
):
def
args2exp_name
(
args
):
exp_name
=
f
"
{
args
.
dataset
}
_
{
args
.
loss
}
_
{
args
.
lr
}
_bs
{
args
.
bs
}
_spc
{
args
.
samples_per_class
}
_embed
{
args
.
embed_dim
}
_arch
{
args
.
arch
}
_decay
{
args
.
decay
}
_fclr
{
args
.
fc_lr_mul
}
_anneal
{
args
.
sigmoid_temperature
}
"
exp_name
=
f
"
{
args
.
dataset
}
_
{
args
.
loss
}
_
{
args
.
lr
}
_bs
{
args
.
bs
}
_spc
{
args
.
samples_per_class
}
_embed
{
args
.
embed_dim
}
_arch
{
args
.
arch
}
_decay
{
args
.
decay
}
_fclr
{
args
.
fc_lr_mul
}
_anneal
{
args
.
sigmoid_temperature
}
"
...
@@ -381,6 +383,8 @@ def eval_metrics_query_and_gallery_dataset(
...
@@ -381,6 +383,8 @@ def eval_metrics_query_and_gallery_dataset(
"""============================================================================================================="""
"""============================================================================================================="""
####### RECOVER CLOSEST EXAMPLE IMAGES #######
####### RECOVER CLOSEST EXAMPLE IMAGES #######
def
recover_closest_one_dataset
(
def
recover_closest_one_dataset
(
feature_matrix_all
,
image_paths
,
save_path
,
n_image_samples
=
10
,
n_closest
=
3
feature_matrix_all
,
image_paths
,
save_path
,
n_image_samples
=
10
,
n_closest
=
3
...
@@ -489,6 +493,8 @@ def recover_closest_inshop(
...
@@ -489,6 +493,8 @@ def recover_closest_inshop(
"""============================================================================================================="""
"""============================================================================================================="""
################## SET NETWORK TRAINING CHECKPOINT #####################
################## SET NETWORK TRAINING CHECKPOINT #####################
def
set_checkpoint
(
model
,
opt
,
progress_saver
,
savepath
):
def
set_checkpoint
(
model
,
opt
,
progress_saver
,
savepath
):
"""
"""
...
@@ -514,6 +520,8 @@ def set_checkpoint(model, opt, progress_saver, savepath):
...
@@ -514,6 +520,8 @@ def set_checkpoint(model, opt, progress_saver, savepath):
"""============================================================================================================="""
"""============================================================================================================="""
################## WRITE TO CSV FILE #####################
################## WRITE TO CSV FILE #####################
class
CSV_Writer
:
class
CSV_Writer
:
"""
"""
...
...
examples/pytorch/hilander/PSS/Smooth_AP/src/datasets.py
View file @
704bcaf6
...
@@ -20,6 +20,8 @@ from torch.utils.data import Dataset
...
@@ -20,6 +20,8 @@ from torch.utils.data import Dataset
from
torchvision
import
transforms
from
torchvision
import
transforms
"""============================================================================"""
"""============================================================================"""
################ FUNCTION TO RETURN ALL DATALOADERS NECESSARY ####################
################ FUNCTION TO RETURN ALL DATALOADERS NECESSARY ####################
def
give_dataloaders
(
dataset
,
trainset
,
testset
,
opt
,
cluster_path
=
""
):
def
give_dataloaders
(
dataset
,
trainset
,
testset
,
opt
,
cluster_path
=
""
):
"""
"""
...
...
examples/pytorch/hilander/PSS/Smooth_AP/src/evaluate_model.py
View file @
704bcaf6
...
@@ -8,7 +8,6 @@ import netlib as netlib
...
@@ -8,7 +8,6 @@ import netlib as netlib
import
torch
import
torch
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
################## INPUT ARGUMENTS ###################
################## INPUT ARGUMENTS ###################
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
####### Main Parameter: Dataset to use for Training
####### Main Parameter: Dataset to use for Training
...
...
Prev
1
…
4
5
6
7
8
9
10
11
12
…
17
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment