Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
704bcaf6
Unverified
Commit
704bcaf6
authored
Feb 19, 2023
by
Hongzhi (Steve), Chen
Committed by
GitHub
Feb 19, 2023
Browse files
examples (#5323)
Co-authored-by:
Ubuntu
<
ubuntu@ip-172-31-28-63.ap-northeast-1.compute.internal
>
parent
6bc82161
Changes
332
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
300 additions
and
195 deletions
+300
-195
examples/pytorch/hilander/PSS/Smooth_AP/src/finetune_1head.py
...ples/pytorch/hilander/PSS/Smooth_AP/src/finetune_1head.py
+2
-2
examples/pytorch/hilander/PSS/Smooth_AP/src/get_features.py
examples/pytorch/hilander/PSS/Smooth_AP/src/get_features.py
+0
-1
examples/pytorch/hilander/PSS/Smooth_AP/src/losses.py
examples/pytorch/hilander/PSS/Smooth_AP/src/losses.py
+2
-0
examples/pytorch/hilander/PSS/Smooth_AP/src/main.py
examples/pytorch/hilander/PSS/Smooth_AP/src/main.py
+0
-1
examples/pytorch/hilander/PSS/test_subg_inat.py
examples/pytorch/hilander/PSS/test_subg_inat.py
+177
-95
examples/pytorch/hilander/PSS/train_subg_inat.py
examples/pytorch/hilander/PSS/train_subg_inat.py
+95
-56
examples/pytorch/hilander/models/graphconv.py
examples/pytorch/hilander/models/graphconv.py
+2
-3
examples/pytorch/hilander/models/lander.py
examples/pytorch/hilander/models/lander.py
+2
-3
examples/pytorch/hilander/test.py
examples/pytorch/hilander/test.py
+2
-2
examples/pytorch/hilander/test_subg.py
examples/pytorch/hilander/test_subg.py
+2
-2
examples/pytorch/hilander/train.py
examples/pytorch/hilander/train.py
+3
-2
examples/pytorch/hilander/train_subg.py
examples/pytorch/hilander/train_subg.py
+2
-2
examples/pytorch/hilander/utils/deduce.py
examples/pytorch/hilander/utils/deduce.py
+1
-2
examples/pytorch/hilander/utils/evaluate.py
examples/pytorch/hilander/utils/evaluate.py
+1
-1
examples/pytorch/hilander/utils/faiss_search.py
examples/pytorch/hilander/utils/faiss_search.py
+0
-1
examples/pytorch/infograph/evaluate_embedding.py
examples/pytorch/infograph/evaluate_embedding.py
+0
-1
examples/pytorch/infograph/model.py
examples/pytorch/infograph/model.py
+2
-7
examples/pytorch/infograph/semisupervised.py
examples/pytorch/infograph/semisupervised.py
+3
-4
examples/pytorch/infograph/unsupervised.py
examples/pytorch/infograph/unsupervised.py
+4
-7
examples/pytorch/infograph/utils.py
examples/pytorch/infograph/utils.py
+0
-3
No files found.
examples/pytorch/hilander/PSS/Smooth_AP/src/finetune_1head.py
View file @
704bcaf6
...
@@ -280,7 +280,6 @@ _ = model.to(opt.device)
...
@@ -280,7 +280,6 @@ _ = model.to(opt.device)
# Place trainable parameter in list of parameters to train:
# Place trainable parameter in list of parameters to train:
if
"fc_lr_mul"
in
vars
(
opt
).
keys
()
and
opt
.
fc_lr_mul
!=
0
:
if
"fc_lr_mul"
in
vars
(
opt
).
keys
()
and
opt
.
fc_lr_mul
!=
0
:
all_but_fc_params
=
list
(
all_but_fc_params
=
list
(
filter
(
lambda
x
:
"last_linear"
not
in
x
[
0
],
model
.
named_parameters
())
filter
(
lambda
x
:
"last_linear"
not
in
x
[
0
],
model
.
named_parameters
())
)
)
...
@@ -376,6 +375,8 @@ def same_model(model1, model2):
...
@@ -376,6 +375,8 @@ def same_model(model1, model2):
"""============================================================================"""
"""============================================================================"""
#################### TRAINER FUNCTION ############################
#################### TRAINER FUNCTION ############################
def
train_one_epoch_finetune
(
def
train_one_epoch_finetune
(
train_dataloader
,
model
,
optimizer
,
criterion
,
opt
,
epoch
train_dataloader
,
model
,
optimizer
,
criterion
,
opt
,
epoch
...
@@ -403,7 +404,6 @@ def train_one_epoch_finetune(
...
@@ -403,7 +404,6 @@ def train_one_epoch_finetune(
train_dataloader
,
desc
=
"Epoch {} Training gt labels..."
.
format
(
epoch
)
train_dataloader
,
desc
=
"Epoch {} Training gt labels..."
.
format
(
epoch
)
)
)
for
i
,
(
class_labels
,
input
)
in
enumerate
(
data_iterator
):
for
i
,
(
class_labels
,
input
)
in
enumerate
(
data_iterator
):
# Compute embeddings for input batch
# Compute embeddings for input batch
features
=
model
(
input
.
to
(
opt
.
device
))
features
=
model
(
input
.
to
(
opt
.
device
))
...
...
examples/pytorch/hilander/PSS/Smooth_AP/src/get_features.py
View file @
704bcaf6
...
@@ -263,7 +263,6 @@ _ = model.to(opt.device)
...
@@ -263,7 +263,6 @@ _ = model.to(opt.device)
# Place trainable parameter in list of parameters to train:
# Place trainable parameter in list of parameters to train:
if
"fc_lr_mul"
in
vars
(
opt
).
keys
()
and
opt
.
fc_lr_mul
!=
0
:
if
"fc_lr_mul"
in
vars
(
opt
).
keys
()
and
opt
.
fc_lr_mul
!=
0
:
all_but_fc_params
=
list
(
all_but_fc_params
=
list
(
filter
(
lambda
x
:
"last_linear"
not
in
x
[
0
],
model
.
named_parameters
())
filter
(
lambda
x
:
"last_linear"
not
in
x
[
0
],
model
.
named_parameters
())
)
)
...
...
examples/pytorch/hilander/PSS/Smooth_AP/src/losses.py
View file @
704bcaf6
...
@@ -11,6 +11,8 @@ import torch
...
@@ -11,6 +11,8 @@ import torch
from
scipy
import
sparse
from
scipy
import
sparse
"""================================================================================================="""
"""================================================================================================="""
############ LOSS SELECTION FUNCTION #####################
############ LOSS SELECTION FUNCTION #####################
def
loss_select
(
loss
,
opt
,
to_optim
):
def
loss_select
(
loss
,
opt
,
to_optim
):
"""
"""
...
...
examples/pytorch/hilander/PSS/Smooth_AP/src/main.py
View file @
704bcaf6
...
@@ -281,7 +281,6 @@ _ = model.to(opt.device)
...
@@ -281,7 +281,6 @@ _ = model.to(opt.device)
# Place trainable parameter in list of parameters to train:
# Place trainable parameter in list of parameters to train:
if
"fc_lr_mul"
in
vars
(
opt
).
keys
()
and
opt
.
fc_lr_mul
!=
0
:
if
"fc_lr_mul"
in
vars
(
opt
).
keys
()
and
opt
.
fc_lr_mul
!=
0
:
all_but_fc_params
=
list
(
all_but_fc_params
=
list
(
filter
(
lambda
x
:
"last_linear"
not
in
x
[
0
],
model
.
named_parameters
())
filter
(
lambda
x
:
"last_linear"
not
in
x
[
0
],
model
.
named_parameters
())
)
)
...
...
examples/pytorch/hilander/PSS/test_subg_inat.py
View file @
704bcaf6
import
argparse
,
time
,
os
,
pickle
import
argparse
,
os
,
pickle
,
time
import
random
import
random
import
sys
import
sys
sys
.
path
.
append
(
".."
)
sys
.
path
.
append
(
".."
)
from
utils.deduce
import
get_edge_dist
import
numpy
as
np
import
shutil
import
shutil
import
dgl
import
dgl
import
numpy
as
np
import
seaborn
import
torch
import
torch
import
torch.optim
as
optim
import
torch.optim
as
optim
from
models
import
LANDER
from
dataset
import
LanderDataset
from
dataset
import
LanderDataset
from
utils
import
evaluation
,
decode
,
build_next_level
,
stop_iterating
from
matplotlib
import
pyplot
as
plt
from
matplotlib
import
pyplot
as
plt
import
seaborn
from
models
import
LANDER
from
utils
import
build_next_level
,
decode
,
evaluation
,
stop_iterating
from
utils.deduce
import
get_edge_dist
STATISTIC
=
False
STATISTIC
=
False
...
@@ -25,43 +26,47 @@ STATISTIC = False
...
@@ -25,43 +26,47 @@ STATISTIC = False
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
# Dataset
# Dataset
parser
.
add_argument
(
'
--data_path
'
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
"
--data_path
"
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
'
--model_filename
'
,
type
=
str
,
default
=
'
lander.pth
'
)
parser
.
add_argument
(
"
--model_filename
"
,
type
=
str
,
default
=
"
lander.pth
"
)
parser
.
add_argument
(
'
--faiss_gpu
'
,
action
=
'
store_true
'
)
parser
.
add_argument
(
"
--faiss_gpu
"
,
action
=
"
store_true
"
)
parser
.
add_argument
(
'
--num_workers
'
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"
--num_workers
"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
'
--output_filename
'
,
type
=
str
,
default
=
'
data/features.pkl
'
)
parser
.
add_argument
(
"
--output_filename
"
,
type
=
str
,
default
=
"
data/features.pkl
"
)
# HyperParam
# HyperParam
parser
.
add_argument
(
'
--knn_k
'
,
type
=
int
,
default
=
10
)
parser
.
add_argument
(
"
--knn_k
"
,
type
=
int
,
default
=
10
)
parser
.
add_argument
(
'
--levels
'
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"
--levels
"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
'
--tau
'
,
type
=
float
,
default
=
0.5
)
parser
.
add_argument
(
"
--tau
"
,
type
=
float
,
default
=
0.5
)
parser
.
add_argument
(
'
--threshold
'
,
type
=
str
,
default
=
'
prob
'
)
parser
.
add_argument
(
"
--threshold
"
,
type
=
str
,
default
=
"
prob
"
)
parser
.
add_argument
(
'
--metrics
'
,
type
=
str
,
default
=
'
pairwise,bcubed,nmi
'
)
parser
.
add_argument
(
"
--metrics
"
,
type
=
str
,
default
=
"
pairwise,bcubed,nmi
"
)
parser
.
add_argument
(
'
--early_stop
'
,
action
=
'
store_true
'
)
parser
.
add_argument
(
"
--early_stop
"
,
action
=
"
store_true
"
)
# Model
# Model
parser
.
add_argument
(
'
--hidden
'
,
type
=
int
,
default
=
512
)
parser
.
add_argument
(
"
--hidden
"
,
type
=
int
,
default
=
512
)
parser
.
add_argument
(
'
--num_conv
'
,
type
=
int
,
default
=
4
)
parser
.
add_argument
(
"
--num_conv
"
,
type
=
int
,
default
=
4
)
parser
.
add_argument
(
'
--dropout
'
,
type
=
float
,
default
=
0.
)
parser
.
add_argument
(
"
--dropout
"
,
type
=
float
,
default
=
0.
0
)
parser
.
add_argument
(
'
--gat
'
,
action
=
'
store_true
'
)
parser
.
add_argument
(
"
--gat
"
,
action
=
"
store_true
"
)
parser
.
add_argument
(
'
--gat_k
'
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"
--gat_k
"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
'
--balance
'
,
action
=
'
store_true
'
)
parser
.
add_argument
(
"
--balance
"
,
action
=
"
store_true
"
)
parser
.
add_argument
(
'
--use_cluster_feat
'
,
action
=
'
store_true
'
)
parser
.
add_argument
(
"
--use_cluster_feat
"
,
action
=
"
store_true
"
)
parser
.
add_argument
(
'
--use_focal_loss
'
,
action
=
'
store_true
'
)
parser
.
add_argument
(
"
--use_focal_loss
"
,
action
=
"
store_true
"
)
parser
.
add_argument
(
'
--use_gt
'
,
action
=
'
store_true
'
)
parser
.
add_argument
(
"
--use_gt
"
,
action
=
"
store_true
"
)
# Subgraph
# Subgraph
parser
.
add_argument
(
'--batch_size'
,
type
=
int
,
default
=
4096
)
parser
.
add_argument
(
"--batch_size"
,
type
=
int
,
default
=
4096
)
parser
.
add_argument
(
'--mode'
,
type
=
str
,
default
=
"1head"
)
parser
.
add_argument
(
"--mode"
,
type
=
str
,
default
=
"1head"
)
parser
.
add_argument
(
'--midpoint'
,
type
=
str
,
default
=
"false"
)
parser
.
add_argument
(
"--midpoint"
,
type
=
str
,
default
=
"false"
)
parser
.
add_argument
(
'--linsize'
,
type
=
int
,
default
=
29011
)
parser
.
add_argument
(
"--linsize"
,
type
=
int
,
default
=
29011
)
parser
.
add_argument
(
'--uinsize'
,
type
=
int
,
default
=
18403
)
parser
.
add_argument
(
"--uinsize"
,
type
=
int
,
default
=
18403
)
parser
.
add_argument
(
'--inclasses'
,
type
=
int
,
default
=
948
)
parser
.
add_argument
(
"--inclasses"
,
type
=
int
,
default
=
948
)
parser
.
add_argument
(
'--thresh'
,
type
=
float
,
default
=
1.0
)
parser
.
add_argument
(
"--thresh"
,
type
=
float
,
default
=
1.0
)
parser
.
add_argument
(
'--draw'
,
type
=
str
,
default
=
'false'
)
parser
.
add_argument
(
"--draw"
,
type
=
str
,
default
=
"false"
)
parser
.
add_argument
(
'--density_distance_pkl'
,
type
=
str
,
default
=
"density_distance.pkl"
)
parser
.
add_argument
(
parser
.
add_argument
(
'--density_lindistance_jpg'
,
type
=
str
,
default
=
"density_lindistance.jpg"
)
"--density_distance_pkl"
,
type
=
str
,
default
=
"density_distance.pkl"
)
parser
.
add_argument
(
"--density_lindistance_jpg"
,
type
=
str
,
default
=
"density_lindistance.jpg"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
print
(
args
)
print
(
args
)
...
@@ -70,21 +75,21 @@ linsize = args.linsize
...
@@ -70,21 +75,21 @@ linsize = args.linsize
uinsize
=
args
.
uinsize
uinsize
=
args
.
uinsize
inclasses
=
args
.
inclasses
inclasses
=
args
.
inclasses
if
args
.
draw
==
'
false
'
:
if
args
.
draw
==
"
false
"
:
args
.
draw
=
False
args
.
draw
=
False
elif
args
.
draw
==
'
true
'
:
elif
args
.
draw
==
"
true
"
:
args
.
draw
=
True
args
.
draw
=
True
###########################
###########################
# Environment Configuration
# Environment Configuration
if
torch
.
cuda
.
is_available
():
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
'
cuda
'
)
device
=
torch
.
device
(
"
cuda
"
)
else
:
else
:
device
=
torch
.
device
(
'
cpu
'
)
device
=
torch
.
device
(
"
cpu
"
)
##################
##################
# Data Preparation
# Data Preparation
with
open
(
args
.
data_path
,
'
rb
'
)
as
f
:
with
open
(
args
.
data_path
,
"
rb
"
)
as
f
:
loaded_data
=
pickle
.
load
(
f
)
loaded_data
=
pickle
.
load
(
f
)
path2idx
,
features
,
pred_labels
,
labels
,
masks
=
loaded_data
path2idx
,
features
,
pred_labels
,
labels
,
masks
=
loaded_data
...
@@ -123,11 +128,12 @@ else:
...
@@ -123,11 +128,12 @@ else:
print
(
"filtered features:"
,
len
(
features
))
print
(
"filtered features:"
,
len
(
features
))
global_features
=
features
.
copy
()
# global features
global_features
=
features
.
copy
()
# global features
dataset
=
LanderDataset
(
features
=
features
,
labels
=
labels
,
k
=
args
.
knn_k
,
dataset
=
LanderDataset
(
levels
=
1
,
faiss_gpu
=
False
)
features
=
features
,
labels
=
labels
,
k
=
args
.
knn_k
,
levels
=
1
,
faiss_gpu
=
False
)
g
=
dataset
.
gs
[
0
]
g
=
dataset
.
gs
[
0
]
g
.
ndata
[
'
pred_den
'
]
=
torch
.
zeros
((
g
.
number_of_nodes
()))
g
.
ndata
[
"
pred_den
"
]
=
torch
.
zeros
((
g
.
number_of_nodes
()))
g
.
edata
[
'
prob_conn
'
]
=
torch
.
zeros
((
g
.
number_of_edges
(),
2
))
g
.
edata
[
"
prob_conn
"
]
=
torch
.
zeros
((
g
.
number_of_edges
(),
2
))
global_labels
=
labels
.
copy
()
global_labels
=
labels
.
copy
()
ids
=
np
.
arange
(
g
.
number_of_nodes
())
ids
=
np
.
arange
(
g
.
number_of_nodes
())
global_edges
=
([],
[])
global_edges
=
([],
[])
...
@@ -135,7 +141,7 @@ global_peaks = np.array([], dtype=np.long)
...
@@ -135,7 +141,7 @@ global_peaks = np.array([], dtype=np.long)
global_edges_len
=
len
(
global_edges
[
0
])
global_edges_len
=
len
(
global_edges
[
0
])
global_num_nodes
=
g
.
number_of_nodes
()
global_num_nodes
=
g
.
number_of_nodes
()
global_densities
=
g
.
ndata
[
'
density
'
][:
linsize
]
global_densities
=
g
.
ndata
[
"
density
"
][:
linsize
]
global_densities
=
np
.
sort
(
global_densities
)
global_densities
=
np
.
sort
(
global_densities
)
xs
=
np
.
arange
(
len
(
global_densities
))
xs
=
np
.
arange
(
len
(
global_densities
))
...
@@ -143,23 +149,30 @@ fanouts = [args.knn_k - 1 for i in range(args.num_conv + 1)]
...
@@ -143,23 +149,30 @@ fanouts = [args.knn_k - 1 for i in range(args.num_conv + 1)]
sampler
=
dgl
.
dataloading
.
MultiLayerNeighborSampler
(
fanouts
)
sampler
=
dgl
.
dataloading
.
MultiLayerNeighborSampler
(
fanouts
)
# fix the number of edges
# fix the number of edges
test_loader
=
dgl
.
dataloading
.
DataLoader
(
test_loader
=
dgl
.
dataloading
.
DataLoader
(
g
,
torch
.
arange
(
g
.
number_of_nodes
()),
sampler
,
g
,
torch
.
arange
(
g
.
number_of_nodes
()),
sampler
,
batch_size
=
args
.
batch_size
,
batch_size
=
args
.
batch_size
,
shuffle
=
False
,
shuffle
=
False
,
drop_last
=
False
,
drop_last
=
False
,
num_workers
=
args
.
num_workers
num_workers
=
args
.
num_workers
,
)
)
##################
##################
# Model Definition
# Model Definition
if
not
args
.
use_gt
:
if
not
args
.
use_gt
:
feature_dim
=
g
.
ndata
[
'features'
].
shape
[
1
]
feature_dim
=
g
.
ndata
[
"features"
].
shape
[
1
]
model
=
LANDER
(
feature_dim
=
feature_dim
,
nhid
=
args
.
hidden
,
model
=
LANDER
(
num_conv
=
args
.
num_conv
,
dropout
=
args
.
dropout
,
feature_dim
=
feature_dim
,
use_GAT
=
args
.
gat
,
K
=
args
.
gat_k
,
nhid
=
args
.
hidden
,
balance
=
args
.
balance
,
num_conv
=
args
.
num_conv
,
use_cluster_feat
=
args
.
use_cluster_feat
,
dropout
=
args
.
dropout
,
use_focal_loss
=
args
.
use_focal_loss
)
use_GAT
=
args
.
gat
,
K
=
args
.
gat_k
,
balance
=
args
.
balance
,
use_cluster_feat
=
args
.
use_cluster_feat
,
use_focal_loss
=
args
.
use_focal_loss
,
)
model
.
load_state_dict
(
torch
.
load
(
args
.
model_filename
))
model
.
load_state_dict
(
torch
.
load
(
args
.
model_filename
))
model
=
model
.
to
(
device
)
model
=
model
.
to
(
device
)
model
.
eval
()
model
.
eval
()
...
@@ -179,46 +192,82 @@ for level in range(args.levels):
...
@@ -179,46 +192,82 @@ for level in range(args.levels):
with
torch
.
no_grad
():
with
torch
.
no_grad
():
output_bipartite
=
model
(
bipartites
)
output_bipartite
=
model
(
bipartites
)
global_nid
=
output_bipartite
.
dstdata
[
dgl
.
NID
]
global_nid
=
output_bipartite
.
dstdata
[
dgl
.
NID
]
global_eid
=
output_bipartite
.
edata
[
'global_eid'
]
global_eid
=
output_bipartite
.
edata
[
"global_eid"
]
g
.
ndata
[
'pred_den'
][
global_nid
]
=
output_bipartite
.
dstdata
[
'pred_den'
].
to
(
'cpu'
)
g
.
ndata
[
"pred_den"
][
global_nid
]
=
output_bipartite
.
dstdata
[
g
.
edata
[
'prob_conn'
][
global_eid
]
=
output_bipartite
.
edata
[
'prob_conn'
].
to
(
'cpu'
)
"pred_den"
].
to
(
"cpu"
)
g
.
edata
[
"prob_conn"
][
global_eid
]
=
output_bipartite
.
edata
[
"prob_conn"
].
to
(
"cpu"
)
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
if
(
batch
+
1
)
%
10
==
0
:
if
(
batch
+
1
)
%
10
==
0
:
print
(
'Batch %d / %d for inference'
%
(
batch
,
total_batches
))
print
(
"Batch %d / %d for inference"
%
(
batch
,
total_batches
))
new_pred_labels
,
peaks
,
\
(
global_edges
,
global_pred_labels
,
global_peaks
=
decode
(
g
,
args
.
tau
,
args
.
threshold
,
args
.
use_gt
,
new_pred_labels
,
ids
,
global_edges
,
global_num_nodes
,
peaks
,
global_peaks
)
global_edges
,
global_pred_labels
,
global_peaks
,
)
=
decode
(
g
,
args
.
tau
,
args
.
threshold
,
args
.
use_gt
,
ids
,
global_edges
,
global_num_nodes
,
global_peaks
,
)
if
level
==
0
:
if
level
==
0
:
global_pred_densities
=
g
.
ndata
[
'
pred_den
'
]
global_pred_densities
=
g
.
ndata
[
"
pred_den
"
]
global_densities
=
g
.
ndata
[
'
density
'
]
global_densities
=
g
.
ndata
[
"
density
"
]
g
.
edata
[
'
prob_conn
'
]
=
torch
.
zeros
((
g
.
number_of_edges
(),
2
))
g
.
edata
[
"
prob_conn
"
]
=
torch
.
zeros
((
g
.
number_of_edges
(),
2
))
ids
=
ids
[
peaks
]
ids
=
ids
[
peaks
]
new_global_edges_len
=
len
(
global_edges
[
0
])
new_global_edges_len
=
len
(
global_edges
[
0
])
num_edges_add_this_level
=
new_global_edges_len
-
global_edges_len
num_edges_add_this_level
=
new_global_edges_len
-
global_edges_len
if
stop_iterating
(
level
,
args
.
levels
,
args
.
early_stop
,
num_edges_add_this_level
,
num_edges_add_last_level
,
if
stop_iterating
(
args
.
knn_k
):
level
,
args
.
levels
,
args
.
early_stop
,
num_edges_add_this_level
,
num_edges_add_last_level
,
args
.
knn_k
,
):
break
break
global_edges_len
=
new_global_edges_len
global_edges_len
=
new_global_edges_len
num_edges_add_last_level
=
num_edges_add_this_level
num_edges_add_last_level
=
num_edges_add_this_level
# build new dataset
# build new dataset
features
,
labels
,
cluster_features
=
build_next_level
(
features
,
labels
,
peaks
,
features
,
labels
,
cluster_features
=
build_next_level
(
global_features
,
global_pred_labels
,
global_peaks
)
features
,
labels
,
peaks
,
global_features
,
global_pred_labels
,
global_peaks
,
)
# After the first level, the number of nodes reduce a lot. Using cpu faiss is faster.
# After the first level, the number of nodes reduce a lot. Using cpu faiss is faster.
dataset
=
LanderDataset
(
features
=
features
,
labels
=
labels
,
k
=
args
.
knn_k
,
dataset
=
LanderDataset
(
levels
=
1
,
faiss_gpu
=
False
,
cluster_features
=
cluster_features
)
features
=
features
,
labels
=
labels
,
k
=
args
.
knn_k
,
levels
=
1
,
faiss_gpu
=
False
,
cluster_features
=
cluster_features
,
)
g
=
dataset
.
gs
[
0
]
g
=
dataset
.
gs
[
0
]
g
.
ndata
[
'
pred_den
'
]
=
torch
.
zeros
((
g
.
number_of_nodes
()))
g
.
ndata
[
"
pred_den
"
]
=
torch
.
zeros
((
g
.
number_of_nodes
()))
g
.
edata
[
'
prob_conn
'
]
=
torch
.
zeros
((
g
.
number_of_edges
(),
2
))
g
.
edata
[
"
prob_conn
"
]
=
torch
.
zeros
((
g
.
number_of_edges
(),
2
))
test_loader
=
dgl
.
dataloading
.
DataLoader
(
test_loader
=
dgl
.
dataloading
.
DataLoader
(
g
,
torch
.
arange
(
g
.
number_of_nodes
()),
sampler
,
g
,
torch
.
arange
(
g
.
number_of_nodes
()),
sampler
,
batch_size
=
args
.
batch_size
,
batch_size
=
args
.
batch_size
,
shuffle
=
False
,
shuffle
=
False
,
drop_last
=
False
,
drop_last
=
False
,
num_workers
=
args
.
num_workers
num_workers
=
args
.
num_workers
,
)
)
if
MODE
==
"selectbydensity"
:
if
MODE
==
"selectbydensity"
:
...
@@ -261,8 +310,10 @@ if MODE == "selectbydensity":
...
@@ -261,8 +310,10 @@ if MODE == "selectbydensity":
idx
=
np
.
where
(
l_in_gt_new
==
i
)
idx
=
np
.
where
(
l_in_gt_new
==
i
)
prototypes
[
i
]
=
np
.
mean
(
l_in_features
[
idx
],
axis
=
0
)
prototypes
[
i
]
=
np
.
mean
(
l_in_features
[
idx
],
axis
=
0
)
similarity_matrix
=
torch
.
mm
(
torch
.
from_numpy
(
global_features
.
astype
(
np
.
float32
)),
similarity_matrix
=
torch
.
mm
(
torch
.
from_numpy
(
prototypes
.
astype
(
np
.
float32
)).
t
())
torch
.
from_numpy
(
global_features
.
astype
(
np
.
float32
)),
torch
.
from_numpy
(
prototypes
.
astype
(
np
.
float32
)).
t
(),
)
similarity_matrix
=
(
1
-
similarity_matrix
)
/
2
similarity_matrix
=
(
1
-
similarity_matrix
)
/
2
minvalues
,
selected_pred_labels
=
torch
.
min
(
similarity_matrix
,
1
)
minvalues
,
selected_pred_labels
=
torch
.
min
(
similarity_matrix
,
1
)
# far-close ratio
# far-close ratio
...
@@ -274,7 +325,7 @@ if MODE == "selectbydensity":
...
@@ -274,7 +325,7 @@ if MODE == "selectbydensity":
cutidx
=
np
.
where
(
global_pred_densities
>=
0.5
)
cutidx
=
np
.
where
(
global_pred_densities
>=
0.5
)
draw_minvalues
=
minvalues
[
cutidx
]
draw_minvalues
=
minvalues
[
cutidx
]
draw_densities
=
global_pred_densities
[
cutidx
]
draw_densities
=
global_pred_densities
[
cutidx
]
with
open
(
args
.
density_distance_pkl
,
'
wb
'
)
as
f
:
with
open
(
args
.
density_distance_pkl
,
"
wb
"
)
as
f
:
pickle
.
dump
((
global_pred_densities
,
minvalues
),
f
)
pickle
.
dump
((
global_pred_densities
,
minvalues
),
f
)
print
(
"dumped."
)
print
(
"dumped."
)
plt
.
clf
()
plt
.
clf
()
...
@@ -283,15 +334,29 @@ if MODE == "selectbydensity":
...
@@ -283,15 +334,29 @@ if MODE == "selectbydensity":
if
len
(
draw_densities
)
>
10000
:
if
len
(
draw_densities
)
>
10000
:
samples_idx
=
random
.
sample
(
range
(
len
(
draw_minvalues
)),
10000
)
samples_idx
=
random
.
sample
(
range
(
len
(
draw_minvalues
)),
10000
)
ax
.
plot
(
draw_densities
[
random
],
draw_minvalues
[
random
],
color
=
'tab:blue'
,
marker
=
'*'
,
linestyle
=
"None"
,
ax
.
plot
(
markersize
=
1
)
draw_densities
[
random
],
draw_minvalues
[
random
],
color
=
"tab:blue"
,
marker
=
"*"
,
linestyle
=
"None"
,
markersize
=
1
,
)
else
:
else
:
ax
.
plot
(
draw_densities
[
random
],
draw_minvalues
[
random
],
color
=
'tab:blue'
,
marker
=
'*'
,
linestyle
=
"None"
,
ax
.
plot
(
markersize
=
1
)
draw_densities
[
random
],
draw_minvalues
[
random
],
color
=
"tab:blue"
,
marker
=
"*"
,
linestyle
=
"None"
,
markersize
=
1
,
)
plt
.
savefig
(
args
.
density_lindistance_jpg
)
plt
.
savefig
(
args
.
density_lindistance_jpg
)
global_pred_labels_new
[
Tidx
]
=
l_in_gt_new
global_pred_labels_new
[
Tidx
]
=
l_in_gt_new
global_pred_labels
[
selectidx
]
=
global_pred_labels
[
selectidx
]
+
len
(
l_in_unique
)
global_pred_labels
[
selectidx
]
=
global_pred_labels
[
selectidx
]
+
len
(
l_in_unique
)
global_pred_labels_new
[
selectedidx
]
=
global_pred_labels
global_pred_labels_new
[
selectedidx
]
=
global_pred_labels
global_pred_labels
=
global_pred_labels_new
global_pred_labels
=
global_pred_labels_new
...
@@ -332,7 +397,9 @@ if MODE == "recluster":
...
@@ -332,7 +397,9 @@ if MODE == "recluster":
global_pred_labels_new
[
Tidx
]
=
l_in_gt_new
global_pred_labels_new
[
Tidx
]
=
l_in_gt_new
print
(
len
(
global_pred_labels
))
print
(
len
(
global_pred_labels
))
print
(
len
(
selectedidx
[
0
]))
print
(
len
(
selectedidx
[
0
]))
global_pred_labels_new
[
selectedidx
[
0
]]
=
global_pred_labels
+
len
(
l_in_unique
)
global_pred_labels_new
[
selectedidx
[
0
]]
=
global_pred_labels
+
len
(
l_in_unique
)
global_pred_labels
=
global_pred_labels_new
global_pred_labels
=
global_pred_labels_new
global_masks
=
masks
global_masks
=
masks
print
(
"mask0"
,
len
(
np
.
where
(
global_masks
==
0
)[
0
]))
print
(
"mask0"
,
len
(
np
.
where
(
global_masks
==
0
)[
0
]))
...
@@ -348,23 +415,29 @@ if MODE == "donothing":
...
@@ -348,23 +415,29 @@ if MODE == "donothing":
print
(
"##################### L_in ########################"
)
print
(
"##################### L_in ########################"
)
print
(
linsize
)
print
(
linsize
)
if
len
(
global_pred_labels
)
>=
linsize
:
if
len
(
global_pred_labels
)
>=
linsize
:
evaluation
(
global_pred_labels
[:
linsize
],
global_gt_labels
[:
linsize
],
args
.
metrics
)
evaluation
(
global_pred_labels
[:
linsize
],
global_gt_labels
[:
linsize
],
args
.
metrics
)
else
:
else
:
print
(
"No samples in L_in!"
)
print
(
"No samples in L_in!"
)
print
(
"##################### U_in ########################"
)
print
(
"##################### U_in ########################"
)
uinidx
=
np
.
where
(
global_pred_labels
[
linsize
:
linsize
+
uinsize
]
!=
-
1
)[
0
]
uinidx
=
np
.
where
(
global_pred_labels
[
linsize
:
linsize
+
uinsize
]
!=
-
1
)[
0
]
uinidx
=
uinidx
+
linsize
uinidx
=
uinidx
+
linsize
print
(
len
(
uinidx
))
print
(
len
(
uinidx
))
if
len
(
uinidx
):
if
len
(
uinidx
):
evaluation
(
global_pred_labels
[
uinidx
],
global_gt_labels
[
uinidx
],
args
.
metrics
)
evaluation
(
global_pred_labels
[
uinidx
],
global_gt_labels
[
uinidx
],
args
.
metrics
)
else
:
else
:
print
(
"No samples in U_in!"
)
print
(
"No samples in U_in!"
)
print
(
"##################### U_out ########################"
)
print
(
"##################### U_out ########################"
)
uoutidx
=
np
.
where
(
global_pred_labels
[
linsize
+
uinsize
:]
!=
-
1
)[
0
]
uoutidx
=
np
.
where
(
global_pred_labels
[
linsize
+
uinsize
:]
!=
-
1
)[
0
]
uoutidx
=
uoutidx
+
linsize
+
uinsize
uoutidx
=
uoutidx
+
linsize
+
uinsize
print
(
len
(
uoutidx
))
print
(
len
(
uoutidx
))
if
len
(
uoutidx
):
if
len
(
uoutidx
):
evaluation
(
global_pred_labels
[
uoutidx
],
global_gt_labels
[
uoutidx
],
args
.
metrics
)
evaluation
(
global_pred_labels
[
uoutidx
],
global_gt_labels
[
uoutidx
],
args
.
metrics
)
else
:
else
:
print
(
"No samples in U_out!"
)
print
(
"No samples in U_out!"
)
print
(
"##################### U ########################"
)
print
(
"##################### U ########################"
)
...
@@ -390,9 +463,18 @@ print(len(nsidx))
...
@@ -390,9 +463,18 @@ print(len(nsidx))
if
len
(
nsidx
)
!=
0
:
if
len
(
nsidx
)
!=
0
:
evaluation
(
global_pred_labels
[
nsidx
],
global_gt_labels
[
nsidx
],
args
.
metrics
)
evaluation
(
global_pred_labels
[
nsidx
],
global_gt_labels
[
nsidx
],
args
.
metrics
)
with
open
(
args
.
output_filename
,
'
wb
'
)
as
f
:
with
open
(
args
.
output_filename
,
"
wb
"
)
as
f
:
print
(
orifeatures
.
shape
)
print
(
orifeatures
.
shape
)
print
(
global_pred_labels
.
shape
)
print
(
global_pred_labels
.
shape
)
print
(
global_gt_labels
.
shape
)
print
(
global_gt_labels
.
shape
)
print
(
global_masks
.
shape
)
print
(
global_masks
.
shape
)
pickle
.
dump
([
path2idx
,
orifeatures
,
global_pred_labels
,
global_gt_labels
,
global_masks
],
f
)
pickle
.
dump
(
[
path2idx
,
orifeatures
,
global_pred_labels
,
global_gt_labels
,
global_masks
,
],
f
,
)
examples/pytorch/hilander/PSS/train_subg_inat.py
View file @
704bcaf6
import
argparse
,
time
,
os
,
pickle
import
argparse
,
os
,
pickle
,
time
import
random
import
random
import
numpy
as
np
import
sys
import
dgl
import
dgl
import
numpy
as
np
import
torch
import
torch
import
torch.optim
as
optim
import
torch.optim
as
optim
import
sys
sys
.
path
.
append
(
".."
)
sys
.
path
.
append
(
".."
)
from
models
import
LANDER
from
dataset
import
LanderDataset
from
dataset
import
LanderDataset
from
models
import
LANDER
###########
###########
# ArgParser
# ArgParser
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
# Dataset
# Dataset
parser
.
add_argument
(
'
--data_path
'
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
"
--data_path
"
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
'
--levels
'
,
type
=
str
,
default
=
'1'
)
parser
.
add_argument
(
"
--levels
"
,
type
=
str
,
default
=
"1"
)
parser
.
add_argument
(
'
--faiss_gpu
'
,
action
=
'
store_true
'
)
parser
.
add_argument
(
"
--faiss_gpu
"
,
action
=
"
store_true
"
)
parser
.
add_argument
(
'
--model_filename
'
,
type
=
str
,
default
=
'
lander.pth
'
)
parser
.
add_argument
(
"
--model_filename
"
,
type
=
str
,
default
=
"
lander.pth
"
)
# KNN
# KNN
parser
.
add_argument
(
'
--knn_k
'
,
type
=
str
,
default
=
'
10
'
)
parser
.
add_argument
(
"
--knn_k
"
,
type
=
str
,
default
=
"
10
"
)
parser
.
add_argument
(
'
--num_workers
'
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"
--num_workers
"
,
type
=
int
,
default
=
0
)
# Model
# Model
parser
.
add_argument
(
'
--hidden
'
,
type
=
int
,
default
=
512
)
parser
.
add_argument
(
"
--hidden
"
,
type
=
int
,
default
=
512
)
parser
.
add_argument
(
'
--num_conv
'
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"
--num_conv
"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
'
--dropout
'
,
type
=
float
,
default
=
0.
)
parser
.
add_argument
(
"
--dropout
"
,
type
=
float
,
default
=
0.
0
)
parser
.
add_argument
(
'
--gat
'
,
action
=
'
store_true
'
)
parser
.
add_argument
(
"
--gat
"
,
action
=
"
store_true
"
)
parser
.
add_argument
(
'
--gat_k
'
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"
--gat_k
"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
'
--balance
'
,
action
=
'
store_true
'
)
parser
.
add_argument
(
"
--balance
"
,
action
=
"
store_true
"
)
parser
.
add_argument
(
'
--use_cluster_feat
'
,
action
=
'
store_true
'
)
parser
.
add_argument
(
"
--use_cluster_feat
"
,
action
=
"
store_true
"
)
parser
.
add_argument
(
'
--use_focal_loss
'
,
action
=
'
store_true
'
)
parser
.
add_argument
(
"
--use_focal_loss
"
,
action
=
"
store_true
"
)
# Training
# Training
parser
.
add_argument
(
'
--epochs
'
,
type
=
int
,
default
=
100
)
parser
.
add_argument
(
"
--epochs
"
,
type
=
int
,
default
=
100
)
parser
.
add_argument
(
'
--batch_size
'
,
type
=
int
,
default
=
1024
)
parser
.
add_argument
(
"
--batch_size
"
,
type
=
int
,
default
=
1024
)
parser
.
add_argument
(
'
--lr
'
,
type
=
float
,
default
=
0.1
)
parser
.
add_argument
(
"
--lr
"
,
type
=
float
,
default
=
0.1
)
parser
.
add_argument
(
'
--momentum
'
,
type
=
float
,
default
=
0.9
)
parser
.
add_argument
(
"
--momentum
"
,
type
=
float
,
default
=
0.9
)
parser
.
add_argument
(
'
--weight_decay
'
,
type
=
float
,
default
=
1e-5
)
parser
.
add_argument
(
"
--weight_decay
"
,
type
=
float
,
default
=
1e-5
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
print
(
args
)
print
(
args
)
...
@@ -49,9 +50,9 @@ print(args)
...
@@ -49,9 +50,9 @@ print(args)
###########################
###########################
# Environment Configuration
# Environment Configuration
if
torch
.
cuda
.
is_available
():
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
'
cuda
'
)
device
=
torch
.
device
(
"
cuda
"
)
else
:
else
:
device
=
torch
.
device
(
'
cpu
'
)
device
=
torch
.
device
(
"
cpu
"
)
def
setup_seed
(
seed
):
def
setup_seed
(
seed
):
...
@@ -66,7 +67,7 @@ def setup_seed(seed):
...
@@ -66,7 +67,7 @@ def setup_seed(seed):
##################
##################
# Data Preparation
# Data Preparation
with
open
(
args
.
data_path
,
'
rb
'
)
as
f
:
with
open
(
args
.
data_path
,
"
rb
"
)
as
f
:
path2idx
,
features
,
labels
,
_
,
masks
=
pickle
.
load
(
f
)
path2idx
,
features
,
labels
,
_
,
masks
=
pickle
.
load
(
f
)
# lidx = np.where(masks==0)
# lidx = np.where(masks==0)
# features = features[lidx]
# features = features[lidx]
...
@@ -75,8 +76,8 @@ with open(args.data_path, 'rb') as f:
...
@@ -75,8 +76,8 @@ with open(args.data_path, 'rb') as f:
print
(
"labels.shape:"
,
labels
.
shape
)
print
(
"labels.shape:"
,
labels
.
shape
)
k_list
=
[
int
(
k
)
for
k
in
args
.
knn_k
.
split
(
','
)]
k_list
=
[
int
(
k
)
for
k
in
args
.
knn_k
.
split
(
","
)]
lvl_list
=
[
int
(
l
)
for
l
in
args
.
levels
.
split
(
','
)]
lvl_list
=
[
int
(
l
)
for
l
in
args
.
levels
.
split
(
","
)]
gs
=
[]
gs
=
[]
nbrs
=
[]
nbrs
=
[]
ks
=
[]
ks
=
[]
...
@@ -84,8 +85,13 @@ datasets = []
...
@@ -84,8 +85,13 @@ datasets = []
for
k
,
l
in
zip
(
k_list
,
lvl_list
):
for
k
,
l
in
zip
(
k_list
,
lvl_list
):
print
(
"k:"
,
k
)
print
(
"k:"
,
k
)
print
(
"levels:"
,
l
)
print
(
"levels:"
,
l
)
dataset
=
LanderDataset
(
features
=
features
,
labels
=
labels
,
k
=
k
,
dataset
=
LanderDataset
(
levels
=
l
,
faiss_gpu
=
args
.
faiss_gpu
)
features
=
features
,
labels
=
labels
,
k
=
k
,
levels
=
l
,
faiss_gpu
=
args
.
faiss_gpu
,
)
gs
+=
[
g
for
g
in
dataset
.
gs
]
gs
+=
[
g
for
g
in
dataset
.
gs
]
ks
+=
[
k
for
g
in
dataset
.
gs
]
ks
+=
[
k
for
g
in
dataset
.
gs
]
nbrs
+=
[
nbr
for
nbr
in
dataset
.
nbrs
]
nbrs
+=
[
nbr
for
nbr
in
dataset
.
nbrs
]
...
@@ -101,24 +107,28 @@ for k, l in zip(k_list, lvl_list):
...
@@ -101,24 +107,28 @@ for k, l in zip(k_list, lvl_list):
# nbrs += [nbr for nbr in dataset.nbrs]
# nbrs += [nbr for nbr in dataset.nbrs]
with
open
(
"./dataset.pkl"
,
'
wb
'
)
as
f
:
with
open
(
"./dataset.pkl"
,
"
wb
"
)
as
f
:
pickle
.
dump
(
datasets
,
f
)
pickle
.
dump
(
datasets
,
f
)
print
(
'Dataset Prepared.'
)
print
(
"Dataset Prepared."
)
def
set_train_sampler_loader
(
g
,
k
):
def
set_train_sampler_loader
(
g
,
k
):
fanouts
=
[
k
-
1
for
i
in
range
(
args
.
num_conv
+
1
)]
fanouts
=
[
k
-
1
for
i
in
range
(
args
.
num_conv
+
1
)]
sampler
=
dgl
.
dataloading
.
MultiLayerNeighborSampler
(
fanouts
)
sampler
=
dgl
.
dataloading
.
MultiLayerNeighborSampler
(
fanouts
)
# fix the number of edges
# fix the number of edges
train_dataloader
=
dgl
.
dataloading
.
DataLoader
(
train_dataloader
=
dgl
.
dataloading
.
DataLoader
(
g
,
torch
.
arange
(
g
.
number_of_nodes
()),
sampler
,
g
,
torch
.
arange
(
g
.
number_of_nodes
()),
sampler
,
batch_size
=
args
.
batch_size
,
batch_size
=
args
.
batch_size
,
shuffle
=
True
,
shuffle
=
True
,
drop_last
=
False
,
drop_last
=
False
,
num_workers
=
args
.
num_workers
num_workers
=
args
.
num_workers
,
)
)
return
train_dataloader
return
train_dataloader
train_loaders
=
[]
train_loaders
=
[]
for
gidx
,
g
in
enumerate
(
gs
):
for
gidx
,
g
in
enumerate
(
gs
):
train_dataloader
=
set_train_sampler_loader
(
gs
[
gidx
],
ks
[
gidx
])
train_dataloader
=
set_train_sampler_loader
(
gs
[
gidx
],
ks
[
gidx
])
...
@@ -126,31 +136,40 @@ for gidx, g in enumerate(gs):
...
@@ -126,31 +136,40 @@ for gidx, g in enumerate(gs):
##################
##################
# Model Definition
# Model Definition
feature_dim
=
gs
[
0
].
ndata
[
'
features
'
].
shape
[
1
]
feature_dim
=
gs
[
0
].
ndata
[
"
features
"
].
shape
[
1
]
print
(
"feature dimension:"
,
feature_dim
)
print
(
"feature dimension:"
,
feature_dim
)
model
=
LANDER
(
feature_dim
=
feature_dim
,
nhid
=
args
.
hidden
,
model
=
LANDER
(
num_conv
=
args
.
num_conv
,
dropout
=
args
.
dropout
,
feature_dim
=
feature_dim
,
use_GAT
=
args
.
gat
,
K
=
args
.
gat_k
,
nhid
=
args
.
hidden
,
balance
=
args
.
balance
,
num_conv
=
args
.
num_conv
,
use_cluster_feat
=
args
.
use_cluster_feat
,
dropout
=
args
.
dropout
,
use_focal_loss
=
args
.
use_focal_loss
)
use_GAT
=
args
.
gat
,
K
=
args
.
gat_k
,
balance
=
args
.
balance
,
use_cluster_feat
=
args
.
use_cluster_feat
,
use_focal_loss
=
args
.
use_focal_loss
,
)
model
=
model
.
to
(
device
)
model
=
model
.
to
(
device
)
model
.
train
()
model
.
train
()
#################
#################
# Hyperparameters
# Hyperparameters
opt
=
optim
.
SGD
(
model
.
parameters
(),
lr
=
args
.
lr
,
momentum
=
args
.
momentum
,
opt
=
optim
.
SGD
(
weight_decay
=
args
.
weight_decay
)
model
.
parameters
(),
lr
=
args
.
lr
,
momentum
=
args
.
momentum
,
weight_decay
=
args
.
weight_decay
,
)
# keep num_batch_per_loader the same for every sub_dataloader
# keep num_batch_per_loader the same for every sub_dataloader
num_batch_per_loader
=
len
(
train_loaders
[
0
])
num_batch_per_loader
=
len
(
train_loaders
[
0
])
train_loaders
=
[
iter
(
train_loader
)
for
train_loader
in
train_loaders
]
train_loaders
=
[
iter
(
train_loader
)
for
train_loader
in
train_loaders
]
num_loaders
=
len
(
train_loaders
)
num_loaders
=
len
(
train_loaders
)
scheduler
=
optim
.
lr_scheduler
.
CosineAnnealingLR
(
opt
,
scheduler
=
optim
.
lr_scheduler
.
CosineAnnealingLR
(
T_max
=
args
.
epochs
*
num_batch_per_loader
*
num_loaders
,
opt
,
T_max
=
args
.
epochs
*
num_batch_per_loader
*
num_loaders
,
eta_min
=
1e-5
eta_min
=
1e-5
)
)
print
(
'
Start Training.
'
)
print
(
"
Start Training.
"
)
###############
###############
# Training Loop
# Training Loop
...
@@ -163,7 +182,9 @@ for epoch in range(args.epochs):
...
@@ -163,7 +182,9 @@ for epoch in range(args.epochs):
try
:
try
:
minibatch
=
next
(
train_loaders
[
loader_id
])
minibatch
=
next
(
train_loaders
[
loader_id
])
except
:
except
:
train_loaders
[
loader_id
]
=
iter
(
set_train_sampler_loader
(
gs
[
loader_id
],
ks
[
loader_id
]))
train_loaders
[
loader_id
]
=
iter
(
set_train_sampler_loader
(
gs
[
loader_id
],
ks
[
loader_id
])
)
minibatch
=
next
(
train_loaders
[
loader_id
])
minibatch
=
next
(
train_loaders
[
loader_id
])
input_nodes
,
sub_g
,
bipartites
=
minibatch
input_nodes
,
sub_g
,
bipartites
=
minibatch
sub_g
=
sub_g
.
to
(
device
)
sub_g
=
sub_g
.
to
(
device
)
...
@@ -171,20 +192,38 @@ for epoch in range(args.epochs):
...
@@ -171,20 +192,38 @@ for epoch in range(args.epochs):
# get the feature for the input_nodes
# get the feature for the input_nodes
opt
.
zero_grad
()
opt
.
zero_grad
()
output_bipartite
=
model
(
bipartites
)
output_bipartite
=
model
(
bipartites
)
loss
,
loss_den_val
,
loss_conn_val
=
model
.
compute_loss
(
output_bipartite
)
loss
,
loss_den_val
,
loss_conn_val
=
model
.
compute_loss
(
output_bipartite
)
loss_den_val_total
.
append
(
loss_den_val
)
loss_den_val_total
.
append
(
loss_den_val
)
loss_conn_val_total
.
append
(
loss_conn_val
)
loss_conn_val_total
.
append
(
loss_conn_val
)
loss_val_total
.
append
(
loss
.
item
())
loss_val_total
.
append
(
loss
.
item
())
loss
.
backward
()
loss
.
backward
()
opt
.
step
()
opt
.
step
()
if
(
batch
+
1
)
%
10
==
0
:
if
(
batch
+
1
)
%
10
==
0
:
print
(
'epoch: %d, batch: %d / %d, loader_id : %d / %d, loss: %.6f, loss_den: %.6f, loss_conn: %.6f'
%
print
(
(
epoch
,
batch
,
num_batch_per_loader
,
loader_id
,
num_loaders
,
"epoch: %d, batch: %d / %d, loader_id : %d / %d, loss: %.6f, loss_den: %.6f, loss_conn: %.6f"
loss
.
item
(),
loss_den_val
,
loss_conn_val
))
%
(
epoch
,
batch
,
num_batch_per_loader
,
loader_id
,
num_loaders
,
loss
.
item
(),
loss_den_val
,
loss_conn_val
,
)
)
scheduler
.
step
()
scheduler
.
step
()
print
(
'epoch: %d, loss: %.6f, loss_den: %.6f, loss_conn: %.6f'
%
print
(
(
epoch
,
np
.
array
(
loss_val_total
).
mean
(),
"epoch: %d, loss: %.6f, loss_den: %.6f, loss_conn: %.6f"
np
.
array
(
loss_den_val_total
).
mean
(),
np
.
array
(
loss_conn_val_total
).
mean
()))
%
(
epoch
,
np
.
array
(
loss_val_total
).
mean
(),
np
.
array
(
loss_den_val_total
).
mean
(),
np
.
array
(
loss_conn_val_total
).
mean
(),
)
)
torch
.
save
(
model
.
state_dict
(),
args
.
model_filename
)
torch
.
save
(
model
.
state_dict
(),
args
.
model_filename
)
torch
.
save
(
model
.
state_dict
(),
args
.
model_filename
)
torch
.
save
(
model
.
state_dict
(),
args
.
model_filename
)
examples/pytorch/hilander/models/graphconv.py
View file @
704bcaf6
#!/usr/bin/env python
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# -*- coding: utf-8 -*-
import
dgl.function
as
fn
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
torch.nn
import
init
import
dgl.function
as
fn
from
dgl.nn.pytorch
import
GATConv
from
dgl.nn.pytorch
import
GATConv
from
torch.nn
import
init
class
GraphConvLayer
(
nn
.
Module
):
class
GraphConvLayer
(
nn
.
Module
):
...
...
examples/pytorch/hilander/models/lander.py
View file @
704bcaf6
#!/usr/bin/env python
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# -*- coding: utf-8 -*-
import
dgl
import
dgl.function
as
fn
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
import
dgl
import
dgl.function
as
fn
from
.focal_loss
import
FocalLoss
from
.focal_loss
import
FocalLoss
from
.graphconv
import
GraphConv
from
.graphconv
import
GraphConv
...
...
examples/pytorch/hilander/test.py
View file @
704bcaf6
...
@@ -3,6 +3,8 @@ import os
...
@@ -3,6 +3,8 @@ import os
import
pickle
import
pickle
import
time
import
time
import
dgl
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
import
torch.optim
as
optim
import
torch.optim
as
optim
...
@@ -10,8 +12,6 @@ from dataset import LanderDataset
...
@@ -10,8 +12,6 @@ from dataset import LanderDataset
from
models
import
LANDER
from
models
import
LANDER
from
utils
import
build_next_level
,
decode
,
evaluation
,
stop_iterating
from
utils
import
build_next_level
,
decode
,
evaluation
,
stop_iterating
import
dgl
###########
###########
# ArgParser
# ArgParser
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
...
...
examples/pytorch/hilander/test_subg.py
View file @
704bcaf6
...
@@ -3,6 +3,8 @@ import os
...
@@ -3,6 +3,8 @@ import os
import
pickle
import
pickle
import
time
import
time
import
dgl
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
import
torch.optim
as
optim
import
torch.optim
as
optim
...
@@ -10,8 +12,6 @@ from dataset import LanderDataset
...
@@ -10,8 +12,6 @@ from dataset import LanderDataset
from
models
import
LANDER
from
models
import
LANDER
from
utils
import
build_next_level
,
decode
,
evaluation
,
stop_iterating
from
utils
import
build_next_level
,
decode
,
evaluation
,
stop_iterating
import
dgl
###########
###########
# ArgParser
# ArgParser
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
...
...
examples/pytorch/hilander/train.py
View file @
704bcaf6
...
@@ -3,14 +3,14 @@ import os
...
@@ -3,14 +3,14 @@ import os
import
pickle
import
pickle
import
time
import
time
import
dgl
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
import
torch.optim
as
optim
import
torch.optim
as
optim
from
dataset
import
LanderDataset
from
dataset
import
LanderDataset
from
models
import
LANDER
from
models
import
LANDER
import
dgl
###########
###########
# ArgParser
# ArgParser
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
...
@@ -50,6 +50,7 @@ if torch.cuda.is_available():
...
@@ -50,6 +50,7 @@ if torch.cuda.is_available():
else
:
else
:
device
=
torch
.
device
(
"cpu"
)
device
=
torch
.
device
(
"cpu"
)
##################
##################
# Data Preparation
# Data Preparation
def
prepare_dataset_graphs
(
data_path
,
k_list
,
lvl_list
):
def
prepare_dataset_graphs
(
data_path
,
k_list
,
lvl_list
):
...
...
examples/pytorch/hilander/train_subg.py
View file @
704bcaf6
...
@@ -3,14 +3,14 @@ import os
...
@@ -3,14 +3,14 @@ import os
import
pickle
import
pickle
import
time
import
time
import
dgl
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
import
torch.optim
as
optim
import
torch.optim
as
optim
from
dataset
import
LanderDataset
from
dataset
import
LanderDataset
from
models
import
LANDER
from
models
import
LANDER
import
dgl
###########
###########
# ArgParser
# ArgParser
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
...
...
examples/pytorch/hilander/utils/deduce.py
View file @
704bcaf6
"""
"""
This file re-uses implementation from https://github.com/yl-1993/learn-to-cluster
This file re-uses implementation from https://github.com/yl-1993/learn-to-cluster
"""
"""
import
dgl
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
sklearn
import
mixture
from
sklearn
import
mixture
import
dgl
from
.density
import
density_to_peaks
,
density_to_peaks_vectorize
from
.density
import
density_to_peaks
,
density_to_peaks_vectorize
__all__
=
[
__all__
=
[
...
...
examples/pytorch/hilander/utils/evaluate.py
View file @
704bcaf6
...
@@ -6,7 +6,7 @@ import inspect
...
@@ -6,7 +6,7 @@ import inspect
import
numpy
as
np
import
numpy
as
np
from
clustering_benchmark
import
ClusteringBenchmark
from
clustering_benchmark
import
ClusteringBenchmark
from
utils
import
TextColors
,
Timer
,
metrics
from
utils
import
metrics
,
TextColors
,
Timer
def
_read_meta
(
fn
):
def
_read_meta
(
fn
):
...
...
examples/pytorch/hilander/utils/faiss_search.py
View file @
704bcaf6
...
@@ -101,7 +101,6 @@ def faiss_search_knn(
...
@@ -101,7 +101,6 @@ def faiss_search_knn(
sort
=
True
,
sort
=
True
,
verbose
=
False
,
verbose
=
False
,
):
):
dists
,
nbrs
=
faiss_search_approx_knn
(
dists
,
nbrs
=
faiss_search_approx_knn
(
query
=
feat
,
target
=
feat
,
k
=
k
,
nprobe
=
nprobe
,
verbose
=
verbose
query
=
feat
,
target
=
feat
,
k
=
k
,
nprobe
=
nprobe
,
verbose
=
verbose
)
)
...
...
examples/pytorch/infograph/evaluate_embedding.py
View file @
704bcaf6
...
@@ -70,7 +70,6 @@ def svc_classify(x, y, search):
...
@@ -70,7 +70,6 @@ def svc_classify(x, y, search):
kf
=
StratifiedKFold
(
n_splits
=
10
,
shuffle
=
True
,
random_state
=
None
)
kf
=
StratifiedKFold
(
n_splits
=
10
,
shuffle
=
True
,
random_state
=
None
)
accuracies
=
[]
accuracies
=
[]
for
train_index
,
test_index
in
kf
.
split
(
x
,
y
):
for
train_index
,
test_index
in
kf
.
split
(
x
,
y
):
x_train
,
x_test
=
x
[
train_index
],
x
[
test_index
]
x_train
,
x_test
=
x
[
train_index
],
x
[
test_index
]
y_train
,
y_test
=
y
[
train_index
],
y
[
test_index
]
y_train
,
y_test
=
y
[
train_index
],
y
[
test_index
]
...
...
examples/pytorch/infograph/model.py
View file @
704bcaf6
import
torch
as
th
import
torch
as
th
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
torch.nn
import
GRU
,
BatchNorm1d
,
Linear
,
ModuleList
,
ReLU
,
Sequential
from
utils
import
global_global_loss_
,
local_global_loss_
from
dgl.nn
import
GINConv
,
NNConv
,
Set2Set
from
dgl.nn
import
GINConv
,
NNConv
,
Set2Set
from
dgl.nn.pytorch.glob
import
SumPooling
from
dgl.nn.pytorch.glob
import
SumPooling
from
torch.nn
import
BatchNorm1d
,
GRU
,
Linear
,
ModuleList
,
ReLU
,
Sequential
from
utils
import
global_global_loss_
,
local_global_loss_
""" Feedforward neural network"""
""" Feedforward neural network"""
...
@@ -102,7 +102,6 @@ class GINEncoder(nn.Module):
...
@@ -102,7 +102,6 @@ class GINEncoder(nn.Module):
self
.
pool
=
SumPooling
()
self
.
pool
=
SumPooling
()
def
forward
(
self
,
graph
,
feat
):
def
forward
(
self
,
graph
,
feat
):
xs
=
[]
xs
=
[]
x
=
feat
x
=
feat
for
i
in
range
(
self
.
n_layer
):
for
i
in
range
(
self
.
n_layer
):
...
@@ -163,7 +162,6 @@ class InfoGraph(nn.Module):
...
@@ -163,7 +162,6 @@ class InfoGraph(nn.Module):
return
global_emb
return
global_emb
def
forward
(
self
,
graph
,
feat
,
graph_id
):
def
forward
(
self
,
graph
,
feat
,
graph_id
):
global_emb
,
local_emb
=
self
.
encoder
(
graph
,
feat
)
global_emb
,
local_emb
=
self
.
encoder
(
graph
,
feat
)
global_h
=
self
.
global_d
(
global_emb
)
# global hidden representation
global_h
=
self
.
global_d
(
global_emb
)
# global hidden representation
...
@@ -221,7 +219,6 @@ class NNConvEncoder(nn.Module):
...
@@ -221,7 +219,6 @@ class NNConvEncoder(nn.Module):
self
.
set2set
=
Set2Set
(
hid_dim
,
n_iters
=
3
,
n_layers
=
1
)
self
.
set2set
=
Set2Set
(
hid_dim
,
n_iters
=
3
,
n_layers
=
1
)
def
forward
(
self
,
graph
,
nfeat
,
efeat
):
def
forward
(
self
,
graph
,
nfeat
,
efeat
):
out
=
F
.
relu
(
self
.
lin0
(
nfeat
))
out
=
F
.
relu
(
self
.
lin0
(
nfeat
))
h
=
out
.
unsqueeze
(
0
)
h
=
out
.
unsqueeze
(
0
)
...
@@ -279,7 +276,6 @@ class InfoGraphS(nn.Module):
...
@@ -279,7 +276,6 @@ class InfoGraphS(nn.Module):
self
.
unsup_d
=
FeedforwardNetwork
(
2
*
hid_dim
,
hid_dim
)
self
.
unsup_d
=
FeedforwardNetwork
(
2
*
hid_dim
,
hid_dim
)
def
forward
(
self
,
graph
,
nfeat
,
efeat
):
def
forward
(
self
,
graph
,
nfeat
,
efeat
):
sup_global_emb
,
sup_local_emb
=
self
.
sup_encoder
(
graph
,
nfeat
,
efeat
)
sup_global_emb
,
sup_local_emb
=
self
.
sup_encoder
(
graph
,
nfeat
,
efeat
)
sup_global_pred
=
self
.
fc2
(
F
.
relu
(
self
.
fc1
(
sup_global_emb
)))
sup_global_pred
=
self
.
fc2
(
F
.
relu
(
self
.
fc1
(
sup_global_emb
)))
...
@@ -288,7 +284,6 @@ class InfoGraphS(nn.Module):
...
@@ -288,7 +284,6 @@ class InfoGraphS(nn.Module):
return
sup_global_pred
return
sup_global_pred
def
unsup_forward
(
self
,
graph
,
nfeat
,
efeat
,
graph_id
):
def
unsup_forward
(
self
,
graph
,
nfeat
,
efeat
,
graph_id
):
sup_global_emb
,
sup_local_emb
=
self
.
sup_encoder
(
graph
,
nfeat
,
efeat
)
sup_global_emb
,
sup_local_emb
=
self
.
sup_encoder
(
graph
,
nfeat
,
efeat
)
unsup_global_emb
,
unsup_local_emb
=
self
.
unsup_encoder
(
unsup_global_emb
,
unsup_local_emb
=
self
.
unsup_encoder
(
graph
,
nfeat
,
efeat
graph
,
nfeat
,
efeat
...
...
examples/pytorch/infograph/semisupervised.py
View file @
704bcaf6
import
argparse
import
argparse
import
dgl
import
numpy
as
np
import
numpy
as
np
import
torch
as
th
import
torch
as
th
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
model
import
InfoGraphS
import
dgl
from
dgl.data
import
QM9EdgeDataset
from
dgl.data
import
QM9EdgeDataset
from
dgl.data.utils
import
Subset
from
dgl.data.utils
import
Subset
from
dgl.dataloading
import
GraphDataLoader
from
dgl.dataloading
import
GraphDataLoader
from
model
import
InfoGraphS
def
argument
():
def
argument
():
...
@@ -160,7 +160,6 @@ def evaluate(model, loader, num, device):
...
@@ -160,7 +160,6 @@ def evaluate(model, loader, num, device):
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
# Step 1: Prepare graph data ===================================== #
# Step 1: Prepare graph data ===================================== #
args
=
argument
()
args
=
argument
()
label_keys
=
[
args
.
target
]
label_keys
=
[
args
.
target
]
...
...
examples/pytorch/infograph/unsupervised.py
View file @
704bcaf6
import
argparse
import
argparse
import
torch
as
th
from
evaluate_embedding
import
evaluate_embedding
from
model
import
InfoGraph
import
dgl
import
dgl
import
torch
as
th
from
dgl.data
import
GINDataset
from
dgl.data
import
GINDataset
from
dgl.dataloading
import
GraphDataLoader
from
dgl.dataloading
import
GraphDataLoader
from
evaluate_embedding
import
evaluate_embedding
from
model
import
InfoGraph
def
argument
():
def
argument
():
...
@@ -75,7 +75,6 @@ def collate(samples):
...
@@ -75,7 +75,6 @@ def collate(samples):
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
# Step 1: Prepare graph data ===================================== #
# Step 1: Prepare graph data ===================================== #
args
=
argument
()
args
=
argument
()
print
(
args
)
print
(
args
)
...
@@ -131,7 +130,6 @@ if __name__ == "__main__":
...
@@ -131,7 +130,6 @@ if __name__ == "__main__":
model
.
train
()
model
.
train
()
for
graph
,
label
in
dataloader
:
for
graph
,
label
in
dataloader
:
graph
=
graph
.
to
(
args
.
device
)
graph
=
graph
.
to
(
args
.
device
)
feat
=
graph
.
ndata
[
"attr"
]
feat
=
graph
.
ndata
[
"attr"
]
graph_id
=
graph
.
ndata
[
"graph_id"
]
graph_id
=
graph
.
ndata
[
"graph_id"
]
...
@@ -147,7 +145,6 @@ if __name__ == "__main__":
...
@@ -147,7 +145,6 @@ if __name__ == "__main__":
print
(
"Epoch {}, Loss {:.4f}"
.
format
(
epoch
,
loss_all
))
print
(
"Epoch {}, Loss {:.4f}"
.
format
(
epoch
,
loss_all
))
if
epoch
%
args
.
log_interval
==
0
:
if
epoch
%
args
.
log_interval
==
0
:
# evaluate embeddings
# evaluate embeddings
model
.
eval
()
model
.
eval
()
emb
=
model
.
get_embedding
(
wholegraph
,
wholefeat
).
cpu
()
emb
=
model
.
get_embedding
(
wholegraph
,
wholefeat
).
cpu
()
...
...
examples/pytorch/infograph/utils.py
View file @
704bcaf6
...
@@ -41,7 +41,6 @@ def get_negative_expectation(q_samples, average=True):
...
@@ -41,7 +41,6 @@ def get_negative_expectation(q_samples, average=True):
def
local_global_loss_
(
l_enc
,
g_enc
,
graph_id
):
def
local_global_loss_
(
l_enc
,
g_enc
,
graph_id
):
num_graphs
=
g_enc
.
shape
[
0
]
num_graphs
=
g_enc
.
shape
[
0
]
num_nodes
=
l_enc
.
shape
[
0
]
num_nodes
=
l_enc
.
shape
[
0
]
...
@@ -51,7 +50,6 @@ def local_global_loss_(l_enc, g_enc, graph_id):
...
@@ -51,7 +50,6 @@ def local_global_loss_(l_enc, g_enc, graph_id):
neg_mask
=
th
.
ones
((
num_nodes
,
num_graphs
)).
to
(
device
)
neg_mask
=
th
.
ones
((
num_nodes
,
num_graphs
)).
to
(
device
)
for
nodeidx
,
graphidx
in
enumerate
(
graph_id
):
for
nodeidx
,
graphidx
in
enumerate
(
graph_id
):
pos_mask
[
nodeidx
][
graphidx
]
=
1.0
pos_mask
[
nodeidx
][
graphidx
]
=
1.0
neg_mask
[
nodeidx
][
graphidx
]
=
0.0
neg_mask
[
nodeidx
][
graphidx
]
=
0.0
...
@@ -66,7 +64,6 @@ def local_global_loss_(l_enc, g_enc, graph_id):
...
@@ -66,7 +64,6 @@ def local_global_loss_(l_enc, g_enc, graph_id):
def
global_global_loss_
(
sup_enc
,
unsup_enc
):
def
global_global_loss_
(
sup_enc
,
unsup_enc
):
num_graphs
=
sup_enc
.
shape
[
0
]
num_graphs
=
sup_enc
.
shape
[
0
]
device
=
sup_enc
.
device
device
=
sup_enc
.
device
...
...
Prev
1
…
5
6
7
8
9
10
11
12
13
…
17
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment