Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
2f47c241
Unverified
Commit
2f47c241
authored
Nov 14, 2023
by
Rhett Ying
Committed by
GitHub
Nov 14, 2023
Browse files
[GraphBolt] fix incorrect call in example (#6563)
parent
07db09f7
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
7 additions
and
4 deletions
+7
-4
examples/sampling/graphbolt/rgcn/hetero_rgcn.py
examples/sampling/graphbolt/rgcn/hetero_rgcn.py
+7
-4
No files found.
examples/sampling/graphbolt/rgcn/hetero_rgcn.py
View file @
2f47c241
...
@@ -46,6 +46,7 @@ import argparse
...
@@ -46,6 +46,7 @@ import argparse
import
itertools
import
itertools
import
sys
import
sys
import
dgl
import
dgl.graphbolt
as
gb
import
dgl.graphbolt
as
gb
import
dgl.nn
as
dglnn
import
dgl.nn
as
dglnn
...
@@ -428,7 +429,9 @@ class Logger(object):
...
@@ -428,7 +429,9 @@ class Logger(object):
def
extract_node_features
(
name
,
block
,
data
,
node_embed
,
device
):
def
extract_node_features
(
name
,
block
,
data
,
node_embed
,
device
):
"""Extract the node features from embedding layer or raw features."""
"""Extract the node features from embedding layer or raw features."""
if
name
==
"ogbn-mag"
:
if
name
==
"ogbn-mag"
:
input_nodes
=
{
k
:
v
.
to
(
device
)
for
k
,
v
in
data
.
input_nodes
.
items
()}
input_nodes
=
{
k
:
v
.
to
(
device
)
for
k
,
v
in
block
.
srcdata
[
dgl
.
NID
].
items
()
}
# Extract node embeddings for the input nodes.
# Extract node embeddings for the input nodes.
node_features
=
extract_embed
(
node_embed
,
input_nodes
)
node_features
=
extract_embed
(
node_embed
,
input_nodes
)
# Add the batch's raw "paper" features. Corresponds to the content
# Add the batch's raw "paper" features. Corresponds to the content
...
@@ -554,12 +557,12 @@ def run(
...
@@ -554,12 +557,12 @@ def run(
num_workers
=
num_workers
,
num_workers
=
num_workers
,
)
)
for
data
in
tqdm
(
data_loader
,
desc
=
f
"Training~Epoch
{
epoch
:
02
d
}
"
):
for
data
in
tqdm
(
data_loader
,
desc
=
f
"Training~Epoch
{
epoch
:
02
d
}
"
):
# Fetch the number of seed nodes in the batch.
num_seeds
=
data
.
output_nodes
[
category
].
shape
[
0
]
# Convert MiniBatch to DGL Blocks.
# Convert MiniBatch to DGL Blocks.
blocks
=
[
block
.
to
(
device
)
for
block
in
data
.
blocks
]
blocks
=
[
block
.
to
(
device
)
for
block
in
data
.
blocks
]
# Fetch the number of seed nodes in the batch.
num_seeds
=
blocks
[
-
1
].
num_dst_nodes
(
category
)
# Extract the node features from embedding layer or raw features.
# Extract the node features from embedding layer or raw features.
node_features
=
extract_node_features
(
node_features
=
extract_node_features
(
name
,
blocks
[
0
],
data
,
node_embed
,
device
name
,
blocks
[
0
],
data
,
node_embed
,
device
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment