Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
4b6faecb
Unverified
Commit
4b6faecb
authored
Nov 01, 2023
by
Mingbang Wang
Committed by
GitHub
Nov 01, 2023
Browse files
[Misc] Modify two examples to make them consistent (#6513)
parent
96226c61
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
11 additions
and
8 deletions
+11
-8
examples/sampling/graphbolt/node_classification.py
examples/sampling/graphbolt/node_classification.py
+6
-6
examples/sampling/node_classification.py
examples/sampling/node_classification.py
+5
-2
No files found.
examples/sampling/graphbolt/node_classification.py
View file @
4b6faecb
...
...
@@ -181,7 +181,7 @@ class SAGE(nn.Module):
self
.
hidden_size
=
hidden_size
self
.
out_size
=
out_size
# Set the dtype for the layers manually.
self
.
set_layer_dtype
(
torch
.
float
64
)
self
.
set_layer_dtype
(
torch
.
float
32
)
def
set_layer_dtype
(
self
,
_dtype
):
for
layer
in
self
.
layers
:
...
...
@@ -221,7 +221,7 @@ class SAGE(nn.Module):
for
step
,
data
in
tqdm
(
enumerate
(
dataloader
)):
x
=
feature
[
data
.
input_nodes
]
hidden_x
=
layer
(
data
.
blocks
[
0
],
x
)
# len(blocks) = 1
hidden_x
=
layer
(
data
.
blocks
[
0
],
x
.
float
()
)
# len(blocks) = 1
if
not
is_last_layer
:
hidden_x
=
F
.
relu
(
hidden_x
)
hidden_x
=
self
.
dropout
(
hidden_x
)
...
...
@@ -266,7 +266,7 @@ def evaluate(args, model, graph, features, itemset, num_classes):
for
step
,
data
in
tqdm
(
enumerate
(
dataloader
)):
x
=
data
.
node_features
[
"feat"
]
y
.
append
(
data
.
labels
)
y_hats
.
append
(
model
(
data
.
blocks
,
x
))
y_hats
.
append
(
model
(
data
.
blocks
,
x
.
float
()
))
return
MF
.
accuracy
(
torch
.
cat
(
y_hats
),
...
...
@@ -286,7 +286,7 @@ def train(args, graph, features, train_set, valid_set, num_classes, model):
t0
=
time
.
time
()
model
.
train
()
total_loss
=
0
for
step
,
data
in
tqdm
(
enumerate
(
dataloader
)
)
:
for
step
,
data
in
enumerate
(
dataloader
):
# The input features from the source nodes in the first layer's
# computation graph.
x
=
data
.
node_features
[
"feat"
]
...
...
@@ -295,7 +295,7 @@ def train(args, graph, features, train_set, valid_set, num_classes, model):
# in the last layer's computation graph.
y
=
data
.
labels
y_hat
=
model
(
data
.
blocks
,
x
)
y_hat
=
model
(
data
.
blocks
,
x
.
float
()
)
# Compute loss.
loss
=
F
.
cross_entropy
(
y_hat
,
y
)
...
...
@@ -399,7 +399,7 @@ def main(args):
model
,
num_classes
,
)
print
(
f
"Test
A
ccuracy
is
{
test_acc
.
item
():.
4
f
}
"
)
print
(
f
"Test
a
ccuracy
{
test_acc
.
item
():.
4
f
}
"
)
if
__name__
==
"__main__"
:
...
...
examples/sampling/node_classification.py
View file @
4b6faecb
...
...
@@ -35,6 +35,7 @@ main
"""
import
argparse
import
time
import
dgl
import
dgl.nn
as
dglnn
...
...
@@ -228,6 +229,7 @@ def train(args, device, g, dataset, model, num_classes, use_uva):
opt
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
1e-3
,
weight_decay
=
5e-4
)
for
epoch
in
range
(
10
):
t0
=
time
.
time
()
model
.
train
()
total_loss
=
0
# A block is a graph consisting of two sets of nodes: the
...
...
@@ -252,10 +254,11 @@ def train(args, device, g, dataset, model, num_classes, use_uva):
loss
.
backward
()
opt
.
step
()
total_loss
+=
loss
.
item
()
t1
=
time
.
time
()
acc
=
evaluate
(
model
,
g
,
val_dataloader
,
num_classes
)
print
(
f
"Epoch
{
epoch
:
05
d
}
| Loss
{
total_loss
/
(
it
+
1
):.
4
f
}
| "
f
"Accuracy
{
acc
.
item
():.
4
f
}
"
f
"Accuracy
{
acc
.
item
():.
4
f
}
| Time
{
t1
-
t0
:.
4
f
}
"
)
...
...
@@ -297,4 +300,4 @@ if __name__ == "__main__":
acc
=
layerwise_infer
(
device
,
g
,
dataset
.
test_idx
,
model
,
num_classes
,
batch_size
=
4096
)
print
(
f
"Test
A
ccuracy
{
acc
.
item
():.
4
f
}
"
)
print
(
f
"Test
a
ccuracy
{
acc
.
item
():.
4
f
}
"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment