Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
31f1b30c
"git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "28f72f16b7716f6914cc2dad2e6977b6de58aaab"
Unverified
Commit
31f1b30c
authored
Aug 02, 2023
by
Andrei Ivanov
Committed by
GitHub
Aug 03, 2023
Browse files
Improving the GGNN example. (#6055)
parent
83437e67
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
5 additions
and
5 deletions
+5
-5
examples/pytorch/ggnn/data_utils.py
examples/pytorch/ggnn/data_utils.py
+3
-3
examples/pytorch/ggnn/train_ns.py
examples/pytorch/ggnn/train_ns.py
+2
-2
No files found.
examples/pytorch/ggnn/data_utils.py
View file @
31f1b30c
...
@@ -104,7 +104,7 @@ def _ns_dataloader(
...
@@ -104,7 +104,7 @@ def _ns_dataloader(
node_ids
.
append
(
s
)
node_ids
.
append
(
s
)
if
t
not
in
node_ids
:
if
t
not
in
node_ids
:
node_ids
.
append
(
t
)
node_ids
.
append
(
t
)
g
=
dgl
.
DGLG
raph
()
g
=
dgl
.
g
raph
(
[]
)
g
.
add_nodes
(
len
(
node_ids
))
g
.
add_nodes
(
len
(
node_ids
))
g
.
ndata
[
"node_id"
]
=
torch
.
tensor
(
node_ids
,
dtype
=
torch
.
long
)
g
.
ndata
[
"node_id"
]
=
torch
.
tensor
(
node_ids
,
dtype
=
torch
.
long
)
...
@@ -224,7 +224,7 @@ def _gc_dataloader(
...
@@ -224,7 +224,7 @@ def _gc_dataloader(
node_ids
.
append
(
s
)
node_ids
.
append
(
s
)
if
t
not
in
node_ids
:
if
t
not
in
node_ids
:
node_ids
.
append
(
t
)
node_ids
.
append
(
t
)
g
=
dgl
.
DGLG
raph
()
g
=
dgl
.
g
raph
(
[]
)
g
.
add_nodes
(
len
(
node_ids
))
g
.
add_nodes
(
len
(
node_ids
))
g
.
ndata
[
"node_id"
]
=
torch
.
tensor
(
node_ids
,
dtype
=
torch
.
long
)
g
.
ndata
[
"node_id"
]
=
torch
.
tensor
(
node_ids
,
dtype
=
torch
.
long
)
...
@@ -346,7 +346,7 @@ def _path_finding_dataloader(
...
@@ -346,7 +346,7 @@ def _path_finding_dataloader(
node_ids
.
append
(
s
)
node_ids
.
append
(
s
)
if
t
not
in
node_ids
:
if
t
not
in
node_ids
:
node_ids
.
append
(
t
)
node_ids
.
append
(
t
)
g
=
dgl
.
DGLG
raph
()
g
=
dgl
.
g
raph
(
[]
)
g
.
add_nodes
(
len
(
node_ids
))
g
.
add_nodes
(
len
(
node_ids
))
g
.
ndata
[
"node_id"
]
=
torch
.
tensor
(
node_ids
,
dtype
=
torch
.
long
)
g
.
ndata
[
"node_id"
]
=
torch
.
tensor
(
node_ids
,
dtype
=
torch
.
long
)
...
...
examples/pytorch/ggnn/train_ns.py
View file @
31f1b30c
...
@@ -59,7 +59,7 @@ def main(args):
...
@@ -59,7 +59,7 @@ def main(args):
labels
=
labels
.
data
.
numpy
().
tolist
()
labels
=
labels
.
data
.
numpy
().
tolist
()
dev_preds
+=
preds
dev_preds
+=
preds
dev_labels
+=
labels
dev_labels
+=
labels
acc
=
np
.
equal
(
dev_labels
,
dev_preds
).
astype
(
np
.
float
).
tolist
()
acc
=
np
.
equal
(
dev_labels
,
dev_preds
).
astype
(
float
).
tolist
()
acc
=
sum
(
acc
)
/
len
(
acc
)
acc
=
sum
(
acc
)
/
len
(
acc
)
print
(
f
"Epoch
{
epoch
}
, Dev acc
{
acc
}
"
)
print
(
f
"Epoch
{
epoch
}
, Dev acc
{
acc
}
"
)
...
@@ -81,7 +81,7 @@ def main(args):
...
@@ -81,7 +81,7 @@ def main(args):
labels
=
labels
.
data
.
numpy
().
tolist
()
labels
=
labels
.
data
.
numpy
().
tolist
()
test_preds
+=
preds
test_preds
+=
preds
test_labels
+=
labels
test_labels
+=
labels
acc
=
np
.
equal
(
test_labels
,
test_preds
).
astype
(
np
.
float
).
tolist
()
acc
=
np
.
equal
(
test_labels
,
test_preds
).
astype
(
float
).
tolist
()
acc
=
sum
(
acc
)
/
len
(
acc
)
acc
=
sum
(
acc
)
/
len
(
acc
)
test_acc_list
.
append
(
acc
)
test_acc_list
.
append
(
acc
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment