Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
6566c31f
Unverified
Commit
6566c31f
authored
Nov 08, 2022
by
Chang Liu
Committed by
GitHub
Nov 09, 2022
Browse files
Fix ogb/ogbn-mag/heter-RGCN example (#4839)
Co-authored-by:
Mufei Li
<
mufeili1996@gmail.com
>
parent
344be1ef
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
5 additions
and
2 deletions
+5
-2
examples/pytorch/ogb/ogbn-mag/hetero_rgcn.py
examples/pytorch/ogb/ogbn-mag/hetero_rgcn.py
+5
-2
No files found.
examples/pytorch/ogb/ogbn-mag/hetero_rgcn.py
View file @
6566c31f
...
@@ -266,11 +266,13 @@ def train(
...
@@ -266,11 +266,13 @@ def train(
category
category
]
# we only predict the nodes with type "category"
]
# we only predict the nodes with type "category"
batch_size
=
seeds
.
shape
[
0
]
batch_size
=
seeds
.
shape
[
0
]
input_nodes_indexes
=
input_nodes
[
"paper"
].
to
(
g
.
device
)
seeds
=
seeds
.
to
(
labels
.
device
)
emb
=
extract_embed
(
node_embed
,
input_nodes
)
emb
=
extract_embed
(
node_embed
,
input_nodes
)
# Add the batch's raw "paper" features
# Add the batch's raw "paper" features
emb
.
update
(
emb
.
update
(
{
"paper"
:
g
.
ndata
[
"feat"
][
"paper"
][
input_nodes
[
"paper"
]
]}
{
"paper"
:
g
.
ndata
[
"feat"
][
"paper"
][
input_nodes
_indexes
]}
)
)
emb
=
{
k
:
e
.
to
(
device
)
for
k
,
e
in
emb
.
items
()}
emb
=
{
k
:
e
.
to
(
device
)
for
k
,
e
in
emb
.
items
()}
...
@@ -334,10 +336,11 @@ def test(g, model, node_embed, y_true, device, split_idx):
...
@@ -334,10 +336,11 @@ def test(g, model, node_embed, y_true, device, split_idx):
category
category
]
# we only predict the nodes with type "category"
]
# we only predict the nodes with type "category"
batch_size
=
seeds
.
shape
[
0
]
batch_size
=
seeds
.
shape
[
0
]
input_nodes_indexes
=
input_nodes
[
"paper"
].
to
(
g
.
device
)
emb
=
extract_embed
(
node_embed
,
input_nodes
)
emb
=
extract_embed
(
node_embed
,
input_nodes
)
# Get the batch's raw "paper" features
# Get the batch's raw "paper" features
emb
.
update
({
"paper"
:
g
.
ndata
[
"feat"
][
"paper"
][
input_nodes
[
"paper"
]
]})
emb
.
update
({
"paper"
:
g
.
ndata
[
"feat"
][
"paper"
][
input_nodes
_indexes
]})
emb
=
{
k
:
e
.
to
(
device
)
for
k
,
e
in
emb
.
items
()}
emb
=
{
k
:
e
.
to
(
device
)
for
k
,
e
in
emb
.
items
()}
logits
=
model
(
emb
,
blocks
)[
category
]
logits
=
model
(
emb
,
blocks
)[
category
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment