Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
2d8d6fbb
Unverified
Commit
2d8d6fbb
authored
Nov 22, 2023
by
Mingbang Wang
Committed by
GitHub
Nov 22, 2023
Browse files
[Misc] Modify two node examples to make them consitent. (#6599)
parent
a664b0c4
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
15 additions
and
7 deletions
+15
-7
examples/sampling/graphbolt/node_classification.py
examples/sampling/graphbolt/node_classification.py
+4
-4
examples/sampling/node_classification.py
examples/sampling/node_classification.py
+11
-3
No files found.
examples/sampling/graphbolt/node_classification.py
View file @
2d8d6fbb
...
...
@@ -207,7 +207,7 @@ class SAGE(nn.Module):
y
=
torch
.
empty
(
graph
.
total_num_nodes
,
self
.
out_size
if
is_last_layer
else
self
.
hidden_size
,
dtype
=
torch
.
float
64
,
dtype
=
torch
.
float
32
,
device
=
buffer_device
,
pin_memory
=
pin_memory
,
)
...
...
@@ -215,7 +215,7 @@ class SAGE(nn.Module):
for
step
,
data
in
tqdm
(
enumerate
(
dataloader
)):
x
=
feature
[
data
.
input_nodes
]
hidden_x
=
layer
(
data
.
blocks
[
0
],
x
.
float
()
)
# len(blocks) = 1
hidden_x
=
layer
(
data
.
blocks
[
0
],
x
)
# len(blocks) = 1
if
not
is_last_layer
:
hidden_x
=
F
.
relu
(
hidden_x
)
hidden_x
=
self
.
dropout
(
hidden_x
)
...
...
@@ -274,7 +274,7 @@ def evaluate(args, model, graph, features, itemset, num_classes):
for
step
,
data
in
tqdm
(
enumerate
(
dataloader
)):
x
=
data
.
node_features
[
"feat"
]
y
.
append
(
data
.
labels
)
y_hats
.
append
(
model
(
data
.
blocks
,
x
.
float
()
))
y_hats
.
append
(
model
(
data
.
blocks
,
x
))
return
MF
.
accuracy
(
torch
.
cat
(
y_hats
),
...
...
@@ -310,7 +310,7 @@ def train(args, graph, features, train_set, valid_set, num_classes, model):
# in the last layer's computation graph.
y
=
data
.
labels
y_hat
=
model
(
data
.
blocks
,
x
.
float
()
)
y_hat
=
model
(
data
.
blocks
,
x
)
# Compute loss.
loss
=
F
.
cross_entropy
(
y_hat
,
y
)
...
...
examples/sampling/node_classification.py
View file @
2d8d6fbb
...
...
@@ -274,10 +274,16 @@ if __name__ == "__main__":
parser
.
add_argument
(
"--mode"
,
default
=
"mixed"
,
choices
=
[
"cpu"
,
"mixed"
,
"gpu"
,
"compare-to-graphbolt"
],
choices
=
[
"cpu"
,
"mixed"
,
"gpu"
],
help
=
"Training mode. 'cpu' for CPU training, 'mixed' for "
"CPU-GPU mixed training, 'gpu' for pure-GPU training."
,
)
parser
.
add_argument
(
"--compare-to-graphbolt"
,
default
=
"false"
,
choices
=
[
"false"
,
"true"
],
help
=
"Whether comparing to GraphBolt or not, 'false' by default."
,
)
args
=
parser
.
parse_args
()
if
not
torch
.
cuda
.
is_available
():
args
.
mode
=
"cpu"
...
...
@@ -286,13 +292,15 @@ if __name__ == "__main__":
# Load and preprocess dataset.
print
(
"Loading data"
)
dataset
=
AsNodePredDataset
(
DglNodePropPredDataset
(
"ogbn-products"
))
g
=
dataset
[
0
]
if
args
.
compare_to_graphbolt
==
"false"
:
g
=
g
.
to
(
"cuda"
if
args
.
mode
==
"gpu"
else
"cpu"
)
num_classes
=
dataset
.
num_classes
# Whether use Unified Virtual Addressing (UVA) for CUDA computation.
use_uva
=
args
.
mode
==
"mixed"
device
=
torch
.
device
(
"cpu"
if
args
.
mode
==
"cpu"
else
"cuda"
)
fused_sampling
=
args
.
mode
!=
"
compare
-
to
-
graphbolt"
fused_sampling
=
args
.
compare
_
to
_
graphbolt
==
"false
"
# Create GraphSAGE model.
in_size
=
g
.
ndata
[
"feat"
].
shape
[
1
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment