Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
704bcaf6
Unverified
Commit
704bcaf6
authored
Feb 19, 2023
by
Hongzhi (Steve), Chen
Committed by
GitHub
Feb 19, 2023
Browse files
examples (#5323)
Co-authored-by:
Ubuntu
<
ubuntu@ip-172-31-28-63.ap-northeast-1.compute.internal
>
parent
6bc82161
Changes
332
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
392 additions
and
253 deletions
+392
-253
examples/pytorch/grace/eval.py
examples/pytorch/grace/eval.py
+1
-1
examples/pytorch/grace/main.py
examples/pytorch/grace/main.py
+0
-1
examples/pytorch/grand/main.py
examples/pytorch/grand/main.py
+3
-7
examples/pytorch/grand/model.py
examples/pytorch/grand/model.py
+1
-8
examples/pytorch/graph_matching/examples.py
examples/pytorch/graph_matching/examples.py
+1
-2
examples/pytorch/graph_matching/ged.py
examples/pytorch/graph_matching/ged.py
+3
-4
examples/pytorch/graphsage/advanced/model.py
examples/pytorch/graphsage/advanced/model.py
+2
-3
examples/pytorch/graphsage/advanced/negative_sampler.py
examples/pytorch/graphsage/advanced/negative_sampler.py
+1
-2
examples/pytorch/graphsage/advanced/train_lightning_unsupervised.py
...ytorch/graphsage/advanced/train_lightning_unsupervised.py
+141
-87
examples/pytorch/graphsage/lightning/node_classification.py
examples/pytorch/graphsage/lightning/node_classification.py
+3
-3
examples/pytorch/graphsage/link_pred.py
examples/pytorch/graphsage/link_pred.py
+119
-59
examples/pytorch/graphsage/load_graph.py
examples/pytorch/graphsage/load_graph.py
+1
-2
examples/pytorch/graphsage/node_classification.py
examples/pytorch/graphsage/node_classification.py
+100
-54
examples/pytorch/graphsage/train_full.py
examples/pytorch/graphsage/train_full.py
+2
-2
examples/pytorch/graphsaint/modules.py
examples/pytorch/graphsaint/modules.py
+1
-2
examples/pytorch/graphsaint/sampler.py
examples/pytorch/graphsaint/sampler.py
+4
-7
examples/pytorch/graphsaint/train_sampling.py
examples/pytorch/graphsaint/train_sampling.py
+1
-1
examples/pytorch/graphsaint/utils.py
examples/pytorch/graphsaint/utils.py
+2
-2
examples/pytorch/graphsim/dataloader.py
examples/pytorch/graphsim/dataloader.py
+2
-2
examples/pytorch/graphsim/models.py
examples/pytorch/graphsim/models.py
+4
-4
No files found.
examples/pytorch/grace/eval.py
View file @
704bcaf6
...
...
@@ -10,7 +10,7 @@ from sklearn.linear_model import LogisticRegression
from
sklearn.metrics
import
f1_score
from
sklearn.model_selection
import
GridSearchCV
,
train_test_split
from
sklearn.multiclass
import
OneVsRestClassifier
from
sklearn.preprocessing
import
OneHotEncoder
,
normalize
from
sklearn.preprocessing
import
normalize
,
OneHotEncoder
def
repeat
(
n_times
):
...
...
examples/pytorch/grace/main.py
View file @
704bcaf6
...
...
@@ -75,7 +75,6 @@ else:
args
.
device
=
"cpu"
if
__name__
==
"__main__"
:
# Step 1: Load hyperparameters =================================================================== #
lr
=
args
.
lr
hid_dim
=
args
.
hid_dim
...
...
examples/pytorch/grand/main.py
View file @
704bcaf6
import
argparse
import
warnings
import
dgl
import
numpy
as
np
import
torch
as
th
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.optim
as
optim
from
model
import
GRAND
import
dgl
from
dgl.data
import
CiteseerGraphDataset
,
CoraGraphDataset
,
PubmedGraphDataset
from
model
import
GRAND
warnings
.
filterwarnings
(
"ignore"
)
def
argument
():
parser
=
argparse
.
ArgumentParser
(
description
=
"GRAND"
)
# data source params
...
...
@@ -111,7 +110,6 @@ def consis_loss(logps, temp, lam):
if
__name__
==
"__main__"
:
# Step 1: Prepare graph data and retrieve train/validation/test index ============================= #
# Load from DGL dataset
args
=
argument
()
...
...
@@ -175,7 +173,6 @@ if __name__ == "__main__":
# Step 4: training epoches =============================================================== #
for
epoch
in
range
(
args
.
epochs
):
"""Training"""
model
.
train
()
...
...
@@ -204,7 +201,6 @@ if __name__ == "__main__":
""" Validating """
model
.
eval
()
with
th
.
no_grad
():
val_logits
=
model
(
graph
,
feats
,
False
)
loss_val
=
F
.
nll_loss
(
val_logits
[
val_idx
],
labels
[
val_idx
])
...
...
examples/pytorch/grand/model.py
View file @
704bcaf6
import
dgl.function
as
fn
import
numpy
as
np
import
torch
as
th
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
dgl.function
as
fn
def
drop_node
(
feats
,
drop_rate
,
training
):
n
=
feats
.
shape
[
0
]
drop_rates
=
th
.
FloatTensor
(
np
.
ones
(
n
)
*
drop_rate
)
if
training
:
masks
=
th
.
bernoulli
(
1.0
-
drop_rates
).
unsqueeze
(
1
)
feats
=
masks
.
to
(
feats
.
device
)
*
feats
...
...
@@ -42,7 +39,6 @@ class MLP(nn.Module):
self
.
layer2
.
reset_parameters
()
def
forward
(
self
,
x
):
if
self
.
use_bn
:
x
=
self
.
bn1
(
x
)
x
=
self
.
input_dropout
(
x
)
...
...
@@ -68,7 +64,6 @@ def GRANDConv(graph, feats, order):
Propagation Steps
"""
with
graph
.
local_scope
():
"""Calculate Symmetric normalized adjacency matrix \hat{A}"""
degs
=
graph
.
in_degrees
().
float
().
clamp
(
min
=
1
)
norm
=
th
.
pow
(
degs
,
-
0.5
).
to
(
feats
.
device
).
unsqueeze
(
1
)
...
...
@@ -127,7 +122,6 @@ class GRAND(nn.Module):
hidden_droprate
=
0.0
,
batchnorm
=
False
,
):
super
(
GRAND
,
self
).
__init__
()
self
.
in_dim
=
in_dim
self
.
hid_dim
=
hid_dim
...
...
@@ -143,7 +137,6 @@ class GRAND(nn.Module):
self
.
node_dropout
=
nn
.
Dropout
(
node_dropout
)
def
forward
(
self
,
graph
,
feats
,
training
=
True
):
X
=
feats
S
=
self
.
S
...
...
examples/pytorch/graph_matching/examples.py
View file @
704bcaf6
import
dgl
import
numpy
as
np
from
ged
import
graph_edit_distance
import
dgl
src1
=
[
0
,
1
,
2
,
3
,
4
,
5
]
dst1
=
[
1
,
2
,
3
,
4
,
5
,
6
]
...
...
examples/pytorch/graph_matching/ged.py
View file @
704bcaf6
from
copy
import
deepcopy
from
heapq
import
heapify
,
heappop
,
heappush
,
nsmallest
import
dgl
import
numpy
as
np
from
heapq
import
heappush
,
heappop
,
heapify
,
nsmallest
from
copy
import
deepcopy
# We use lapjv implementation (https://github.com/src-d/lapjv) to solve assignment problem, because of its scalability
# Also see https://github.com/berhane/LAP-solvers for benchmarking of LAP solvers
...
...
@@ -247,7 +248,6 @@ class search_tree_node:
cost_matrix_nodes
,
cost_matrix_edges
,
):
self
.
matched_cost
=
parent_matched_cost
self
.
future_approximate_cost
=
0.0
self
.
matched_nodes
=
deepcopy
(
parent_matched_nodes
)
...
...
@@ -1156,7 +1156,6 @@ def graph_edit_distance(
algorithm
=
"bipartite"
,
max_beam_size
=
100
,
):
"""Returns GED (graph edit distance) between DGLGraphs G1 and G2.
...
...
examples/pytorch/graphsage/advanced/model.py
View file @
704bcaf6
import
dgl
import
dgl.nn
as
dglnn
import
sklearn.linear_model
as
lm
import
sklearn.metrics
as
skm
import
torch
as
th
...
...
@@ -5,9 +7,6 @@ import torch.functional as F
import
torch.nn
as
nn
import
tqdm
import
dgl
import
dgl.nn
as
dglnn
class
SAGE
(
nn
.
Module
):
def
__init__
(
...
...
examples/pytorch/graphsage/advanced/negative_sampler.py
View file @
704bcaf6
import
torch
as
th
import
dgl
import
torch
as
th
class
NegativeSampler
(
object
):
...
...
examples/pytorch/graphsage/advanced/train_lightning_unsupervised.py
View file @
704bcaf6
import
argparse
import
glob
import
os
import
sys
import
time
import
dgl
import
dgl.function
as
fn
import
dgl.nn.pytorch
as
dglnn
import
numpy
as
np
import
torch
as
th
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.optim
as
optim
import
dgl.nn.pytorch
as
dglnn
import
dgl.function
as
fn
import
time
import
argparse
import
tqdm
import
glob
import
os
from
model
import
compute_acc_unsupervised
as
compute_acc
,
SAGE
from
negative_sampler
import
NegativeSampler
from
pytorch_lightning.callbacks
import
ModelCheckpoint
,
Callback
from
pytorch_lightning
import
LightningDataModule
,
LightningModule
,
Trainer
from
model
import
SAGE
,
compute_acc_unsupervised
as
compute_acc
import
sys
sys
.
path
.
append
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'..'
))
from
load_graph
import
load_reddit
,
inductive_split
,
load_ogb
from
pytorch_lightning.callbacks
import
Callback
,
ModelCheckpoint
sys
.
path
.
append
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
".."
))
from
load_graph
import
inductive_split
,
load_ogb
,
load_reddit
class
CrossEntropyLoss
(
nn
.
Module
):
def
forward
(
self
,
block_outputs
,
pos_graph
,
neg_graph
):
with
pos_graph
.
local_scope
():
pos_graph
.
ndata
[
'h'
]
=
block_outputs
pos_graph
.
apply_edges
(
fn
.
u_dot_v
(
'h'
,
'h'
,
'
score
'
))
pos_score
=
pos_graph
.
edata
[
'
score
'
]
pos_graph
.
ndata
[
"h"
]
=
block_outputs
pos_graph
.
apply_edges
(
fn
.
u_dot_v
(
"h"
,
"h"
,
"
score
"
))
pos_score
=
pos_graph
.
edata
[
"
score
"
]
with
neg_graph
.
local_scope
():
neg_graph
.
ndata
[
'h'
]
=
block_outputs
neg_graph
.
apply_edges
(
fn
.
u_dot_v
(
'h'
,
'h'
,
'
score
'
))
neg_score
=
neg_graph
.
edata
[
'
score
'
]
neg_graph
.
ndata
[
"h"
]
=
block_outputs
neg_graph
.
apply_edges
(
fn
.
u_dot_v
(
"h"
,
"h"
,
"
score
"
))
neg_score
=
neg_graph
.
edata
[
"
score
"
]
score
=
th
.
cat
([
pos_score
,
neg_score
])
label
=
th
.
cat
([
th
.
ones_like
(
pos_score
),
th
.
zeros_like
(
neg_score
)]).
long
()
label
=
th
.
cat
(
[
th
.
ones_like
(
pos_score
),
th
.
zeros_like
(
neg_score
)]
).
long
()
loss
=
F
.
binary_cross_entropy_with_logits
(
score
,
label
.
float
())
return
loss
class
SAGELightning
(
LightningModule
):
def
__init__
(
self
,
in_feats
,
n_hidden
,
n_classes
,
n_layers
,
activation
,
dropout
,
lr
):
def
__init__
(
self
,
in_feats
,
n_hidden
,
n_classes
,
n_layers
,
activation
,
dropout
,
lr
):
super
().
__init__
()
self
.
save_hyperparameters
()
self
.
module
=
SAGE
(
in_feats
,
n_hidden
,
n_classes
,
n_layers
,
activation
,
dropout
)
self
.
module
=
SAGE
(
in_feats
,
n_hidden
,
n_classes
,
n_layers
,
activation
,
dropout
)
self
.
lr
=
lr
self
.
loss_fcn
=
CrossEntropyLoss
()
...
...
@@ -57,18 +60,20 @@ class SAGELightning(LightningModule):
mfgs
=
[
mfg
.
int
().
to
(
device
)
for
mfg
in
mfgs
]
pos_graph
=
pos_graph
.
to
(
device
)
neg_graph
=
neg_graph
.
to
(
device
)
batch_inputs
=
mfgs
[
0
].
srcdata
[
'
features
'
]
batch_labels
=
mfgs
[
-
1
].
dstdata
[
'
labels
'
]
batch_inputs
=
mfgs
[
0
].
srcdata
[
"
features
"
]
batch_labels
=
mfgs
[
-
1
].
dstdata
[
"
labels
"
]
batch_pred
=
self
.
module
(
mfgs
,
batch_inputs
)
loss
=
self
.
loss_fcn
(
batch_pred
,
pos_graph
,
neg_graph
)
self
.
log
(
'train_loss'
,
loss
,
prog_bar
=
True
,
on_step
=
False
,
on_epoch
=
True
)
self
.
log
(
"train_loss"
,
loss
,
prog_bar
=
True
,
on_step
=
False
,
on_epoch
=
True
)
return
loss
def
validation_step
(
self
,
batch
,
batch_idx
):
input_nodes
,
output_nodes
,
mfgs
=
batch
mfgs
=
[
mfg
.
int
().
to
(
device
)
for
mfg
in
mfgs
]
batch_inputs
=
mfgs
[
0
].
srcdata
[
'
features
'
]
batch_labels
=
mfgs
[
-
1
].
dstdata
[
'
labels
'
]
batch_inputs
=
mfgs
[
0
].
srcdata
[
"
features
"
]
batch_labels
=
mfgs
[
-
1
].
dstdata
[
"
labels
"
]
batch_pred
=
self
.
module
(
mfgs
,
batch_inputs
)
return
batch_pred
...
...
@@ -78,54 +83,73 @@ class SAGELightning(LightningModule):
class
DataModule
(
LightningDataModule
):
def
__init__
(
self
,
dataset_name
,
data_cpu
=
False
,
fan_out
=
[
10
,
25
],
device
=
th
.
device
(
'cpu'
),
batch_size
=
1000
,
num_workers
=
4
):
def
__init__
(
self
,
dataset_name
,
data_cpu
=
False
,
fan_out
=
[
10
,
25
],
device
=
th
.
device
(
"cpu"
),
batch_size
=
1000
,
num_workers
=
4
,
):
super
().
__init__
()
if
dataset_name
==
'
reddit
'
:
if
dataset_name
==
"
reddit
"
:
g
,
n_classes
=
load_reddit
()
n_edges
=
g
.
num_edges
()
reverse_eids
=
th
.
cat
(
[
th
.
arange
(
n_edges
//
2
,
n_edges
),
th
.
arange
(
0
,
n_edges
//
2
)]
)
elif
dataset_name
==
'
ogbn-products
'
:
g
,
n_classes
=
load_ogb
(
'
ogbn-products
'
)
reverse_eids
=
th
.
cat
(
[
th
.
arange
(
n_edges
//
2
,
n_edges
),
th
.
arange
(
0
,
n_edges
//
2
)]
)
elif
dataset_name
==
"
ogbn-products
"
:
g
,
n_classes
=
load_ogb
(
"
ogbn-products
"
)
n_edges
=
g
.
num_edges
()
# The reverse edge of edge 0 in OGB products dataset is 1.
# The reverse edge of edge 2 is 3. So on so forth.
reverse_eids
=
th
.
arange
(
n_edges
)
^
1
else
:
raise
ValueError
(
'
unknown dataset
'
)
raise
ValueError
(
"
unknown dataset
"
)
train_nid
=
th
.
nonzero
(
g
.
ndata
[
'train_mask'
],
as_tuple
=
True
)[
0
]
val_nid
=
th
.
nonzero
(
g
.
ndata
[
'val_mask'
],
as_tuple
=
True
)[
0
]
test_nid
=
th
.
nonzero
(
~
(
g
.
ndata
[
'train_mask'
]
|
g
.
ndata
[
'val_mask'
]),
as_tuple
=
True
)[
0
]
train_nid
=
th
.
nonzero
(
g
.
ndata
[
"train_mask"
],
as_tuple
=
True
)[
0
]
val_nid
=
th
.
nonzero
(
g
.
ndata
[
"val_mask"
],
as_tuple
=
True
)[
0
]
test_nid
=
th
.
nonzero
(
~
(
g
.
ndata
[
"train_mask"
]
|
g
.
ndata
[
"val_mask"
]),
as_tuple
=
True
)[
0
]
sampler
=
dgl
.
dataloading
.
MultiLayerNeighborSampler
([
int
(
_
)
for
_
in
fan_out
])
sampler
=
dgl
.
dataloading
.
MultiLayerNeighborSampler
(
[
int
(
_
)
for
_
in
fan_out
]
)
dataloader_device
=
th
.
device
(
'
cpu
'
)
dataloader_device
=
th
.
device
(
"
cpu
"
)
if
not
data_cpu
:
train_nid
=
train_nid
.
to
(
device
)
val_nid
=
val_nid
.
to
(
device
)
test_nid
=
test_nid
.
to
(
device
)
g
=
g
.
formats
([
'
csc
'
])
g
=
g
.
formats
([
"
csc
"
])
g
=
g
.
to
(
device
)
dataloader_device
=
device
self
.
g
=
g
self
.
train_nid
,
self
.
val_nid
,
self
.
test_nid
=
train_nid
,
val_nid
,
test_nid
self
.
train_nid
,
self
.
val_nid
,
self
.
test_nid
=
(
train_nid
,
val_nid
,
test_nid
,
)
self
.
sampler
=
sampler
self
.
device
=
dataloader_device
self
.
batch_size
=
batch_size
self
.
num_workers
=
num_workers
self
.
in_feats
=
g
.
ndata
[
'
features
'
].
shape
[
1
]
self
.
in_feats
=
g
.
ndata
[
"
features
"
].
shape
[
1
]
self
.
n_classes
=
n_classes
self
.
reverse_eids
=
reverse_eids
def
train_dataloader
(
self
):
sampler
=
dgl
.
dataloading
.
as_edge_prediction_sampler
(
self
.
sampler
,
exclude
=
'reverse_id'
,
self
.
sampler
,
exclude
=
"reverse_id"
,
reverse_eids
=
self
.
reverse_eids
,
negative_sampler
=
NegativeSampler
(
self
.
g
,
args
.
num_negs
,
args
.
neg_share
))
negative_sampler
=
NegativeSampler
(
self
.
g
,
args
.
num_negs
,
args
.
neg_share
),
)
return
dgl
.
dataloading
.
DataLoader
(
self
.
g
,
np
.
arange
(
self
.
g
.
num_edges
()),
...
...
@@ -134,7 +158,8 @@ class DataModule(LightningDataModule):
batch_size
=
self
.
batch_size
,
shuffle
=
True
,
drop_last
=
False
,
num_workers
=
self
.
num_workers
)
num_workers
=
self
.
num_workers
,
)
def
val_dataloader
(
self
):
# Note that the validation data loader is a DataLoader
...
...
@@ -147,63 +172,92 @@ class DataModule(LightningDataModule):
batch_size
=
self
.
batch_size
,
shuffle
=
False
,
drop_last
=
False
,
num_workers
=
self
.
num_workers
)
num_workers
=
self
.
num_workers
,
)
class
UnsupervisedClassification
(
Callback
):
def
on_validation_epoch_start
(
self
,
trainer
,
pl_module
):
self
.
val_outputs
=
[]
def
on_validation_batch_end
(
self
,
trainer
,
pl_module
,
outputs
,
batch
,
batch_idx
,
dataloader_idx
):
def
on_validation_batch_end
(
self
,
trainer
,
pl_module
,
outputs
,
batch
,
batch_idx
,
dataloader_idx
):
self
.
val_outputs
.
append
(
outputs
)
def
on_validation_epoch_end
(
self
,
trainer
,
pl_module
):
node_emb
=
th
.
cat
(
self
.
val_outputs
,
0
)
g
=
trainer
.
datamodule
.
g
labels
=
g
.
ndata
[
'
labels
'
]
labels
=
g
.
ndata
[
"
labels
"
]
f1_micro
,
f1_macro
=
compute_acc
(
node_emb
,
labels
,
trainer
.
datamodule
.
train_nid
,
trainer
.
datamodule
.
val_nid
,
trainer
.
datamodule
.
test_nid
)
pl_module
.
log
(
'val_f1_micro'
,
f1_micro
)
node_emb
,
labels
,
trainer
.
datamodule
.
train_nid
,
trainer
.
datamodule
.
val_nid
,
trainer
.
datamodule
.
test_nid
,
)
pl_module
.
log
(
"val_f1_micro"
,
f1_micro
)
if
__name__
==
'
__main__
'
:
if
__name__
==
"
__main__
"
:
argparser
=
argparse
.
ArgumentParser
(
"multi-gpu training"
)
argparser
.
add_argument
(
"--gpu"
,
type
=
int
,
default
=
0
)
argparser
.
add_argument
(
'--dataset'
,
type
=
str
,
default
=
'reddit'
)
argparser
.
add_argument
(
'--num-epochs'
,
type
=
int
,
default
=
20
)
argparser
.
add_argument
(
'--num-hidden'
,
type
=
int
,
default
=
16
)
argparser
.
add_argument
(
'--num-layers'
,
type
=
int
,
default
=
2
)
argparser
.
add_argument
(
'--num-negs'
,
type
=
int
,
default
=
1
)
argparser
.
add_argument
(
'--neg-share'
,
default
=
False
,
action
=
'store_true'
,
help
=
"sharing neg nodes for positive nodes"
)
argparser
.
add_argument
(
'--fan-out'
,
type
=
str
,
default
=
'10,25'
)
argparser
.
add_argument
(
'--batch-size'
,
type
=
int
,
default
=
10000
)
argparser
.
add_argument
(
'--log-every'
,
type
=
int
,
default
=
20
)
argparser
.
add_argument
(
'--eval-every'
,
type
=
int
,
default
=
1000
)
argparser
.
add_argument
(
'--lr'
,
type
=
float
,
default
=
0.003
)
argparser
.
add_argument
(
'--dropout'
,
type
=
float
,
default
=
0.5
)
argparser
.
add_argument
(
'--num-workers'
,
type
=
int
,
default
=
0
,
help
=
"Number of sampling processes. Use 0 for no extra process."
)
argparser
.
add_argument
(
"--dataset"
,
type
=
str
,
default
=
"reddit"
)
argparser
.
add_argument
(
"--num-epochs"
,
type
=
int
,
default
=
20
)
argparser
.
add_argument
(
"--num-hidden"
,
type
=
int
,
default
=
16
)
argparser
.
add_argument
(
"--num-layers"
,
type
=
int
,
default
=
2
)
argparser
.
add_argument
(
"--num-negs"
,
type
=
int
,
default
=
1
)
argparser
.
add_argument
(
"--neg-share"
,
default
=
False
,
action
=
"store_true"
,
help
=
"sharing neg nodes for positive nodes"
,
)
argparser
.
add_argument
(
"--fan-out"
,
type
=
str
,
default
=
"10,25"
)
argparser
.
add_argument
(
"--batch-size"
,
type
=
int
,
default
=
10000
)
argparser
.
add_argument
(
"--log-every"
,
type
=
int
,
default
=
20
)
argparser
.
add_argument
(
"--eval-every"
,
type
=
int
,
default
=
1000
)
argparser
.
add_argument
(
"--lr"
,
type
=
float
,
default
=
0.003
)
argparser
.
add_argument
(
"--dropout"
,
type
=
float
,
default
=
0.5
)
argparser
.
add_argument
(
"--num-workers"
,
type
=
int
,
default
=
0
,
help
=
"Number of sampling processes. Use 0 for no extra process."
,
)
args
=
argparser
.
parse_args
()
if
args
.
gpu
>=
0
:
device
=
th
.
device
(
'
cuda:%d
'
%
args
.
gpu
)
device
=
th
.
device
(
"
cuda:%d
"
%
args
.
gpu
)
else
:
device
=
th
.
device
(
'
cpu
'
)
device
=
th
.
device
(
"
cpu
"
)
datamodule
=
DataModule
(
args
.
dataset
,
True
,
[
int
(
_
)
for
_
in
args
.
fan_out
.
split
(
','
)],
device
,
args
.
batch_size
,
args
.
num_workers
)
args
.
dataset
,
True
,
[
int
(
_
)
for
_
in
args
.
fan_out
.
split
(
","
)],
device
,
args
.
batch_size
,
args
.
num_workers
,
)
model
=
SAGELightning
(
datamodule
.
in_feats
,
args
.
num_hidden
,
datamodule
.
n_classes
,
args
.
num_layers
,
F
.
relu
,
args
.
dropout
,
args
.
lr
)
datamodule
.
in_feats
,
args
.
num_hidden
,
datamodule
.
n_classes
,
args
.
num_layers
,
F
.
relu
,
args
.
dropout
,
args
.
lr
,
)
# Train
unsupervised_callback
=
UnsupervisedClassification
()
checkpoint_callback
=
ModelCheckpoint
(
monitor
=
'val_f1_micro'
,
save_top_k
=
1
)
trainer
=
Trainer
(
gpus
=
[
args
.
gpu
]
if
args
.
gpu
!=
-
1
else
None
,
checkpoint_callback
=
ModelCheckpoint
(
monitor
=
"val_f1_micro"
,
save_top_k
=
1
)
trainer
=
Trainer
(
gpus
=
[
args
.
gpu
]
if
args
.
gpu
!=
-
1
else
None
,
max_epochs
=
args
.
num_epochs
,
val_check_interval
=
1000
,
callbacks
=
[
checkpoint_callback
,
unsupervised_callback
],
num_sanity_val_steps
=
0
)
num_sanity_val_steps
=
0
,
)
trainer
.
fit
(
model
,
datamodule
=
datamodule
)
examples/pytorch/graphsage/lightning/node_classification.py
View file @
704bcaf6
import
glob
import
os
import
dgl
import
dgl.nn.pytorch
as
dglnn
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
...
...
@@ -12,9 +15,6 @@ from pytorch_lightning import LightningDataModule, LightningModule, Trainer
from
pytorch_lightning.callbacks
import
ModelCheckpoint
from
torchmetrics
import
Accuracy
import
dgl
import
dgl.nn.pytorch
as
dglnn
class
SAGE
(
LightningModule
):
def
__init__
(
self
,
in_feats
,
n_hidden
,
n_classes
):
...
...
examples/pytorch/graphsage/link_pred.py
View file @
704bcaf6
import
argparse
import
dgl
import
dgl.nn
as
dglnn
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torchmetrics.functional
as
MF
import
dgl
import
dgl.nn
as
dglnn
from
dgl.dataloading
import
DataLoader
,
NeighborSampler
,
MultiLayerFullNeighborSampler
,
as_edge_prediction_sampler
,
negative_sampler
import
tqdm
import
argparse
from
dgl.dataloading
import
(
as_edge_prediction_sampler
,
DataLoader
,
MultiLayerFullNeighborSampler
,
negative_sampler
,
NeighborSampler
,
)
from
ogb.linkproppred
import
DglLinkPropPredDataset
,
Evaluator
def
to_bidirected_with_reverse_mapping
(
g
):
"""Makes a graph bidirectional, and returns a mapping array ``mapping`` where ``mapping[i]``
is the reverse edge of edge ID ``i``. Does not work with graphs that have self-loops.
"""
g_simple
,
mapping
=
dgl
.
to_simple
(
dgl
.
add_reverse_edges
(
g
),
return_counts
=
'count'
,
writeback_mapping
=
True
)
c
=
g_simple
.
edata
[
'count'
]
dgl
.
add_reverse_edges
(
g
),
return_counts
=
"count"
,
writeback_mapping
=
True
)
c
=
g_simple
.
edata
[
"count"
]
num_edges
=
g
.
num_edges
()
mapping_offset
=
torch
.
zeros
(
g_simple
.
num_edges
()
+
1
,
dtype
=
g_simple
.
idtype
)
mapping_offset
=
torch
.
zeros
(
g_simple
.
num_edges
()
+
1
,
dtype
=
g_simple
.
idtype
)
mapping_offset
[
1
:]
=
c
.
cumsum
(
0
)
idx
=
mapping
.
argsort
()
idx_uniq
=
idx
[
mapping_offset
[:
-
1
]]
reverse_idx
=
torch
.
where
(
idx_uniq
>=
num_edges
,
idx_uniq
-
num_edges
,
idx_uniq
+
num_edges
)
reverse_idx
=
torch
.
where
(
idx_uniq
>=
num_edges
,
idx_uniq
-
num_edges
,
idx_uniq
+
num_edges
)
reverse_mapping
=
mapping
[
reverse_idx
]
# sanity check
src1
,
dst1
=
g_simple
.
edges
()
...
...
@@ -30,21 +43,23 @@ def to_bidirected_with_reverse_mapping(g):
assert
torch
.
equal
(
src2
,
dst1
)
return
g_simple
,
reverse_mapping
class
SAGE
(
nn
.
Module
):
def
__init__
(
self
,
in_size
,
hid_size
):
super
().
__init__
()
self
.
layers
=
nn
.
ModuleList
()
# three-layer GraphSAGE-mean
self
.
layers
.
append
(
dglnn
.
SAGEConv
(
in_size
,
hid_size
,
'
mean
'
))
self
.
layers
.
append
(
dglnn
.
SAGEConv
(
hid_size
,
hid_size
,
'
mean
'
))
self
.
layers
.
append
(
dglnn
.
SAGEConv
(
hid_size
,
hid_size
,
'
mean
'
))
self
.
layers
.
append
(
dglnn
.
SAGEConv
(
in_size
,
hid_size
,
"
mean
"
))
self
.
layers
.
append
(
dglnn
.
SAGEConv
(
hid_size
,
hid_size
,
"
mean
"
))
self
.
layers
.
append
(
dglnn
.
SAGEConv
(
hid_size
,
hid_size
,
"
mean
"
))
self
.
hid_size
=
hid_size
self
.
predictor
=
nn
.
Sequential
(
nn
.
Linear
(
hid_size
,
hid_size
),
nn
.
ReLU
(),
nn
.
Linear
(
hid_size
,
hid_size
),
nn
.
ReLU
(),
nn
.
Linear
(
hid_size
,
1
))
nn
.
Linear
(
hid_size
,
1
),
)
def
forward
(
self
,
pair_graph
,
neg_pair_graph
,
blocks
,
x
):
h
=
x
...
...
@@ -60,19 +75,31 @@ class SAGE(nn.Module):
def
inference
(
self
,
g
,
device
,
batch_size
):
"""Layer-wise inference algorithm to compute GNN node embeddings."""
feat
=
g
.
ndata
[
'
feat
'
]
sampler
=
MultiLayerFullNeighborSampler
(
1
,
prefetch_node_feats
=
[
'
feat
'
])
feat
=
g
.
ndata
[
"
feat
"
]
sampler
=
MultiLayerFullNeighborSampler
(
1
,
prefetch_node_feats
=
[
"
feat
"
])
dataloader
=
DataLoader
(
g
,
torch
.
arange
(
g
.
num_nodes
()).
to
(
g
.
device
),
sampler
,
device
=
device
,
batch_size
=
batch_size
,
shuffle
=
False
,
drop_last
=
False
,
num_workers
=
0
)
buffer_device
=
torch
.
device
(
'cpu'
)
pin_memory
=
(
buffer_device
!=
device
)
g
,
torch
.
arange
(
g
.
num_nodes
()).
to
(
g
.
device
),
sampler
,
device
=
device
,
batch_size
=
batch_size
,
shuffle
=
False
,
drop_last
=
False
,
num_workers
=
0
,
)
buffer_device
=
torch
.
device
(
"cpu"
)
pin_memory
=
buffer_device
!=
device
for
l
,
layer
in
enumerate
(
self
.
layers
):
y
=
torch
.
empty
(
g
.
num_nodes
(),
self
.
hid_size
,
device
=
buffer_device
,
pin_memory
=
pin_memory
)
y
=
torch
.
empty
(
g
.
num_nodes
(),
self
.
hid_size
,
device
=
buffer_device
,
pin_memory
=
pin_memory
,
)
feat
=
feat
.
to
(
device
)
for
input_nodes
,
output_nodes
,
blocks
in
tqdm
.
tqdm
(
dataloader
,
desc
=
'Inference'
):
for
input_nodes
,
output_nodes
,
blocks
in
tqdm
.
tqdm
(
dataloader
,
desc
=
"Inference"
):
x
=
feat
[
input_nodes
]
h
=
layer
(
blocks
[
0
],
x
)
if
l
!=
len
(
self
.
layers
)
-
1
:
...
...
@@ -81,49 +108,70 @@ class SAGE(nn.Module):
feat
=
y
return
y
def
compute_mrr
(
model
,
evaluator
,
node_emb
,
src
,
dst
,
neg_dst
,
device
,
batch_size
=
500
):
def
compute_mrr
(
model
,
evaluator
,
node_emb
,
src
,
dst
,
neg_dst
,
device
,
batch_size
=
500
):
"""Compute Mean Reciprocal Rank (MRR) in batches."""
rr
=
torch
.
zeros
(
src
.
shape
[
0
])
for
start
in
tqdm
.
trange
(
0
,
src
.
shape
[
0
],
batch_size
,
desc
=
'
Evaluate
'
):
for
start
in
tqdm
.
trange
(
0
,
src
.
shape
[
0
],
batch_size
,
desc
=
"
Evaluate
"
):
end
=
min
(
start
+
batch_size
,
src
.
shape
[
0
])
all_dst
=
torch
.
cat
([
dst
[
start
:
end
,
None
],
neg_dst
[
start
:
end
]],
1
)
h_src
=
node_emb
[
src
[
start
:
end
]][:,
None
,
:].
to
(
device
)
h_dst
=
node_emb
[
all_dst
.
view
(
-
1
)].
view
(
*
all_dst
.
shape
,
-
1
).
to
(
device
)
pred
=
model
.
predictor
(
h_src
*
h_dst
).
squeeze
(
-
1
)
input_dict
=
{
'
y_pred_pos
'
:
pred
[:,
0
],
'
y_pred_neg
'
:
pred
[:,
1
:]}
rr
[
start
:
end
]
=
evaluator
.
eval
(
input_dict
)[
'
mrr_list
'
]
pred
=
model
.
predictor
(
h_src
*
h_dst
).
squeeze
(
-
1
)
input_dict
=
{
"
y_pred_pos
"
:
pred
[:,
0
],
"
y_pred_neg
"
:
pred
[:,
1
:]}
rr
[
start
:
end
]
=
evaluator
.
eval
(
input_dict
)[
"
mrr_list
"
]
return
rr
.
mean
()
def
evaluate
(
device
,
graph
,
edge_split
,
model
,
batch_size
):
model
.
eval
()
evaluator
=
Evaluator
(
name
=
'
ogbl-citation2
'
)
evaluator
=
Evaluator
(
name
=
"
ogbl-citation2
"
)
with
torch
.
no_grad
():
node_emb
=
model
.
inference
(
graph
,
device
,
batch_size
)
results
=
[]
for
split
in
[
'valid'
,
'test'
]:
src
=
edge_split
[
split
][
'source_node'
].
to
(
node_emb
.
device
)
dst
=
edge_split
[
split
][
'target_node'
].
to
(
node_emb
.
device
)
neg_dst
=
edge_split
[
split
][
'target_node_neg'
].
to
(
node_emb
.
device
)
results
.
append
(
compute_mrr
(
model
,
evaluator
,
node_emb
,
src
,
dst
,
neg_dst
,
device
))
for
split
in
[
"valid"
,
"test"
]:
src
=
edge_split
[
split
][
"source_node"
].
to
(
node_emb
.
device
)
dst
=
edge_split
[
split
][
"target_node"
].
to
(
node_emb
.
device
)
neg_dst
=
edge_split
[
split
][
"target_node_neg"
].
to
(
node_emb
.
device
)
results
.
append
(
compute_mrr
(
model
,
evaluator
,
node_emb
,
src
,
dst
,
neg_dst
,
device
)
)
return
results
def
train
(
args
,
device
,
g
,
reverse_eids
,
seed_edges
,
model
):
# create sampler & dataloader
sampler
=
NeighborSampler
([
15
,
10
,
5
],
prefetch_node_feats
=
[
'
feat
'
])
sampler
=
NeighborSampler
([
15
,
10
,
5
],
prefetch_node_feats
=
[
"
feat
"
])
sampler
=
as_edge_prediction_sampler
(
sampler
,
exclude
=
'reverse_id'
,
reverse_eids
=
reverse_eids
,
negative_sampler
=
negative_sampler
.
Uniform
(
1
))
use_uva
=
(
args
.
mode
==
'mixed'
)
sampler
,
exclude
=
"reverse_id"
,
reverse_eids
=
reverse_eids
,
negative_sampler
=
negative_sampler
.
Uniform
(
1
),
)
use_uva
=
args
.
mode
==
"mixed"
dataloader
=
DataLoader
(
g
,
seed_edges
,
sampler
,
device
=
device
,
batch_size
=
512
,
shuffle
=
True
,
drop_last
=
False
,
num_workers
=
0
,
use_uva
=
use_uva
)
g
,
seed_edges
,
sampler
,
device
=
device
,
batch_size
=
512
,
shuffle
=
True
,
drop_last
=
False
,
num_workers
=
0
,
use_uva
=
use_uva
,
)
opt
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
0.0005
)
for
epoch
in
range
(
10
):
model
.
train
()
total_loss
=
0
for
it
,
(
input_nodes
,
pair_graph
,
neg_pair_graph
,
blocks
)
in
enumerate
(
dataloader
):
x
=
blocks
[
0
].
srcdata
[
'feat'
]
for
it
,
(
input_nodes
,
pair_graph
,
neg_pair_graph
,
blocks
)
in
enumerate
(
dataloader
):
x
=
blocks
[
0
].
srcdata
[
"feat"
]
pos_score
,
neg_score
=
model
(
pair_graph
,
neg_pair_graph
,
blocks
,
x
)
score
=
torch
.
cat
([
pos_score
,
neg_score
])
pos_label
=
torch
.
ones_like
(
pos_score
)
...
...
@@ -134,39 +182,51 @@ def train(args, device, g, reverse_eids, seed_edges, model):
loss
.
backward
()
opt
.
step
()
total_loss
+=
loss
.
item
()
if
(
it
+
1
)
==
1000
:
break
print
(
"Epoch {:05d} | Loss {:.4f}"
.
format
(
epoch
,
total_loss
/
(
it
+
1
)))
if
(
it
+
1
)
==
1000
:
break
print
(
"Epoch {:05d} | Loss {:.4f}"
.
format
(
epoch
,
total_loss
/
(
it
+
1
)))
if
__name__
==
'
__main__
'
:
if
__name__
==
"
__main__
"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--mode"
,
default
=
'mixed'
,
choices
=
[
'cpu'
,
'mixed'
,
'puregpu'
],
parser
.
add_argument
(
"--mode"
,
default
=
"mixed"
,
choices
=
[
"cpu"
,
"mixed"
,
"puregpu"
],
help
=
"Training mode. 'cpu' for CPU training, 'mixed' for CPU-GPU mixed training, "
"'puregpu' for pure-GPU training."
)
"'puregpu' for pure-GPU training."
,
)
args
=
parser
.
parse_args
()
if
not
torch
.
cuda
.
is_available
():
args
.
mode
=
'
cpu
'
print
(
f
'
Training in
{
args
.
mode
}
mode.
'
)
args
.
mode
=
"
cpu
"
print
(
f
"
Training in
{
args
.
mode
}
mode.
"
)
# load and preprocess dataset
print
(
'
Loading data
'
)
dataset
=
DglLinkPropPredDataset
(
'
ogbl-citation2
'
)
print
(
"
Loading data
"
)
dataset
=
DglLinkPropPredDataset
(
"
ogbl-citation2
"
)
g
=
dataset
[
0
]
g
=
g
.
to
(
'
cuda
'
if
args
.
mode
==
'
puregpu
'
else
'
cpu
'
)
device
=
torch
.
device
(
'
cpu
'
if
args
.
mode
==
'
cpu
'
else
'
cuda
'
)
g
=
g
.
to
(
"
cuda
"
if
args
.
mode
==
"
puregpu
"
else
"
cpu
"
)
device
=
torch
.
device
(
"
cpu
"
if
args
.
mode
==
"
cpu
"
else
"
cuda
"
)
g
,
reverse_eids
=
to_bidirected_with_reverse_mapping
(
g
)
reverse_eids
=
reverse_eids
.
to
(
device
)
seed_edges
=
torch
.
arange
(
g
.
num_edges
()).
to
(
device
)
edge_split
=
dataset
.
get_edge_split
()
# create GraphSAGE model
in_size
=
g
.
ndata
[
'
feat
'
].
shape
[
1
]
in_size
=
g
.
ndata
[
"
feat
"
].
shape
[
1
]
model
=
SAGE
(
in_size
,
256
).
to
(
device
)
# model training
print
(
'
Training...
'
)
print
(
"
Training...
"
)
train
(
args
,
device
,
g
,
reverse_eids
,
seed_edges
,
model
)
# validate/test the model
print
(
'Validation/Testing...'
)
valid_mrr
,
test_mrr
=
evaluate
(
device
,
g
,
edge_split
,
model
,
batch_size
=
1000
)
print
(
'Validation MRR {:.4f}, Test MRR {:.4f}'
.
format
(
valid_mrr
.
item
(),
test_mrr
.
item
()))
print
(
"Validation/Testing..."
)
valid_mrr
,
test_mrr
=
evaluate
(
device
,
g
,
edge_split
,
model
,
batch_size
=
1000
)
print
(
"Validation MRR {:.4f}, Test MRR {:.4f}"
.
format
(
valid_mrr
.
item
(),
test_mrr
.
item
()
)
)
examples/pytorch/graphsage/load_graph.py
View file @
704bcaf6
import
torch
as
th
import
dgl
import
torch
as
th
def
load_reddit
(
self_loop
=
True
):
...
...
examples/pytorch/graphsage/node_classification.py
View file @
704bcaf6
import
argparse
import
dgl
import
dgl.nn
as
dglnn
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torchmetrics.functional
as
MF
import
dgl
import
dgl.nn
as
dglnn
import
tqdm
from
dgl.data
import
AsNodePredDataset
from
dgl.dataloading
import
DataLoader
,
NeighborSampler
,
MultiLayerFullNeighborSampler
from
dgl.dataloading
import
(
DataLoader
,
MultiLayerFullNeighborSampler
,
NeighborSampler
,
)
from
ogb.nodeproppred
import
DglNodePropPredDataset
import
tqdm
import
argparse
class
SAGE
(
nn
.
Module
):
def
__init__
(
self
,
in_size
,
hid_size
,
out_size
):
super
().
__init__
()
self
.
layers
=
nn
.
ModuleList
()
# three-layer GraphSAGE-mean
self
.
layers
.
append
(
dglnn
.
SAGEConv
(
in_size
,
hid_size
,
'
mean
'
))
self
.
layers
.
append
(
dglnn
.
SAGEConv
(
hid_size
,
hid_size
,
'
mean
'
))
self
.
layers
.
append
(
dglnn
.
SAGEConv
(
hid_size
,
out_size
,
'
mean
'
))
self
.
layers
.
append
(
dglnn
.
SAGEConv
(
in_size
,
hid_size
,
"
mean
"
))
self
.
layers
.
append
(
dglnn
.
SAGEConv
(
hid_size
,
hid_size
,
"
mean
"
))
self
.
layers
.
append
(
dglnn
.
SAGEConv
(
hid_size
,
out_size
,
"
mean
"
))
self
.
dropout
=
nn
.
Dropout
(
0.5
)
self
.
hid_size
=
hid_size
self
.
out_size
=
out_size
...
...
@@ -33,19 +39,28 @@ class SAGE(nn.Module):
def
inference
(
self
,
g
,
device
,
batch_size
):
"""Conduct layer-wise inference to get all the node embeddings."""
feat
=
g
.
ndata
[
'
feat
'
]
sampler
=
MultiLayerFullNeighborSampler
(
1
,
prefetch_node_feats
=
[
'
feat
'
])
feat
=
g
.
ndata
[
"
feat
"
]
sampler
=
MultiLayerFullNeighborSampler
(
1
,
prefetch_node_feats
=
[
"
feat
"
])
dataloader
=
DataLoader
(
g
,
torch
.
arange
(
g
.
num_nodes
()).
to
(
g
.
device
),
sampler
,
device
=
device
,
batch_size
=
batch_size
,
shuffle
=
False
,
drop_last
=
False
,
num_workers
=
0
)
buffer_device
=
torch
.
device
(
'cpu'
)
pin_memory
=
(
buffer_device
!=
device
)
g
,
torch
.
arange
(
g
.
num_nodes
()).
to
(
g
.
device
),
sampler
,
device
=
device
,
batch_size
=
batch_size
,
shuffle
=
False
,
drop_last
=
False
,
num_workers
=
0
,
)
buffer_device
=
torch
.
device
(
"cpu"
)
pin_memory
=
buffer_device
!=
device
for
l
,
layer
in
enumerate
(
self
.
layers
):
y
=
torch
.
empty
(
g
.
num_nodes
(),
self
.
hid_size
if
l
!=
len
(
self
.
layers
)
-
1
else
self
.
out_size
,
device
=
buffer_device
,
pin_memory
=
pin_memory
)
g
.
num_nodes
(),
self
.
hid_size
if
l
!=
len
(
self
.
layers
)
-
1
else
self
.
out_size
,
device
=
buffer_device
,
pin_memory
=
pin_memory
,
)
feat
=
feat
.
to
(
device
)
for
input_nodes
,
output_nodes
,
blocks
in
tqdm
.
tqdm
(
dataloader
):
x
=
feat
[
input_nodes
]
...
...
@@ -54,55 +69,78 @@ class SAGE(nn.Module):
h
=
F
.
relu
(
h
)
h
=
self
.
dropout
(
h
)
# by design, our output nodes are contiguous
y
[
output_nodes
[
0
]
:
output_nodes
[
-
1
]
+
1
]
=
h
.
to
(
buffer_device
)
y
[
output_nodes
[
0
]
:
output_nodes
[
-
1
]
+
1
]
=
h
.
to
(
buffer_device
)
feat
=
y
return
y
def
evaluate
(
model
,
graph
,
dataloader
):
model
.
eval
()
ys
=
[]
y_hats
=
[]
for
it
,
(
input_nodes
,
output_nodes
,
blocks
)
in
enumerate
(
dataloader
):
with
torch
.
no_grad
():
x
=
blocks
[
0
].
srcdata
[
'
feat
'
]
ys
.
append
(
blocks
[
-
1
].
dstdata
[
'
label
'
])
x
=
blocks
[
0
].
srcdata
[
"
feat
"
]
ys
.
append
(
blocks
[
-
1
].
dstdata
[
"
label
"
])
y_hats
.
append
(
model
(
blocks
,
x
))
return
MF
.
accuracy
(
torch
.
cat
(
y_hats
),
torch
.
cat
(
ys
))
def
layerwise_infer
(
device
,
graph
,
nid
,
model
,
batch_size
):
model
.
eval
()
with
torch
.
no_grad
():
pred
=
model
.
inference
(
graph
,
device
,
batch_size
)
# pred in buffer_device
pred
=
model
.
inference
(
graph
,
device
,
batch_size
)
# pred in buffer_device
pred
=
pred
[
nid
]
label
=
graph
.
ndata
[
'
label
'
][
nid
].
to
(
pred
.
device
)
label
=
graph
.
ndata
[
"
label
"
][
nid
].
to
(
pred
.
device
)
return
MF
.
accuracy
(
pred
,
label
)
def
train
(
args
,
device
,
g
,
dataset
,
model
):
# create sampler & dataloader
train_idx
=
dataset
.
train_idx
.
to
(
device
)
val_idx
=
dataset
.
val_idx
.
to
(
device
)
sampler
=
NeighborSampler
([
10
,
10
,
10
],
# fanout for [layer-0, layer-1, layer-2]
prefetch_node_feats
=
[
'feat'
],
prefetch_labels
=
[
'label'
])
use_uva
=
(
args
.
mode
==
'mixed'
)
train_dataloader
=
DataLoader
(
g
,
train_idx
,
sampler
,
device
=
device
,
batch_size
=
1024
,
shuffle
=
True
,
drop_last
=
False
,
num_workers
=
0
,
use_uva
=
use_uva
)
val_dataloader
=
DataLoader
(
g
,
val_idx
,
sampler
,
device
=
device
,
batch_size
=
1024
,
shuffle
=
True
,
drop_last
=
False
,
num_workers
=
0
,
use_uva
=
use_uva
)
sampler
=
NeighborSampler
(
[
10
,
10
,
10
],
# fanout for [layer-0, layer-1, layer-2]
prefetch_node_feats
=
[
"feat"
],
prefetch_labels
=
[
"label"
],
)
use_uva
=
args
.
mode
==
"mixed"
train_dataloader
=
DataLoader
(
g
,
train_idx
,
sampler
,
device
=
device
,
batch_size
=
1024
,
shuffle
=
True
,
drop_last
=
False
,
num_workers
=
0
,
use_uva
=
use_uva
,
)
val_dataloader
=
DataLoader
(
g
,
val_idx
,
sampler
,
device
=
device
,
batch_size
=
1024
,
shuffle
=
True
,
drop_last
=
False
,
num_workers
=
0
,
use_uva
=
use_uva
,
)
opt
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
1e-3
,
weight_decay
=
5e-4
)
for
epoch
in
range
(
10
):
model
.
train
()
total_loss
=
0
for
it
,
(
input_nodes
,
output_nodes
,
blocks
)
in
enumerate
(
train_dataloader
):
x
=
blocks
[
0
].
srcdata
[
'feat'
]
y
=
blocks
[
-
1
].
dstdata
[
'label'
]
for
it
,
(
input_nodes
,
output_nodes
,
blocks
)
in
enumerate
(
train_dataloader
):
x
=
blocks
[
0
].
srcdata
[
"feat"
]
y
=
blocks
[
-
1
].
dstdata
[
"label"
]
y_hat
=
model
(
blocks
,
x
)
loss
=
F
.
cross_entropy
(
y_hat
,
y
)
opt
.
zero_grad
()
...
...
@@ -110,36 +148,44 @@ def train(args, device, g, dataset, model):
opt
.
step
()
total_loss
+=
loss
.
item
()
acc
=
evaluate
(
model
,
g
,
val_dataloader
)
print
(
"Epoch {:05d} | Loss {:.4f} | Accuracy {:.4f} "
.
format
(
epoch
,
total_loss
/
(
it
+
1
),
acc
.
item
()))
print
(
"Epoch {:05d} | Loss {:.4f} | Accuracy {:.4f} "
.
format
(
epoch
,
total_loss
/
(
it
+
1
),
acc
.
item
()
)
)
if
__name__
==
'
__main__
'
:
if
__name__
==
"
__main__
"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--mode"
,
default
=
'mixed'
,
choices
=
[
'cpu'
,
'mixed'
,
'puregpu'
],
parser
.
add_argument
(
"--mode"
,
default
=
"mixed"
,
choices
=
[
"cpu"
,
"mixed"
,
"puregpu"
],
help
=
"Training mode. 'cpu' for CPU training, 'mixed' for CPU-GPU mixed training, "
"'puregpu' for pure-GPU training."
)
"'puregpu' for pure-GPU training."
,
)
args
=
parser
.
parse_args
()
if
not
torch
.
cuda
.
is_available
():
args
.
mode
=
'
cpu
'
print
(
f
'
Training in
{
args
.
mode
}
mode.
'
)
args
.
mode
=
"
cpu
"
print
(
f
"
Training in
{
args
.
mode
}
mode.
"
)
# load and preprocess dataset
print
(
'
Loading data
'
)
dataset
=
AsNodePredDataset
(
DglNodePropPredDataset
(
'
ogbn-products
'
))
print
(
"
Loading data
"
)
dataset
=
AsNodePredDataset
(
DglNodePropPredDataset
(
"
ogbn-products
"
))
g
=
dataset
[
0
]
g
=
g
.
to
(
'
cuda
'
if
args
.
mode
==
'
puregpu
'
else
'
cpu
'
)
device
=
torch
.
device
(
'
cpu
'
if
args
.
mode
==
'
cpu
'
else
'
cuda
'
)
g
=
g
.
to
(
"
cuda
"
if
args
.
mode
==
"
puregpu
"
else
"
cpu
"
)
device
=
torch
.
device
(
"
cpu
"
if
args
.
mode
==
"
cpu
"
else
"
cuda
"
)
# create GraphSAGE model
in_size
=
g
.
ndata
[
'
feat
'
].
shape
[
1
]
in_size
=
g
.
ndata
[
"
feat
"
].
shape
[
1
]
out_size
=
dataset
.
num_classes
model
=
SAGE
(
in_size
,
256
,
out_size
).
to
(
device
)
# model training
print
(
'
Training...
'
)
print
(
"
Training...
"
)
train
(
args
,
device
,
g
,
dataset
,
model
)
# test the model
print
(
'
Testing...
'
)
print
(
"
Testing...
"
)
acc
=
layerwise_infer
(
device
,
g
,
dataset
.
test_idx
,
model
,
batch_size
=
4096
)
print
(
"Test Accuracy {:.4f}"
.
format
(
acc
.
item
()))
examples/pytorch/graphsage/train_full.py
View file @
704bcaf6
import
argparse
import
dgl.nn
as
dglnn
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
dgl.nn
as
dglnn
from
dgl
import
AddSelfLoop
from
dgl.data
import
CiteseerGraphDataset
,
CoraGraphDataset
,
PubmedGraphDataset
...
...
examples/pytorch/graphsaint/modules.py
View file @
704bcaf6
import
dgl.function
as
fn
import
torch
as
th
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
dgl.function
as
fn
class
GCNLayer
(
nn
.
Module
):
def
__init__
(
...
...
examples/pytorch/graphsaint/sampler.py
View file @
704bcaf6
...
...
@@ -3,14 +3,14 @@ import os
import
random
import
time
import
dgl
import
dgl.function
as
fn
import
numpy
as
np
import
scipy
import
torch
as
th
from
torch.utils.data
import
DataLoader
import
dgl
import
dgl.function
as
fn
from
dgl.sampling
import
pack_traces
,
random_walk
from
torch.utils.data
import
DataLoader
# The base class of sampler
...
...
@@ -123,7 +123,6 @@ class SAINTSampler:
t
=
time
.
perf_counter
()
for
num_nodes
,
subgraphs_nids
,
subgraphs_eids
in
loader
:
self
.
subgraphs
.
extend
(
subgraphs_nids
)
sampled_nodes
+=
num_nodes
...
...
@@ -214,7 +213,6 @@ class SAINTSampler:
raise
NotImplementedError
def
__compute_norm__
(
self
):
self
.
node_counter
[
self
.
node_counter
==
0
]
=
1
self
.
edge_counter
[
self
.
edge_counter
==
0
]
=
1
...
...
@@ -231,7 +229,6 @@ class SAINTSampler:
return
aggr_norm
.
numpy
(),
loss_norm
.
numpy
()
def
__compute_degree_norm
(
self
):
self
.
train_g
.
ndata
[
"train_D_norm"
]
=
1.0
/
self
.
train_g
.
in_degrees
().
float
().
clamp
(
min
=
1
).
unsqueeze
(
1
)
...
...
examples/pytorch/graphsaint/train_sampling.py
View file @
704bcaf6
...
...
@@ -9,7 +9,7 @@ from config import CONFIG
from
modules
import
GCNNet
from
sampler
import
SAINTEdgeSampler
,
SAINTNodeSampler
,
SAINTRandomWalkSampler
from
torch.utils.data
import
DataLoader
from
utils
import
Logger
,
calc_f1
,
evaluate
,
load_data
,
save_log_dir
from
utils
import
calc_f1
,
evaluate
,
load_data
,
Logger
,
save_log_dir
def
main
(
args
,
task
):
...
...
examples/pytorch/graphsaint/utils.py
View file @
704bcaf6
...
...
@@ -2,14 +2,14 @@ import json
import
os
from
functools
import
namedtuple
import
dgl
import
numpy
as
np
import
scipy.sparse
import
torch
from
sklearn.metrics
import
f1_score
from
sklearn.preprocessing
import
StandardScaler
import
dgl
class
Logger
(
object
):
"""A custom logger to log stdout to a logging file."""
...
...
examples/pytorch/graphsim/dataloader.py
View file @
704bcaf6
import
copy
import
os
import
dgl
import
networkx
as
nx
import
numpy
as
np
import
torch
from
torch.utils.data
import
DataLoader
,
Dataset
import
dgl
def
build_dense_graph
(
n_particles
):
g
=
nx
.
complete_graph
(
n_particles
)
...
...
examples/pytorch/graphsim/models.py
View file @
704bcaf6
import
copy
from
functools
import
partial
import
torch
import
torch.nn
as
nn
from
torch.nn
import
functional
as
F
import
dgl
import
dgl.function
as
fn
import
dgl.nn
as
dglnn
import
torch
import
torch.nn
as
nn
from
torch.nn
import
functional
as
F
class
MLP
(
nn
.
Module
):
def
__init__
(
self
,
in_feats
,
out_feats
,
num_layers
=
2
,
hidden
=
128
):
...
...
Prev
1
…
3
4
5
6
7
8
9
10
11
…
17
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment