Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
b452043c
Unverified
Commit
b452043c
authored
Nov 15, 2023
by
Mingbang Wang
Committed by
GitHub
Nov 15, 2023
Browse files
[Misc] Add compare-to-graphbolt mode for regression test (#6569)
parent
23649071
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
24 additions
and
10 deletions
+24
-10
examples/sampling/node_classification.py
examples/sampling/node_classification.py
+24
-10
No files found.
examples/sampling/node_classification.py
View file @
b452043c
...
@@ -75,7 +75,7 @@ class SAGE(nn.Module):
...
@@ -75,7 +75,7 @@ class SAGE(nn.Module):
hidden_x
=
self
.
dropout
(
hidden_x
)
hidden_x
=
self
.
dropout
(
hidden_x
)
return
hidden_x
return
hidden_x
def
inference
(
self
,
g
,
device
,
batch_size
):
def
inference
(
self
,
g
,
device
,
batch_size
,
fused_sampling
:
bool
=
True
):
"""Conduct layer-wise inference to get all the node embeddings."""
"""Conduct layer-wise inference to get all the node embeddings."""
feat
=
g
.
ndata
[
"feat"
]
feat
=
g
.
ndata
[
"feat"
]
#####################################################################
#####################################################################
...
@@ -109,7 +109,9 @@ class SAGE(nn.Module):
...
@@ -109,7 +109,9 @@ class SAGE(nn.Module):
# │ │ │
# │ │ │
# └─Compute1 └─Compute2 └─Compute3
# └─Compute1 └─Compute2 └─Compute3
#####################################################################
#####################################################################
sampler
=
MultiLayerFullNeighborSampler
(
1
,
prefetch_node_feats
=
[
"feat"
])
sampler
=
MultiLayerFullNeighborSampler
(
1
,
prefetch_node_feats
=
[
"feat"
],
fused
=
fused_sampling
)
dataloader
=
DataLoader
(
dataloader
=
DataLoader
(
g
,
g
,
...
@@ -167,18 +169,22 @@ def evaluate(model, graph, dataloader, num_classes):
...
@@ -167,18 +169,22 @@ def evaluate(model, graph, dataloader, num_classes):
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
layerwise_infer
(
device
,
graph
,
nid
,
model
,
num_classes
,
batch_size
):
def
layerwise_infer
(
device
,
graph
,
nid
,
model
,
num_classes
,
batch_size
,
fused_sampling
):
model
.
eval
()
model
.
eval
()
pred
=
model
.
inference
(
graph
,
device
,
batch_size
)
# pred in buffer_device.
pred
=
model
.
inference
(
graph
,
device
,
batch_size
,
fused_sampling
)
# pred in buffer_device.
pred
=
pred
[
nid
]
pred
=
pred
[
nid
]
label
=
graph
.
ndata
[
"label"
][
nid
].
to
(
pred
.
device
)
label
=
graph
.
ndata
[
"label"
][
nid
].
to
(
pred
.
device
)
return
MF
.
accuracy
(
pred
,
label
,
task
=
"multiclass"
,
num_classes
=
num_classes
)
return
MF
.
accuracy
(
pred
,
label
,
task
=
"multiclass"
,
num_classes
=
num_classes
)
def
train
(
args
,
device
,
g
,
dataset
,
model
,
num_classes
,
use_uva
):
def
train
(
device
,
g
,
dataset
,
model
,
num_classes
,
use_uva
,
fused_sampling
):
# Create sampler & dataloader.
# Create sampler & dataloader.
train_idx
=
dataset
.
train_idx
.
to
(
device
)
train_idx
=
dataset
.
train_idx
.
to
(
g
.
device
if
not
use_uva
else
device
)
val_idx
=
dataset
.
val_idx
.
to
(
device
)
val_idx
=
dataset
.
val_idx
.
to
(
g
.
device
if
not
use_uva
else
device
)
#####################################################################
#####################################################################
# (HIGHLIGHT) Instantiate a NeighborSampler object for efficient
# (HIGHLIGHT) Instantiate a NeighborSampler object for efficient
# training of Graph Neural Networks (GNNs) on large-scale graphs.
# training of Graph Neural Networks (GNNs) on large-scale graphs.
...
@@ -197,6 +203,7 @@ def train(args, device, g, dataset, model, num_classes, use_uva):
...
@@ -197,6 +203,7 @@ def train(args, device, g, dataset, model, num_classes, use_uva):
[
10
,
10
,
10
],
# fanout for [layer-0, layer-1, layer-2]
[
10
,
10
,
10
],
# fanout for [layer-0, layer-1, layer-2]
prefetch_node_feats
=
[
"feat"
],
prefetch_node_feats
=
[
"feat"
],
prefetch_labels
=
[
"label"
],
prefetch_labels
=
[
"label"
],
fused
=
fused_sampling
,
)
)
train_dataloader
=
DataLoader
(
train_dataloader
=
DataLoader
(
...
@@ -267,7 +274,7 @@ if __name__ == "__main__":
...
@@ -267,7 +274,7 @@ if __name__ == "__main__":
parser
.
add_argument
(
parser
.
add_argument
(
"--mode"
,
"--mode"
,
default
=
"mixed"
,
default
=
"mixed"
,
choices
=
[
"cpu"
,
"mixed"
,
"gpu"
],
choices
=
[
"cpu"
,
"mixed"
,
"gpu"
,
"compare-to-graphbolt"
],
help
=
"Training mode. 'cpu' for CPU training, 'mixed' for "
help
=
"Training mode. 'cpu' for CPU training, 'mixed' for "
"CPU-GPU mixed training, 'gpu' for pure-GPU training."
,
"CPU-GPU mixed training, 'gpu' for pure-GPU training."
,
)
)
...
@@ -285,6 +292,7 @@ if __name__ == "__main__":
...
@@ -285,6 +292,7 @@ if __name__ == "__main__":
# Whether use Unified Virtual Addressing (UVA) for CUDA computation.
# Whether use Unified Virtual Addressing (UVA) for CUDA computation.
use_uva
=
args
.
mode
==
"mixed"
use_uva
=
args
.
mode
==
"mixed"
device
=
torch
.
device
(
"cpu"
if
args
.
mode
==
"cpu"
else
"cuda"
)
device
=
torch
.
device
(
"cpu"
if
args
.
mode
==
"cpu"
else
"cuda"
)
fused_sampling
=
args
.
mode
!=
"compare-to-graphbolt"
# Create GraphSAGE model.
# Create GraphSAGE model.
in_size
=
g
.
ndata
[
"feat"
].
shape
[
1
]
in_size
=
g
.
ndata
[
"feat"
].
shape
[
1
]
...
@@ -293,11 +301,17 @@ if __name__ == "__main__":
...
@@ -293,11 +301,17 @@ if __name__ == "__main__":
# Model training.
# Model training.
print
(
"Training..."
)
print
(
"Training..."
)
train
(
args
,
device
,
g
,
dataset
,
model
,
num_classes
,
use_uva
)
train
(
device
,
g
,
dataset
,
model
,
num_classes
,
use_uva
,
fused_sampling
)
# Test the model.
# Test the model.
print
(
"Testing..."
)
print
(
"Testing..."
)
acc
=
layerwise_infer
(
acc
=
layerwise_infer
(
device
,
g
,
dataset
.
test_idx
,
model
,
num_classes
,
batch_size
=
4096
device
,
g
,
dataset
.
test_idx
,
model
,
num_classes
,
batch_size
=
4096
,
fused_sampling
=
fused_sampling
,
)
)
print
(
f
"Test accuracy
{
acc
.
item
():.
4
f
}
"
)
print
(
f
"Test accuracy
{
acc
.
item
():.
4
f
}
"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment