Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
364806f2
"git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "97f543f3bc917ecaefb5b194f5cd7c342fbf78c9"
Unverified
Commit
364806f2
authored
Oct 28, 2023
by
Mingbang Wang
Committed by
GitHub
Oct 28, 2023
Browse files
[GraphBolt] Modify `node_classification` for benchmark (#6501)
parent
e645d936
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
4 additions
and
1 deletion
+4
-1
examples/sampling/graphbolt/node_classification.py
examples/sampling/graphbolt/node_classification.py
+4
-1
No files found.
examples/sampling/graphbolt/node_classification.py
View file @
364806f2
...
@@ -38,6 +38,7 @@ main
...
@@ -38,6 +38,7 @@ main
└───> All nodes set inference & Test set evaluation
└───> All nodes set inference & Test set evaluation
"""
"""
import
argparse
import
argparse
import
time
import
dgl.graphbolt
as
gb
import
dgl.graphbolt
as
gb
import
dgl.nn
as
dglnn
import
dgl.nn
as
dglnn
...
@@ -282,6 +283,7 @@ def train(args, graph, features, train_set, valid_set, num_classes, model):
...
@@ -282,6 +283,7 @@ def train(args, graph, features, train_set, valid_set, num_classes, model):
)
)
for
epoch
in
range
(
args
.
epochs
):
for
epoch
in
range
(
args
.
epochs
):
t0
=
time
.
time
()
model
.
train
()
model
.
train
()
total_loss
=
0
total_loss
=
0
for
step
,
data
in
tqdm
(
enumerate
(
dataloader
)):
for
step
,
data
in
tqdm
(
enumerate
(
dataloader
)):
...
@@ -304,11 +306,12 @@ def train(args, graph, features, train_set, valid_set, num_classes, model):
...
@@ -304,11 +306,12 @@ def train(args, graph, features, train_set, valid_set, num_classes, model):
total_loss
+=
loss
.
item
()
total_loss
+=
loss
.
item
()
t1
=
time
.
time
()
# Evaluate the model.
# Evaluate the model.
acc
=
evaluate
(
args
,
model
,
graph
,
features
,
valid_set
,
num_classes
)
acc
=
evaluate
(
args
,
model
,
graph
,
features
,
valid_set
,
num_classes
)
print
(
print
(
f
"Epoch
{
epoch
:
05
d
}
| Loss
{
total_loss
/
(
step
+
1
):.
4
f
}
| "
f
"Epoch
{
epoch
:
05
d
}
| Loss
{
total_loss
/
(
step
+
1
):.
4
f
}
| "
f
"Accuracy
{
acc
.
item
():.
4
f
}
"
f
"Accuracy
{
acc
.
item
():.
4
f
}
| Time
{
t1
-
t0
:.
4
f
}
"
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment