Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
704bcaf6
Unverified
Commit
704bcaf6
authored
Feb 19, 2023
by
Hongzhi (Steve), Chen
Committed by
GitHub
Feb 19, 2023
Browse files
examples (#5323)
Co-authored-by:
Ubuntu
<
ubuntu@ip-172-31-28-63.ap-northeast-1.compute.internal
>
parent
6bc82161
Changes
332
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
300 additions
and
195 deletions
+300
-195
examples/pytorch/hilander/PSS/Smooth_AP/src/finetune_1head.py
...ples/pytorch/hilander/PSS/Smooth_AP/src/finetune_1head.py
+2
-2
examples/pytorch/hilander/PSS/Smooth_AP/src/get_features.py
examples/pytorch/hilander/PSS/Smooth_AP/src/get_features.py
+0
-1
examples/pytorch/hilander/PSS/Smooth_AP/src/losses.py
examples/pytorch/hilander/PSS/Smooth_AP/src/losses.py
+2
-0
examples/pytorch/hilander/PSS/Smooth_AP/src/main.py
examples/pytorch/hilander/PSS/Smooth_AP/src/main.py
+0
-1
examples/pytorch/hilander/PSS/test_subg_inat.py
examples/pytorch/hilander/PSS/test_subg_inat.py
+177
-95
examples/pytorch/hilander/PSS/train_subg_inat.py
examples/pytorch/hilander/PSS/train_subg_inat.py
+95
-56
examples/pytorch/hilander/models/graphconv.py
examples/pytorch/hilander/models/graphconv.py
+2
-3
examples/pytorch/hilander/models/lander.py
examples/pytorch/hilander/models/lander.py
+2
-3
examples/pytorch/hilander/test.py
examples/pytorch/hilander/test.py
+2
-2
examples/pytorch/hilander/test_subg.py
examples/pytorch/hilander/test_subg.py
+2
-2
examples/pytorch/hilander/train.py
examples/pytorch/hilander/train.py
+3
-2
examples/pytorch/hilander/train_subg.py
examples/pytorch/hilander/train_subg.py
+2
-2
examples/pytorch/hilander/utils/deduce.py
examples/pytorch/hilander/utils/deduce.py
+1
-2
examples/pytorch/hilander/utils/evaluate.py
examples/pytorch/hilander/utils/evaluate.py
+1
-1
examples/pytorch/hilander/utils/faiss_search.py
examples/pytorch/hilander/utils/faiss_search.py
+0
-1
examples/pytorch/infograph/evaluate_embedding.py
examples/pytorch/infograph/evaluate_embedding.py
+0
-1
examples/pytorch/infograph/model.py
examples/pytorch/infograph/model.py
+2
-7
examples/pytorch/infograph/semisupervised.py
examples/pytorch/infograph/semisupervised.py
+3
-4
examples/pytorch/infograph/unsupervised.py
examples/pytorch/infograph/unsupervised.py
+4
-7
examples/pytorch/infograph/utils.py
examples/pytorch/infograph/utils.py
+0
-3
No files found.
examples/pytorch/hilander/PSS/Smooth_AP/src/finetune_1head.py
View file @
704bcaf6
...
...
@@ -280,7 +280,6 @@ _ = model.to(opt.device)
# Place trainable parameter in list of parameters to train:
if
"fc_lr_mul"
in
vars
(
opt
).
keys
()
and
opt
.
fc_lr_mul
!=
0
:
all_but_fc_params
=
list
(
filter
(
lambda
x
:
"last_linear"
not
in
x
[
0
],
model
.
named_parameters
())
)
...
...
@@ -376,6 +375,8 @@ def same_model(model1, model2):
"""============================================================================"""
#################### TRAINER FUNCTION ############################
def
train_one_epoch_finetune
(
train_dataloader
,
model
,
optimizer
,
criterion
,
opt
,
epoch
...
...
@@ -403,7 +404,6 @@ def train_one_epoch_finetune(
train_dataloader
,
desc
=
"Epoch {} Training gt labels..."
.
format
(
epoch
)
)
for
i
,
(
class_labels
,
input
)
in
enumerate
(
data_iterator
):
# Compute embeddings for input batch
features
=
model
(
input
.
to
(
opt
.
device
))
...
...
examples/pytorch/hilander/PSS/Smooth_AP/src/get_features.py
View file @
704bcaf6
...
...
@@ -263,7 +263,6 @@ _ = model.to(opt.device)
# Place trainable parameter in list of parameters to train:
if
"fc_lr_mul"
in
vars
(
opt
).
keys
()
and
opt
.
fc_lr_mul
!=
0
:
all_but_fc_params
=
list
(
filter
(
lambda
x
:
"last_linear"
not
in
x
[
0
],
model
.
named_parameters
())
)
...
...
examples/pytorch/hilander/PSS/Smooth_AP/src/losses.py
View file @
704bcaf6
...
...
@@ -11,6 +11,8 @@ import torch
from
scipy
import
sparse
"""================================================================================================="""
############ LOSS SELECTION FUNCTION #####################
def
loss_select
(
loss
,
opt
,
to_optim
):
"""
...
...
examples/pytorch/hilander/PSS/Smooth_AP/src/main.py
View file @
704bcaf6
...
...
@@ -281,7 +281,6 @@ _ = model.to(opt.device)
# Place trainable parameter in list of parameters to train:
if
"fc_lr_mul"
in
vars
(
opt
).
keys
()
and
opt
.
fc_lr_mul
!=
0
:
all_but_fc_params
=
list
(
filter
(
lambda
x
:
"last_linear"
not
in
x
[
0
],
model
.
named_parameters
())
)
...
...
examples/pytorch/hilander/PSS/test_subg_inat.py
View file @
704bcaf6
import
argparse
,
time
,
os
,
pickle
import
argparse
,
os
,
pickle
,
time
import
random
import
sys
sys
.
path
.
append
(
".."
)
from
utils.deduce
import
get_edge_dist
import
numpy
as
np
import
shutil
import
dgl
import
numpy
as
np
import
seaborn
import
torch
import
torch.optim
as
optim
from
models
import
LANDER
from
dataset
import
LanderDataset
from
utils
import
evaluation
,
decode
,
build_next_level
,
stop_iterating
from
matplotlib
import
pyplot
as
plt
import
seaborn
from
models
import
LANDER
from
utils
import
build_next_level
,
decode
,
evaluation
,
stop_iterating
from
utils.deduce
import
get_edge_dist
STATISTIC
=
False
...
...
@@ -25,43 +26,47 @@ STATISTIC = False
parser
=
argparse
.
ArgumentParser
()
# Dataset
parser
.
add_argument
(
'
--data_path
'
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
'
--model_filename
'
,
type
=
str
,
default
=
'
lander.pth
'
)
parser
.
add_argument
(
'
--faiss_gpu
'
,
action
=
'
store_true
'
)
parser
.
add_argument
(
'
--num_workers
'
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
'
--output_filename
'
,
type
=
str
,
default
=
'
data/features.pkl
'
)
parser
.
add_argument
(
"
--data_path
"
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
"
--model_filename
"
,
type
=
str
,
default
=
"
lander.pth
"
)
parser
.
add_argument
(
"
--faiss_gpu
"
,
action
=
"
store_true
"
)
parser
.
add_argument
(
"
--num_workers
"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"
--output_filename
"
,
type
=
str
,
default
=
"
data/features.pkl
"
)
# HyperParam
parser
.
add_argument
(
'
--knn_k
'
,
type
=
int
,
default
=
10
)
parser
.
add_argument
(
'
--levels
'
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
'
--tau
'
,
type
=
float
,
default
=
0.5
)
parser
.
add_argument
(
'
--threshold
'
,
type
=
str
,
default
=
'
prob
'
)
parser
.
add_argument
(
'
--metrics
'
,
type
=
str
,
default
=
'
pairwise,bcubed,nmi
'
)
parser
.
add_argument
(
'
--early_stop
'
,
action
=
'
store_true
'
)
parser
.
add_argument
(
"
--knn_k
"
,
type
=
int
,
default
=
10
)
parser
.
add_argument
(
"
--levels
"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"
--tau
"
,
type
=
float
,
default
=
0.5
)
parser
.
add_argument
(
"
--threshold
"
,
type
=
str
,
default
=
"
prob
"
)
parser
.
add_argument
(
"
--metrics
"
,
type
=
str
,
default
=
"
pairwise,bcubed,nmi
"
)
parser
.
add_argument
(
"
--early_stop
"
,
action
=
"
store_true
"
)
# Model
parser
.
add_argument
(
'
--hidden
'
,
type
=
int
,
default
=
512
)
parser
.
add_argument
(
'
--num_conv
'
,
type
=
int
,
default
=
4
)
parser
.
add_argument
(
'
--dropout
'
,
type
=
float
,
default
=
0.
)
parser
.
add_argument
(
'
--gat
'
,
action
=
'
store_true
'
)
parser
.
add_argument
(
'
--gat_k
'
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
'
--balance
'
,
action
=
'
store_true
'
)
parser
.
add_argument
(
'
--use_cluster_feat
'
,
action
=
'
store_true
'
)
parser
.
add_argument
(
'
--use_focal_loss
'
,
action
=
'
store_true
'
)
parser
.
add_argument
(
'
--use_gt
'
,
action
=
'
store_true
'
)
parser
.
add_argument
(
"
--hidden
"
,
type
=
int
,
default
=
512
)
parser
.
add_argument
(
"
--num_conv
"
,
type
=
int
,
default
=
4
)
parser
.
add_argument
(
"
--dropout
"
,
type
=
float
,
default
=
0.
0
)
parser
.
add_argument
(
"
--gat
"
,
action
=
"
store_true
"
)
parser
.
add_argument
(
"
--gat_k
"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"
--balance
"
,
action
=
"
store_true
"
)
parser
.
add_argument
(
"
--use_cluster_feat
"
,
action
=
"
store_true
"
)
parser
.
add_argument
(
"
--use_focal_loss
"
,
action
=
"
store_true
"
)
parser
.
add_argument
(
"
--use_gt
"
,
action
=
"
store_true
"
)
# Subgraph
parser
.
add_argument
(
'--batch_size'
,
type
=
int
,
default
=
4096
)
parser
.
add_argument
(
'--mode'
,
type
=
str
,
default
=
"1head"
)
parser
.
add_argument
(
'--midpoint'
,
type
=
str
,
default
=
"false"
)
parser
.
add_argument
(
'--linsize'
,
type
=
int
,
default
=
29011
)
parser
.
add_argument
(
'--uinsize'
,
type
=
int
,
default
=
18403
)
parser
.
add_argument
(
'--inclasses'
,
type
=
int
,
default
=
948
)
parser
.
add_argument
(
'--thresh'
,
type
=
float
,
default
=
1.0
)
parser
.
add_argument
(
'--draw'
,
type
=
str
,
default
=
'false'
)
parser
.
add_argument
(
'--density_distance_pkl'
,
type
=
str
,
default
=
"density_distance.pkl"
)
parser
.
add_argument
(
'--density_lindistance_jpg'
,
type
=
str
,
default
=
"density_lindistance.jpg"
)
parser
.
add_argument
(
"--batch_size"
,
type
=
int
,
default
=
4096
)
parser
.
add_argument
(
"--mode"
,
type
=
str
,
default
=
"1head"
)
parser
.
add_argument
(
"--midpoint"
,
type
=
str
,
default
=
"false"
)
parser
.
add_argument
(
"--linsize"
,
type
=
int
,
default
=
29011
)
parser
.
add_argument
(
"--uinsize"
,
type
=
int
,
default
=
18403
)
parser
.
add_argument
(
"--inclasses"
,
type
=
int
,
default
=
948
)
parser
.
add_argument
(
"--thresh"
,
type
=
float
,
default
=
1.0
)
parser
.
add_argument
(
"--draw"
,
type
=
str
,
default
=
"false"
)
parser
.
add_argument
(
"--density_distance_pkl"
,
type
=
str
,
default
=
"density_distance.pkl"
)
parser
.
add_argument
(
"--density_lindistance_jpg"
,
type
=
str
,
default
=
"density_lindistance.jpg"
)
args
=
parser
.
parse_args
()
print
(
args
)
...
...
@@ -70,21 +75,21 @@ linsize = args.linsize
uinsize
=
args
.
uinsize
inclasses
=
args
.
inclasses
if
args
.
draw
==
'
false
'
:
if
args
.
draw
==
"
false
"
:
args
.
draw
=
False
elif
args
.
draw
==
'
true
'
:
elif
args
.
draw
==
"
true
"
:
args
.
draw
=
True
###########################
# Environment Configuration
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
'
cuda
'
)
device
=
torch
.
device
(
"
cuda
"
)
else
:
device
=
torch
.
device
(
'
cpu
'
)
device
=
torch
.
device
(
"
cpu
"
)
##################
# Data Preparation
with
open
(
args
.
data_path
,
'
rb
'
)
as
f
:
with
open
(
args
.
data_path
,
"
rb
"
)
as
f
:
loaded_data
=
pickle
.
load
(
f
)
path2idx
,
features
,
pred_labels
,
labels
,
masks
=
loaded_data
...
...
@@ -123,11 +128,12 @@ else:
print
(
"filtered features:"
,
len
(
features
))
global_features
=
features
.
copy
()
# global features
dataset
=
LanderDataset
(
features
=
features
,
labels
=
labels
,
k
=
args
.
knn_k
,
levels
=
1
,
faiss_gpu
=
False
)
dataset
=
LanderDataset
(
features
=
features
,
labels
=
labels
,
k
=
args
.
knn_k
,
levels
=
1
,
faiss_gpu
=
False
)
g
=
dataset
.
gs
[
0
]
g
.
ndata
[
'
pred_den
'
]
=
torch
.
zeros
((
g
.
number_of_nodes
()))
g
.
edata
[
'
prob_conn
'
]
=
torch
.
zeros
((
g
.
number_of_edges
(),
2
))
g
.
ndata
[
"
pred_den
"
]
=
torch
.
zeros
((
g
.
number_of_nodes
()))
g
.
edata
[
"
prob_conn
"
]
=
torch
.
zeros
((
g
.
number_of_edges
(),
2
))
global_labels
=
labels
.
copy
()
ids
=
np
.
arange
(
g
.
number_of_nodes
())
global_edges
=
([],
[])
...
...
@@ -135,7 +141,7 @@ global_peaks = np.array([], dtype=np.long)
global_edges_len
=
len
(
global_edges
[
0
])
global_num_nodes
=
g
.
number_of_nodes
()
global_densities
=
g
.
ndata
[
'
density
'
][:
linsize
]
global_densities
=
g
.
ndata
[
"
density
"
][:
linsize
]
global_densities
=
np
.
sort
(
global_densities
)
xs
=
np
.
arange
(
len
(
global_densities
))
...
...
@@ -143,23 +149,30 @@ fanouts = [args.knn_k - 1 for i in range(args.num_conv + 1)]
sampler
=
dgl
.
dataloading
.
MultiLayerNeighborSampler
(
fanouts
)
# fix the number of edges
test_loader
=
dgl
.
dataloading
.
DataLoader
(
g
,
torch
.
arange
(
g
.
number_of_nodes
()),
sampler
,
g
,
torch
.
arange
(
g
.
number_of_nodes
()),
sampler
,
batch_size
=
args
.
batch_size
,
shuffle
=
False
,
drop_last
=
False
,
num_workers
=
args
.
num_workers
num_workers
=
args
.
num_workers
,
)
##################
# Model Definition
if
not
args
.
use_gt
:
feature_dim
=
g
.
ndata
[
'features'
].
shape
[
1
]
model
=
LANDER
(
feature_dim
=
feature_dim
,
nhid
=
args
.
hidden
,
num_conv
=
args
.
num_conv
,
dropout
=
args
.
dropout
,
use_GAT
=
args
.
gat
,
K
=
args
.
gat_k
,
balance
=
args
.
balance
,
use_cluster_feat
=
args
.
use_cluster_feat
,
use_focal_loss
=
args
.
use_focal_loss
)
feature_dim
=
g
.
ndata
[
"features"
].
shape
[
1
]
model
=
LANDER
(
feature_dim
=
feature_dim
,
nhid
=
args
.
hidden
,
num_conv
=
args
.
num_conv
,
dropout
=
args
.
dropout
,
use_GAT
=
args
.
gat
,
K
=
args
.
gat_k
,
balance
=
args
.
balance
,
use_cluster_feat
=
args
.
use_cluster_feat
,
use_focal_loss
=
args
.
use_focal_loss
,
)
model
.
load_state_dict
(
torch
.
load
(
args
.
model_filename
))
model
=
model
.
to
(
device
)
model
.
eval
()
...
...
@@ -179,46 +192,82 @@ for level in range(args.levels):
with
torch
.
no_grad
():
output_bipartite
=
model
(
bipartites
)
global_nid
=
output_bipartite
.
dstdata
[
dgl
.
NID
]
global_eid
=
output_bipartite
.
edata
[
'global_eid'
]
g
.
ndata
[
'pred_den'
][
global_nid
]
=
output_bipartite
.
dstdata
[
'pred_den'
].
to
(
'cpu'
)
g
.
edata
[
'prob_conn'
][
global_eid
]
=
output_bipartite
.
edata
[
'prob_conn'
].
to
(
'cpu'
)
global_eid
=
output_bipartite
.
edata
[
"global_eid"
]
g
.
ndata
[
"pred_den"
][
global_nid
]
=
output_bipartite
.
dstdata
[
"pred_den"
].
to
(
"cpu"
)
g
.
edata
[
"prob_conn"
][
global_eid
]
=
output_bipartite
.
edata
[
"prob_conn"
].
to
(
"cpu"
)
torch
.
cuda
.
empty_cache
()
if
(
batch
+
1
)
%
10
==
0
:
print
(
'Batch %d / %d for inference'
%
(
batch
,
total_batches
))
new_pred_labels
,
peaks
,
\
global_edges
,
global_pred_labels
,
global_peaks
=
decode
(
g
,
args
.
tau
,
args
.
threshold
,
args
.
use_gt
,
ids
,
global_edges
,
global_num_nodes
,
global_peaks
)
print
(
"Batch %d / %d for inference"
%
(
batch
,
total_batches
))
(
new_pred_labels
,
peaks
,
global_edges
,
global_pred_labels
,
global_peaks
,
)
=
decode
(
g
,
args
.
tau
,
args
.
threshold
,
args
.
use_gt
,
ids
,
global_edges
,
global_num_nodes
,
global_peaks
,
)
if
level
==
0
:
global_pred_densities
=
g
.
ndata
[
'
pred_den
'
]
global_densities
=
g
.
ndata
[
'
density
'
]
g
.
edata
[
'
prob_conn
'
]
=
torch
.
zeros
((
g
.
number_of_edges
(),
2
))
global_pred_densities
=
g
.
ndata
[
"
pred_den
"
]
global_densities
=
g
.
ndata
[
"
density
"
]
g
.
edata
[
"
prob_conn
"
]
=
torch
.
zeros
((
g
.
number_of_edges
(),
2
))
ids
=
ids
[
peaks
]
new_global_edges_len
=
len
(
global_edges
[
0
])
num_edges_add_this_level
=
new_global_edges_len
-
global_edges_len
if
stop_iterating
(
level
,
args
.
levels
,
args
.
early_stop
,
num_edges_add_this_level
,
num_edges_add_last_level
,
args
.
knn_k
):
if
stop_iterating
(
level
,
args
.
levels
,
args
.
early_stop
,
num_edges_add_this_level
,
num_edges_add_last_level
,
args
.
knn_k
,
):
break
global_edges_len
=
new_global_edges_len
num_edges_add_last_level
=
num_edges_add_this_level
# build new dataset
features
,
labels
,
cluster_features
=
build_next_level
(
features
,
labels
,
peaks
,
global_features
,
global_pred_labels
,
global_peaks
)
features
,
labels
,
cluster_features
=
build_next_level
(
features
,
labels
,
peaks
,
global_features
,
global_pred_labels
,
global_peaks
,
)
# After the first level, the number of nodes reduce a lot. Using cpu faiss is faster.
dataset
=
LanderDataset
(
features
=
features
,
labels
=
labels
,
k
=
args
.
knn_k
,
levels
=
1
,
faiss_gpu
=
False
,
cluster_features
=
cluster_features
)
dataset
=
LanderDataset
(
features
=
features
,
labels
=
labels
,
k
=
args
.
knn_k
,
levels
=
1
,
faiss_gpu
=
False
,
cluster_features
=
cluster_features
,
)
g
=
dataset
.
gs
[
0
]
g
.
ndata
[
'
pred_den
'
]
=
torch
.
zeros
((
g
.
number_of_nodes
()))
g
.
edata
[
'
prob_conn
'
]
=
torch
.
zeros
((
g
.
number_of_edges
(),
2
))
g
.
ndata
[
"
pred_den
"
]
=
torch
.
zeros
((
g
.
number_of_nodes
()))
g
.
edata
[
"
prob_conn
"
]
=
torch
.
zeros
((
g
.
number_of_edges
(),
2
))
test_loader
=
dgl
.
dataloading
.
DataLoader
(
g
,
torch
.
arange
(
g
.
number_of_nodes
()),
sampler
,
g
,
torch
.
arange
(
g
.
number_of_nodes
()),
sampler
,
batch_size
=
args
.
batch_size
,
shuffle
=
False
,
drop_last
=
False
,
num_workers
=
args
.
num_workers
num_workers
=
args
.
num_workers
,
)
if
MODE
==
"selectbydensity"
:
...
...
@@ -261,8 +310,10 @@ if MODE == "selectbydensity":
idx
=
np
.
where
(
l_in_gt_new
==
i
)
prototypes
[
i
]
=
np
.
mean
(
l_in_features
[
idx
],
axis
=
0
)
similarity_matrix
=
torch
.
mm
(
torch
.
from_numpy
(
global_features
.
astype
(
np
.
float32
)),
torch
.
from_numpy
(
prototypes
.
astype
(
np
.
float32
)).
t
())
similarity_matrix
=
torch
.
mm
(
torch
.
from_numpy
(
global_features
.
astype
(
np
.
float32
)),
torch
.
from_numpy
(
prototypes
.
astype
(
np
.
float32
)).
t
(),
)
similarity_matrix
=
(
1
-
similarity_matrix
)
/
2
minvalues
,
selected_pred_labels
=
torch
.
min
(
similarity_matrix
,
1
)
# far-close ratio
...
...
@@ -274,7 +325,7 @@ if MODE == "selectbydensity":
cutidx
=
np
.
where
(
global_pred_densities
>=
0.5
)
draw_minvalues
=
minvalues
[
cutidx
]
draw_densities
=
global_pred_densities
[
cutidx
]
with
open
(
args
.
density_distance_pkl
,
'
wb
'
)
as
f
:
with
open
(
args
.
density_distance_pkl
,
"
wb
"
)
as
f
:
pickle
.
dump
((
global_pred_densities
,
minvalues
),
f
)
print
(
"dumped."
)
plt
.
clf
()
...
...
@@ -283,15 +334,29 @@ if MODE == "selectbydensity":
if
len
(
draw_densities
)
>
10000
:
samples_idx
=
random
.
sample
(
range
(
len
(
draw_minvalues
)),
10000
)
ax
.
plot
(
draw_densities
[
random
],
draw_minvalues
[
random
],
color
=
'tab:blue'
,
marker
=
'*'
,
linestyle
=
"None"
,
markersize
=
1
)
ax
.
plot
(
draw_densities
[
random
],
draw_minvalues
[
random
],
color
=
"tab:blue"
,
marker
=
"*"
,
linestyle
=
"None"
,
markersize
=
1
,
)
else
:
ax
.
plot
(
draw_densities
[
random
],
draw_minvalues
[
random
],
color
=
'tab:blue'
,
marker
=
'*'
,
linestyle
=
"None"
,
markersize
=
1
)
ax
.
plot
(
draw_densities
[
random
],
draw_minvalues
[
random
],
color
=
"tab:blue"
,
marker
=
"*"
,
linestyle
=
"None"
,
markersize
=
1
,
)
plt
.
savefig
(
args
.
density_lindistance_jpg
)
global_pred_labels_new
[
Tidx
]
=
l_in_gt_new
global_pred_labels
[
selectidx
]
=
global_pred_labels
[
selectidx
]
+
len
(
l_in_unique
)
global_pred_labels
[
selectidx
]
=
global_pred_labels
[
selectidx
]
+
len
(
l_in_unique
)
global_pred_labels_new
[
selectedidx
]
=
global_pred_labels
global_pred_labels
=
global_pred_labels_new
...
...
@@ -332,7 +397,9 @@ if MODE == "recluster":
global_pred_labels_new
[
Tidx
]
=
l_in_gt_new
print
(
len
(
global_pred_labels
))
print
(
len
(
selectedidx
[
0
]))
global_pred_labels_new
[
selectedidx
[
0
]]
=
global_pred_labels
+
len
(
l_in_unique
)
global_pred_labels_new
[
selectedidx
[
0
]]
=
global_pred_labels
+
len
(
l_in_unique
)
global_pred_labels
=
global_pred_labels_new
global_masks
=
masks
print
(
"mask0"
,
len
(
np
.
where
(
global_masks
==
0
)[
0
]))
...
...
@@ -348,23 +415,29 @@ if MODE == "donothing":
print
(
"##################### L_in ########################"
)
print
(
linsize
)
if
len
(
global_pred_labels
)
>=
linsize
:
evaluation
(
global_pred_labels
[:
linsize
],
global_gt_labels
[:
linsize
],
args
.
metrics
)
evaluation
(
global_pred_labels
[:
linsize
],
global_gt_labels
[:
linsize
],
args
.
metrics
)
else
:
print
(
"No samples in L_in!"
)
print
(
"##################### U_in ########################"
)
uinidx
=
np
.
where
(
global_pred_labels
[
linsize
:
linsize
+
uinsize
]
!=
-
1
)[
0
]
uinidx
=
np
.
where
(
global_pred_labels
[
linsize
:
linsize
+
uinsize
]
!=
-
1
)[
0
]
uinidx
=
uinidx
+
linsize
print
(
len
(
uinidx
))
if
len
(
uinidx
):
evaluation
(
global_pred_labels
[
uinidx
],
global_gt_labels
[
uinidx
],
args
.
metrics
)
evaluation
(
global_pred_labels
[
uinidx
],
global_gt_labels
[
uinidx
],
args
.
metrics
)
else
:
print
(
"No samples in U_in!"
)
print
(
"##################### U_out ########################"
)
uoutidx
=
np
.
where
(
global_pred_labels
[
linsize
+
uinsize
:]
!=
-
1
)[
0
]
uoutidx
=
np
.
where
(
global_pred_labels
[
linsize
+
uinsize
:]
!=
-
1
)[
0
]
uoutidx
=
uoutidx
+
linsize
+
uinsize
print
(
len
(
uoutidx
))
if
len
(
uoutidx
):
evaluation
(
global_pred_labels
[
uoutidx
],
global_gt_labels
[
uoutidx
],
args
.
metrics
)
evaluation
(
global_pred_labels
[
uoutidx
],
global_gt_labels
[
uoutidx
],
args
.
metrics
)
else
:
print
(
"No samples in U_out!"
)
print
(
"##################### U ########################"
)
...
...
@@ -390,9 +463,18 @@ print(len(nsidx))
if
len
(
nsidx
)
!=
0
:
evaluation
(
global_pred_labels
[
nsidx
],
global_gt_labels
[
nsidx
],
args
.
metrics
)
with
open
(
args
.
output_filename
,
'
wb
'
)
as
f
:
with
open
(
args
.
output_filename
,
"
wb
"
)
as
f
:
print
(
orifeatures
.
shape
)
print
(
global_pred_labels
.
shape
)
print
(
global_gt_labels
.
shape
)
print
(
global_masks
.
shape
)
pickle
.
dump
([
path2idx
,
orifeatures
,
global_pred_labels
,
global_gt_labels
,
global_masks
],
f
)
pickle
.
dump
(
[
path2idx
,
orifeatures
,
global_pred_labels
,
global_gt_labels
,
global_masks
,
],
f
,
)
examples/pytorch/hilander/PSS/train_subg_inat.py
View file @
704bcaf6
import
argparse
,
time
,
os
,
pickle
import
argparse
,
os
,
pickle
,
time
import
random
import
numpy
as
np
import
sys
import
dgl
import
numpy
as
np
import
torch
import
torch.optim
as
optim
import
sys
sys
.
path
.
append
(
".."
)
from
models
import
LANDER
from
dataset
import
LanderDataset
from
models
import
LANDER
###########
# ArgParser
parser
=
argparse
.
ArgumentParser
()
# Dataset
parser
.
add_argument
(
'
--data_path
'
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
'
--levels
'
,
type
=
str
,
default
=
'1'
)
parser
.
add_argument
(
'
--faiss_gpu
'
,
action
=
'
store_true
'
)
parser
.
add_argument
(
'
--model_filename
'
,
type
=
str
,
default
=
'
lander.pth
'
)
parser
.
add_argument
(
"
--data_path
"
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
"
--levels
"
,
type
=
str
,
default
=
"1"
)
parser
.
add_argument
(
"
--faiss_gpu
"
,
action
=
"
store_true
"
)
parser
.
add_argument
(
"
--model_filename
"
,
type
=
str
,
default
=
"
lander.pth
"
)
# KNN
parser
.
add_argument
(
'
--knn_k
'
,
type
=
str
,
default
=
'
10
'
)
parser
.
add_argument
(
'
--num_workers
'
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"
--knn_k
"
,
type
=
str
,
default
=
"
10
"
)
parser
.
add_argument
(
"
--num_workers
"
,
type
=
int
,
default
=
0
)
# Model
parser
.
add_argument
(
'
--hidden
'
,
type
=
int
,
default
=
512
)
parser
.
add_argument
(
'
--num_conv
'
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
'
--dropout
'
,
type
=
float
,
default
=
0.
)
parser
.
add_argument
(
'
--gat
'
,
action
=
'
store_true
'
)
parser
.
add_argument
(
'
--gat_k
'
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
'
--balance
'
,
action
=
'
store_true
'
)
parser
.
add_argument
(
'
--use_cluster_feat
'
,
action
=
'
store_true
'
)
parser
.
add_argument
(
'
--use_focal_loss
'
,
action
=
'
store_true
'
)
parser
.
add_argument
(
"
--hidden
"
,
type
=
int
,
default
=
512
)
parser
.
add_argument
(
"
--num_conv
"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"
--dropout
"
,
type
=
float
,
default
=
0.
0
)
parser
.
add_argument
(
"
--gat
"
,
action
=
"
store_true
"
)
parser
.
add_argument
(
"
--gat_k
"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"
--balance
"
,
action
=
"
store_true
"
)
parser
.
add_argument
(
"
--use_cluster_feat
"
,
action
=
"
store_true
"
)
parser
.
add_argument
(
"
--use_focal_loss
"
,
action
=
"
store_true
"
)
# Training
parser
.
add_argument
(
'
--epochs
'
,
type
=
int
,
default
=
100
)
parser
.
add_argument
(
'
--batch_size
'
,
type
=
int
,
default
=
1024
)
parser
.
add_argument
(
'
--lr
'
,
type
=
float
,
default
=
0.1
)
parser
.
add_argument
(
'
--momentum
'
,
type
=
float
,
default
=
0.9
)
parser
.
add_argument
(
'
--weight_decay
'
,
type
=
float
,
default
=
1e-5
)
parser
.
add_argument
(
"
--epochs
"
,
type
=
int
,
default
=
100
)
parser
.
add_argument
(
"
--batch_size
"
,
type
=
int
,
default
=
1024
)
parser
.
add_argument
(
"
--lr
"
,
type
=
float
,
default
=
0.1
)
parser
.
add_argument
(
"
--momentum
"
,
type
=
float
,
default
=
0.9
)
parser
.
add_argument
(
"
--weight_decay
"
,
type
=
float
,
default
=
1e-5
)
args
=
parser
.
parse_args
()
print
(
args
)
...
...
@@ -49,9 +50,9 @@ print(args)
###########################
# Environment Configuration
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
'
cuda
'
)
device
=
torch
.
device
(
"
cuda
"
)
else
:
device
=
torch
.
device
(
'
cpu
'
)
device
=
torch
.
device
(
"
cpu
"
)
def
setup_seed
(
seed
):
...
...
@@ -66,7 +67,7 @@ def setup_seed(seed):
##################
# Data Preparation
with
open
(
args
.
data_path
,
'
rb
'
)
as
f
:
with
open
(
args
.
data_path
,
"
rb
"
)
as
f
:
path2idx
,
features
,
labels
,
_
,
masks
=
pickle
.
load
(
f
)
# lidx = np.where(masks==0)
# features = features[lidx]
...
...
@@ -75,8 +76,8 @@ with open(args.data_path, 'rb') as f:
print
(
"labels.shape:"
,
labels
.
shape
)
k_list
=
[
int
(
k
)
for
k
in
args
.
knn_k
.
split
(
','
)]
lvl_list
=
[
int
(
l
)
for
l
in
args
.
levels
.
split
(
','
)]
k_list
=
[
int
(
k
)
for
k
in
args
.
knn_k
.
split
(
","
)]
lvl_list
=
[
int
(
l
)
for
l
in
args
.
levels
.
split
(
","
)]
gs
=
[]
nbrs
=
[]
ks
=
[]
...
...
@@ -84,8 +85,13 @@ datasets = []
for
k
,
l
in
zip
(
k_list
,
lvl_list
):
print
(
"k:"
,
k
)
print
(
"levels:"
,
l
)
dataset
=
LanderDataset
(
features
=
features
,
labels
=
labels
,
k
=
k
,
levels
=
l
,
faiss_gpu
=
args
.
faiss_gpu
)
dataset
=
LanderDataset
(
features
=
features
,
labels
=
labels
,
k
=
k
,
levels
=
l
,
faiss_gpu
=
args
.
faiss_gpu
,
)
gs
+=
[
g
for
g
in
dataset
.
gs
]
ks
+=
[
k
for
g
in
dataset
.
gs
]
nbrs
+=
[
nbr
for
nbr
in
dataset
.
nbrs
]
...
...
@@ -101,24 +107,28 @@ for k, l in zip(k_list, lvl_list):
# nbrs += [nbr for nbr in dataset.nbrs]
with
open
(
"./dataset.pkl"
,
'
wb
'
)
as
f
:
with
open
(
"./dataset.pkl"
,
"
wb
"
)
as
f
:
pickle
.
dump
(
datasets
,
f
)
print
(
'Dataset Prepared.'
)
print
(
"Dataset Prepared."
)
def
set_train_sampler_loader
(
g
,
k
):
fanouts
=
[
k
-
1
for
i
in
range
(
args
.
num_conv
+
1
)]
fanouts
=
[
k
-
1
for
i
in
range
(
args
.
num_conv
+
1
)]
sampler
=
dgl
.
dataloading
.
MultiLayerNeighborSampler
(
fanouts
)
# fix the number of edges
train_dataloader
=
dgl
.
dataloading
.
DataLoader
(
g
,
torch
.
arange
(
g
.
number_of_nodes
()),
sampler
,
g
,
torch
.
arange
(
g
.
number_of_nodes
()),
sampler
,
batch_size
=
args
.
batch_size
,
shuffle
=
True
,
drop_last
=
False
,
num_workers
=
args
.
num_workers
num_workers
=
args
.
num_workers
,
)
return
train_dataloader
train_loaders
=
[]
for
gidx
,
g
in
enumerate
(
gs
):
train_dataloader
=
set_train_sampler_loader
(
gs
[
gidx
],
ks
[
gidx
])
...
...
@@ -126,31 +136,40 @@ for gidx, g in enumerate(gs):
##################
# Model Definition
feature_dim
=
gs
[
0
].
ndata
[
'
features
'
].
shape
[
1
]
feature_dim
=
gs
[
0
].
ndata
[
"
features
"
].
shape
[
1
]
print
(
"feature dimension:"
,
feature_dim
)
model
=
LANDER
(
feature_dim
=
feature_dim
,
nhid
=
args
.
hidden
,
num_conv
=
args
.
num_conv
,
dropout
=
args
.
dropout
,
use_GAT
=
args
.
gat
,
K
=
args
.
gat_k
,
balance
=
args
.
balance
,
use_cluster_feat
=
args
.
use_cluster_feat
,
use_focal_loss
=
args
.
use_focal_loss
)
model
=
LANDER
(
feature_dim
=
feature_dim
,
nhid
=
args
.
hidden
,
num_conv
=
args
.
num_conv
,
dropout
=
args
.
dropout
,
use_GAT
=
args
.
gat
,
K
=
args
.
gat_k
,
balance
=
args
.
balance
,
use_cluster_feat
=
args
.
use_cluster_feat
,
use_focal_loss
=
args
.
use_focal_loss
,
)
model
=
model
.
to
(
device
)
model
.
train
()
#################
# Hyperparameters
opt
=
optim
.
SGD
(
model
.
parameters
(),
lr
=
args
.
lr
,
momentum
=
args
.
momentum
,
weight_decay
=
args
.
weight_decay
)
opt
=
optim
.
SGD
(
model
.
parameters
(),
lr
=
args
.
lr
,
momentum
=
args
.
momentum
,
weight_decay
=
args
.
weight_decay
,
)
# keep num_batch_per_loader the same for every sub_dataloader
num_batch_per_loader
=
len
(
train_loaders
[
0
])
train_loaders
=
[
iter
(
train_loader
)
for
train_loader
in
train_loaders
]
num_loaders
=
len
(
train_loaders
)
scheduler
=
optim
.
lr_scheduler
.
CosineAnnealingLR
(
opt
,
T_max
=
args
.
epochs
*
num_batch_per_loader
*
num_loaders
,
eta_min
=
1e-5
)
scheduler
=
optim
.
lr_scheduler
.
CosineAnnealingLR
(
opt
,
T_max
=
args
.
epochs
*
num_batch_per_loader
*
num_loaders
,
eta_min
=
1e-5
)
print
(
'
Start Training.
'
)
print
(
"
Start Training.
"
)
###############
# Training Loop
...
...
@@ -163,7 +182,9 @@ for epoch in range(args.epochs):
try
:
minibatch
=
next
(
train_loaders
[
loader_id
])
except
:
train_loaders
[
loader_id
]
=
iter
(
set_train_sampler_loader
(
gs
[
loader_id
],
ks
[
loader_id
]))
train_loaders
[
loader_id
]
=
iter
(
set_train_sampler_loader
(
gs
[
loader_id
],
ks
[
loader_id
])
)
minibatch
=
next
(
train_loaders
[
loader_id
])
input_nodes
,
sub_g
,
bipartites
=
minibatch
sub_g
=
sub_g
.
to
(
device
)
...
...
@@ -171,20 +192,38 @@ for epoch in range(args.epochs):
# get the feature for the input_nodes
opt
.
zero_grad
()
output_bipartite
=
model
(
bipartites
)
loss
,
loss_den_val
,
loss_conn_val
=
model
.
compute_loss
(
output_bipartite
)
loss
,
loss_den_val
,
loss_conn_val
=
model
.
compute_loss
(
output_bipartite
)
loss_den_val_total
.
append
(
loss_den_val
)
loss_conn_val_total
.
append
(
loss_conn_val
)
loss_val_total
.
append
(
loss
.
item
())
loss
.
backward
()
opt
.
step
()
if
(
batch
+
1
)
%
10
==
0
:
print
(
'epoch: %d, batch: %d / %d, loader_id : %d / %d, loss: %.6f, loss_den: %.6f, loss_conn: %.6f'
%
(
epoch
,
batch
,
num_batch_per_loader
,
loader_id
,
num_loaders
,
loss
.
item
(),
loss_den_val
,
loss_conn_val
))
print
(
"epoch: %d, batch: %d / %d, loader_id : %d / %d, loss: %.6f, loss_den: %.6f, loss_conn: %.6f"
%
(
epoch
,
batch
,
num_batch_per_loader
,
loader_id
,
num_loaders
,
loss
.
item
(),
loss_den_val
,
loss_conn_val
,
)
)
scheduler
.
step
()
print
(
'epoch: %d, loss: %.6f, loss_den: %.6f, loss_conn: %.6f'
%
(
epoch
,
np
.
array
(
loss_val_total
).
mean
(),
np
.
array
(
loss_den_val_total
).
mean
(),
np
.
array
(
loss_conn_val_total
).
mean
()))
print
(
"epoch: %d, loss: %.6f, loss_den: %.6f, loss_conn: %.6f"
%
(
epoch
,
np
.
array
(
loss_val_total
).
mean
(),
np
.
array
(
loss_den_val_total
).
mean
(),
np
.
array
(
loss_conn_val_total
).
mean
(),
)
)
torch
.
save
(
model
.
state_dict
(),
args
.
model_filename
)
torch
.
save
(
model
.
state_dict
(),
args
.
model_filename
)
examples/pytorch/hilander/models/graphconv.py
View file @
704bcaf6
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import
dgl.function
as
fn
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
torch.nn
import
init
import
dgl.function
as
fn
from
dgl.nn.pytorch
import
GATConv
from
torch.nn
import
init
class
GraphConvLayer
(
nn
.
Module
):
...
...
examples/pytorch/hilander/models/lander.py
View file @
704bcaf6
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import
dgl
import
dgl.function
as
fn
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
dgl
import
dgl.function
as
fn
from
.focal_loss
import
FocalLoss
from
.graphconv
import
GraphConv
...
...
examples/pytorch/hilander/test.py
View file @
704bcaf6
...
...
@@ -3,6 +3,8 @@ import os
import
pickle
import
time
import
dgl
import
numpy
as
np
import
torch
import
torch.optim
as
optim
...
...
@@ -10,8 +12,6 @@ from dataset import LanderDataset
from
models
import
LANDER
from
utils
import
build_next_level
,
decode
,
evaluation
,
stop_iterating
import
dgl
###########
# ArgParser
parser
=
argparse
.
ArgumentParser
()
...
...
examples/pytorch/hilander/test_subg.py
View file @
704bcaf6
...
...
@@ -3,6 +3,8 @@ import os
import
pickle
import
time
import
dgl
import
numpy
as
np
import
torch
import
torch.optim
as
optim
...
...
@@ -10,8 +12,6 @@ from dataset import LanderDataset
from
models
import
LANDER
from
utils
import
build_next_level
,
decode
,
evaluation
,
stop_iterating
import
dgl
###########
# ArgParser
parser
=
argparse
.
ArgumentParser
()
...
...
examples/pytorch/hilander/train.py
View file @
704bcaf6
...
...
@@ -3,14 +3,14 @@ import os
import
pickle
import
time
import
dgl
import
numpy
as
np
import
torch
import
torch.optim
as
optim
from
dataset
import
LanderDataset
from
models
import
LANDER
import
dgl
###########
# ArgParser
parser
=
argparse
.
ArgumentParser
()
...
...
@@ -50,6 +50,7 @@ if torch.cuda.is_available():
else
:
device
=
torch
.
device
(
"cpu"
)
##################
# Data Preparation
def
prepare_dataset_graphs
(
data_path
,
k_list
,
lvl_list
):
...
...
examples/pytorch/hilander/train_subg.py
View file @
704bcaf6
...
...
@@ -3,14 +3,14 @@ import os
import
pickle
import
time
import
dgl
import
numpy
as
np
import
torch
import
torch.optim
as
optim
from
dataset
import
LanderDataset
from
models
import
LANDER
import
dgl
###########
# ArgParser
parser
=
argparse
.
ArgumentParser
()
...
...
examples/pytorch/hilander/utils/deduce.py
View file @
704bcaf6
"""
This file re-uses implementation from https://github.com/yl-1993/learn-to-cluster
"""
import
dgl
import
numpy
as
np
import
torch
from
sklearn
import
mixture
import
dgl
from
.density
import
density_to_peaks
,
density_to_peaks_vectorize
__all__
=
[
...
...
examples/pytorch/hilander/utils/evaluate.py
View file @
704bcaf6
...
...
@@ -6,7 +6,7 @@ import inspect
import
numpy
as
np
from
clustering_benchmark
import
ClusteringBenchmark
from
utils
import
TextColors
,
Timer
,
metrics
from
utils
import
metrics
,
TextColors
,
Timer
def
_read_meta
(
fn
):
...
...
examples/pytorch/hilander/utils/faiss_search.py
View file @
704bcaf6
...
...
@@ -101,7 +101,6 @@ def faiss_search_knn(
sort
=
True
,
verbose
=
False
,
):
dists
,
nbrs
=
faiss_search_approx_knn
(
query
=
feat
,
target
=
feat
,
k
=
k
,
nprobe
=
nprobe
,
verbose
=
verbose
)
...
...
examples/pytorch/infograph/evaluate_embedding.py
View file @
704bcaf6
...
...
@@ -70,7 +70,6 @@ def svc_classify(x, y, search):
kf
=
StratifiedKFold
(
n_splits
=
10
,
shuffle
=
True
,
random_state
=
None
)
accuracies
=
[]
for
train_index
,
test_index
in
kf
.
split
(
x
,
y
):
x_train
,
x_test
=
x
[
train_index
],
x
[
test_index
]
y_train
,
y_test
=
y
[
train_index
],
y
[
test_index
]
...
...
examples/pytorch/infograph/model.py
View file @
704bcaf6
import
torch
as
th
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
torch.nn
import
GRU
,
BatchNorm1d
,
Linear
,
ModuleList
,
ReLU
,
Sequential
from
utils
import
global_global_loss_
,
local_global_loss_
from
dgl.nn
import
GINConv
,
NNConv
,
Set2Set
from
dgl.nn.pytorch.glob
import
SumPooling
from
torch.nn
import
BatchNorm1d
,
GRU
,
Linear
,
ModuleList
,
ReLU
,
Sequential
from
utils
import
global_global_loss_
,
local_global_loss_
""" Feedforward neural network"""
...
...
@@ -102,7 +102,6 @@ class GINEncoder(nn.Module):
self
.
pool
=
SumPooling
()
def
forward
(
self
,
graph
,
feat
):
xs
=
[]
x
=
feat
for
i
in
range
(
self
.
n_layer
):
...
...
@@ -163,7 +162,6 @@ class InfoGraph(nn.Module):
return
global_emb
def
forward
(
self
,
graph
,
feat
,
graph_id
):
global_emb
,
local_emb
=
self
.
encoder
(
graph
,
feat
)
global_h
=
self
.
global_d
(
global_emb
)
# global hidden representation
...
...
@@ -221,7 +219,6 @@ class NNConvEncoder(nn.Module):
self
.
set2set
=
Set2Set
(
hid_dim
,
n_iters
=
3
,
n_layers
=
1
)
def
forward
(
self
,
graph
,
nfeat
,
efeat
):
out
=
F
.
relu
(
self
.
lin0
(
nfeat
))
h
=
out
.
unsqueeze
(
0
)
...
...
@@ -279,7 +276,6 @@ class InfoGraphS(nn.Module):
self
.
unsup_d
=
FeedforwardNetwork
(
2
*
hid_dim
,
hid_dim
)
def
forward
(
self
,
graph
,
nfeat
,
efeat
):
sup_global_emb
,
sup_local_emb
=
self
.
sup_encoder
(
graph
,
nfeat
,
efeat
)
sup_global_pred
=
self
.
fc2
(
F
.
relu
(
self
.
fc1
(
sup_global_emb
)))
...
...
@@ -288,7 +284,6 @@ class InfoGraphS(nn.Module):
return
sup_global_pred
def
unsup_forward
(
self
,
graph
,
nfeat
,
efeat
,
graph_id
):
sup_global_emb
,
sup_local_emb
=
self
.
sup_encoder
(
graph
,
nfeat
,
efeat
)
unsup_global_emb
,
unsup_local_emb
=
self
.
unsup_encoder
(
graph
,
nfeat
,
efeat
...
...
examples/pytorch/infograph/semisupervised.py
View file @
704bcaf6
import
argparse
import
dgl
import
numpy
as
np
import
torch
as
th
import
torch.nn.functional
as
F
from
model
import
InfoGraphS
import
dgl
from
dgl.data
import
QM9EdgeDataset
from
dgl.data.utils
import
Subset
from
dgl.dataloading
import
GraphDataLoader
from
model
import
InfoGraphS
def
argument
():
...
...
@@ -160,7 +160,6 @@ def evaluate(model, loader, num, device):
if
__name__
==
"__main__"
:
# Step 1: Prepare graph data ===================================== #
args
=
argument
()
label_keys
=
[
args
.
target
]
...
...
examples/pytorch/infograph/unsupervised.py
View file @
704bcaf6
import
argparse
import
torch
as
th
from
evaluate_embedding
import
evaluate_embedding
from
model
import
InfoGraph
import
dgl
import
torch
as
th
from
dgl.data
import
GINDataset
from
dgl.dataloading
import
GraphDataLoader
from
evaluate_embedding
import
evaluate_embedding
from
model
import
InfoGraph
def
argument
():
...
...
@@ -75,7 +75,6 @@ def collate(samples):
if
__name__
==
"__main__"
:
# Step 1: Prepare graph data ===================================== #
args
=
argument
()
print
(
args
)
...
...
@@ -131,7 +130,6 @@ if __name__ == "__main__":
model
.
train
()
for
graph
,
label
in
dataloader
:
graph
=
graph
.
to
(
args
.
device
)
feat
=
graph
.
ndata
[
"attr"
]
graph_id
=
graph
.
ndata
[
"graph_id"
]
...
...
@@ -147,7 +145,6 @@ if __name__ == "__main__":
print
(
"Epoch {}, Loss {:.4f}"
.
format
(
epoch
,
loss_all
))
if
epoch
%
args
.
log_interval
==
0
:
# evaluate embeddings
model
.
eval
()
emb
=
model
.
get_embedding
(
wholegraph
,
wholefeat
).
cpu
()
...
...
examples/pytorch/infograph/utils.py
View file @
704bcaf6
...
...
@@ -41,7 +41,6 @@ def get_negative_expectation(q_samples, average=True):
def
local_global_loss_
(
l_enc
,
g_enc
,
graph_id
):
num_graphs
=
g_enc
.
shape
[
0
]
num_nodes
=
l_enc
.
shape
[
0
]
...
...
@@ -51,7 +50,6 @@ def local_global_loss_(l_enc, g_enc, graph_id):
neg_mask
=
th
.
ones
((
num_nodes
,
num_graphs
)).
to
(
device
)
for
nodeidx
,
graphidx
in
enumerate
(
graph_id
):
pos_mask
[
nodeidx
][
graphidx
]
=
1.0
neg_mask
[
nodeidx
][
graphidx
]
=
0.0
...
...
@@ -66,7 +64,6 @@ def local_global_loss_(l_enc, g_enc, graph_id):
def
global_global_loss_
(
sup_enc
,
unsup_enc
):
num_graphs
=
sup_enc
.
shape
[
0
]
device
=
sup_enc
.
device
...
...
Prev
1
…
5
6
7
8
9
10
11
12
13
…
17
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment