Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
a8e672e7
Unverified
Commit
a8e672e7
authored
Nov 24, 2023
by
LastWhisper
Committed by
GitHub
Nov 24, 2023
Browse files
[Benchmark] Align the DGL/Graphbolt link pred examples (#6609)
parent
77ec365d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
56 additions
and
13 deletions
+56
-13
examples/sampling/graphbolt/link_prediction.py
examples/sampling/graphbolt/link_prediction.py
+8
-2
examples/sampling/link_prediction.py
examples/sampling/link_prediction.py
+48
-11
No files found.
examples/sampling/graphbolt/link_prediction.py
View file @
a8e672e7
...
@@ -189,7 +189,7 @@ def create_dataloader(args, graph, features, itemset, is_train=True):
...
@@ -189,7 +189,7 @@ def create_dataloader(args, graph, features, itemset, is_train=True):
# to ensure that positive samples are not inadvertently included within
# to ensure that positive samples are not inadvertently included within
# the negative samples.
# the negative samples.
############################################################################
############################################################################
if
is_train
:
if
is_train
and
args
.
exclude_edges
:
datapipe
=
datapipe
.
transform
(
gb
.
exclude_seed_edges
)
datapipe
=
datapipe
.
transform
(
gb
.
exclude_seed_edges
)
############################################################################
############################################################################
...
@@ -369,7 +369,7 @@ def parse_args():
...
@@ -369,7 +369,7 @@ def parse_args():
parser
.
add_argument
(
"--neg-ratio"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"--neg-ratio"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"--train-batch-size"
,
type
=
int
,
default
=
512
)
parser
.
add_argument
(
"--train-batch-size"
,
type
=
int
,
default
=
512
)
parser
.
add_argument
(
"--eval-batch-size"
,
type
=
int
,
default
=
1024
)
parser
.
add_argument
(
"--eval-batch-size"
,
type
=
int
,
default
=
1024
)
parser
.
add_argument
(
"--num-workers"
,
type
=
int
,
default
=
4
)
parser
.
add_argument
(
"--num-workers"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
parser
.
add_argument
(
"--early-stop"
,
"--early-stop"
,
type
=
int
,
type
=
int
,
...
@@ -382,6 +382,12 @@ def parse_args():
...
@@ -382,6 +382,12 @@ def parse_args():
default
=
"15,10,5"
,
default
=
"15,10,5"
,
help
=
"Fan-out of neighbor sampling. Default: 15,10,5"
,
help
=
"Fan-out of neighbor sampling. Default: 15,10,5"
,
)
)
parser
.
add_argument
(
"--exclude-edges"
,
type
=
int
,
default
=
1
,
help
=
"Whether to exclude reverse edges during sampling. Default: 1"
,
)
parser
.
add_argument
(
parser
.
add_argument
(
"--device"
,
"--device"
,
default
=
"cpu"
,
default
=
"cpu"
,
...
...
examples/sampling/link_prediction.py
View file @
a8e672e7
...
@@ -283,7 +283,9 @@ def evaluate(device, graph, edge_split, model, batch_size):
...
@@ -283,7 +283,9 @@ def evaluate(device, graph, edge_split, model, batch_size):
return
results
return
results
def
train
(
args
,
device
,
g
,
reverse_eids
,
seed_edges
,
model
,
use_uva
):
def
train
(
args
,
device
,
g
,
reverse_eids
,
seed_edges
,
model
,
use_uva
,
fused_sampling
):
#####################################################################
#####################################################################
# (HIGHLIGHT) Instantiate a NeighborSampler object for efficient
# (HIGHLIGHT) Instantiate a NeighborSampler object for efficient
# training of Graph Neural Networks (GNNs) on large-scale graphs.
# training of Graph Neural Networks (GNNs) on large-scale graphs.
...
@@ -320,11 +322,15 @@ def train(args, device, g, reverse_eids, seed_edges, model, use_uva):
...
@@ -320,11 +322,15 @@ def train(args, device, g, reverse_eids, seed_edges, model, use_uva):
# not just to learn node representations, but also to predict the
# not just to learn node representations, but also to predict the
# likelihood of an edge existing between two nodes (link prediction).
# likelihood of an edge existing between two nodes (link prediction).
#####################################################################
#####################################################################
sampler
=
NeighborSampler
([
15
,
10
,
5
],
prefetch_node_feats
=
[
"feat"
])
sampler
=
NeighborSampler
(
[
15
,
10
,
5
],
prefetch_node_feats
=
[
"feat"
],
fused
=
fused_sampling
,
)
sampler
=
as_edge_prediction_sampler
(
sampler
=
as_edge_prediction_sampler
(
sampler
,
sampler
,
exclude
=
"reverse_id"
,
exclude
=
"reverse_id"
if
args
.
exclude_edges
else
None
,
reverse_eids
=
reverse_eids
,
reverse_eids
=
reverse_eids
if
args
.
exclude_edges
else
None
,
negative_sampler
=
negative_sampler
.
Uniform
(
1
),
negative_sampler
=
negative_sampler
.
Uniform
(
1
),
)
)
...
@@ -333,7 +339,7 @@ def train(args, device, g, reverse_eids, seed_edges, model, use_uva):
...
@@ -333,7 +339,7 @@ def train(args, device, g, reverse_eids, seed_edges, model, use_uva):
seed_edges
,
seed_edges
,
sampler
,
sampler
,
device
=
device
,
device
=
device
,
batch_size
=
args
.
batch_size
,
batch_size
=
args
.
train_
batch_size
,
shuffle
=
True
,
shuffle
=
True
,
drop_last
=
False
,
drop_last
=
False
,
# If `g` is on gpu or `use_uva` is True, `num_workers` must be zero,
# If `g` is on gpu or `use_uva` is True, `num_workers` must be zero,
...
@@ -342,7 +348,7 @@ def train(args, device, g, reverse_eids, seed_edges, model, use_uva):
...
@@ -342,7 +348,7 @@ def train(args, device, g, reverse_eids, seed_edges, model, use_uva):
use_uva
=
use_uva
,
use_uva
=
use_uva
,
)
)
opt
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
args
.
lr
)
opt
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
args
.
lr
)
for
epoch
in
range
(
10
):
for
epoch
in
range
(
args
.
epochs
):
model
.
train
()
model
.
train
()
total_loss
=
0
total_loss
=
0
# A block is a graph consisting of two sets of nodes: the
# A block is a graph consisting of two sets of nodes: the
...
@@ -377,6 +383,7 @@ def train(args, device, g, reverse_eids, seed_edges, model, use_uva):
...
@@ -377,6 +383,7 @@ def train(args, device, g, reverse_eids, seed_edges, model, use_uva):
def
parse_args
():
def
parse_args
():
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--epochs"
,
type
=
int
,
default
=
10
)
parser
.
add_argument
(
parser
.
add_argument
(
"--lr"
,
"--lr"
,
type
=
float
,
type
=
float
,
...
@@ -384,10 +391,16 @@ def parse_args():
...
@@ -384,10 +391,16 @@ def parse_args():
help
=
"Learning rate. Default: 0.0005"
,
help
=
"Learning rate. Default: 0.0005"
,
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--batch-size"
,
"--
train-
batch-size"
,
type
=
int
,
type
=
int
,
default
=
512
,
default
=
512
,
help
=
"Batch size. Default: 512"
,
help
=
"Batch size for training. Default: 512"
,
)
parser
.
add_argument
(
"--eval-batch-size"
,
type
=
int
,
default
=
1024
,
help
=
"Batch size during evaluation. Default: 1024"
,
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--early-stop"
,
"--early-stop"
,
...
@@ -395,6 +408,17 @@ def parse_args():
...
@@ -395,6 +408,17 @@ def parse_args():
default
=
0
,
default
=
0
,
help
=
"0 means no early stop, otherwise stop at the input-th step"
,
help
=
"0 means no early stop, otherwise stop at the input-th step"
,
)
)
parser
.
add_argument
(
"--exclude-edges"
,
type
=
int
,
default
=
1
,
help
=
"Whether to exclude reverse edges during sampling. Default: 1"
,
)
parser
.
add_argument
(
"--compare-graphbolt"
,
action
=
"store_true"
,
help
=
"Compare with GraphBolt"
,
)
parser
.
add_argument
(
parser
.
add_argument
(
"--mode"
,
"--mode"
,
default
=
"mixed"
,
default
=
"mixed"
,
...
@@ -414,7 +438,11 @@ def main(args):
...
@@ -414,7 +438,11 @@ def main(args):
print
(
"Loading data"
)
print
(
"Loading data"
)
dataset
=
DglLinkPropPredDataset
(
"ogbl-citation2"
)
dataset
=
DglLinkPropPredDataset
(
"ogbl-citation2"
)
g
=
dataset
[
0
]
g
=
dataset
[
0
]
g
=
g
.
to
(
"cuda"
if
args
.
mode
==
"puregpu"
else
"cpu"
)
if
args
.
compare_graphbolt
:
fused_sampling
=
False
else
:
fused_sampling
=
True
g
=
g
.
to
(
"cuda"
if
args
.
mode
==
"puregpu"
else
"cpu"
)
# Whether use Unified Virtual Addressing (UVA) for CUDA computation.
# Whether use Unified Virtual Addressing (UVA) for CUDA computation.
use_uva
=
args
.
mode
==
"mixed"
use_uva
=
args
.
mode
==
"mixed"
...
@@ -432,12 +460,21 @@ def main(args):
...
@@ -432,12 +460,21 @@ def main(args):
# Model training.
# Model training.
print
(
"Training..."
)
print
(
"Training..."
)
train
(
args
,
device
,
g
,
reverse_eids
,
seed_edges
,
model
,
use_uva
)
train
(
args
,
device
,
g
,
reverse_eids
,
seed_edges
,
model
,
use_uva
,
fused_sampling
,
)
# Validate/Test the model.
# Validate/Test the model.
print
(
"Validation/Testing..."
)
print
(
"Validation/Testing..."
)
valid_mrr
,
test_mrr
=
evaluate
(
valid_mrr
,
test_mrr
=
evaluate
(
device
,
g
,
edge_split
,
model
,
batch_size
=
1000
device
,
g
,
edge_split
,
model
,
batch_size
=
args
.
eval_batch_size
)
)
print
(
print
(
f
"Validation MRR
{
valid_mrr
.
item
():.
4
f
}
, Test MRR
{
test_mrr
.
item
():.
4
f
}
"
f
"Validation MRR
{
valid_mrr
.
item
():.
4
f
}
, Test MRR
{
test_mrr
.
item
():.
4
f
}
"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment