Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
397b7599
Unverified
Commit
397b7599
authored
Jan 03, 2024
by
czkkkkkk
Committed by
GitHub
Jan 03, 2024
Browse files
[Graphbolt] Support loading heterogeneous attributes in sampling graph. (#6873)
parent
f5981789
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
39 additions
and
5 deletions
+39
-5
python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
+10
-3
python/dgl/graphbolt/impl/ondisk_dataset.py
python/dgl/graphbolt/impl/ondisk_dataset.py
+29
-2
No files found.
python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
View file @
397b7599
...
...
@@ -6,7 +6,7 @@ import torch
from
dgl.utils
import
recursive_apply
from
...base
import
EID
,
ETYPE
from
...base
import
EID
,
ETYPE
,
NID
,
NTYPE
from
...convert
import
to_homogeneous
from
...heterograph
import
DGLGraph
from
..base
import
etype_str_to_tuple
,
etype_tuple_to_str
,
ORIGINAL_EDGE_ID
...
...
@@ -1117,7 +1117,9 @@ def from_dglgraph(
)
->
FusedCSCSamplingGraph
:
"""Convert a DGLGraph to FusedCSCSamplingGraph."""
homo_g
,
ntype_count
,
_
=
to_homogeneous
(
g
,
return_count
=
True
)
homo_g
,
ntype_count
,
_
=
to_homogeneous
(
g
,
ndata
=
g
.
ndata
,
edata
=
g
.
edata
,
return_count
=
True
)
if
is_homogeneous
:
node_type_to_id
=
None
...
...
@@ -1147,8 +1149,13 @@ def from_dglgraph(
)
node_attributes
=
{}
edge_attributes
=
{}
for
feat_name
,
feat_data
in
homo_g
.
ndata
.
items
():
if
feat_name
not
in
(
NID
,
NTYPE
):
node_attributes
[
feat_name
]
=
feat_data
for
feat_name
,
feat_data
in
homo_g
.
edata
.
items
():
if
feat_name
not
in
(
EID
,
ETYPE
):
edge_attributes
[
feat_name
]
=
feat_data
if
include_original_edge_id
:
# Assign edge attributes according to the original eids mapping.
edge_attributes
[
ORIGINAL_EDGE_ID
]
=
torch
.
index_select
(
...
...
python/dgl/graphbolt/impl/ondisk_dataset.py
View file @
397b7599
...
...
@@ -129,14 +129,41 @@ def preprocess_ondisk_dataset(
graph_feature
[
"format"
],
in_memory
=
in_memory
,
)
g
.
ndata
[
graph_feature
[
"name"
]]
=
node_data
if
is_homogeneous
:
g
.
ndata
[
graph_feature
[
"name"
]]
=
node_data
else
:
g
.
nodes
[
graph_feature
[
"type"
]].
data
[
graph_feature
[
"name"
]
]
=
node_data
if
graph_feature
[
"domain"
]
==
"edge"
:
edge_data
=
read_data
(
os
.
path
.
join
(
dataset_dir
,
graph_feature
[
"path"
]),
graph_feature
[
"format"
],
in_memory
=
in_memory
,
)
g
.
edata
[
graph_feature
[
"name"
]]
=
edge_data
if
is_homogeneous
:
g
.
edata
[
graph_feature
[
"name"
]]
=
edge_data
else
:
g
.
edges
[
etype_str_to_tuple
(
graph_feature
[
"type"
])].
data
[
graph_feature
[
"name"
]
]
=
edge_data
if
not
is_homogeneous
:
# For homogeneous graph, a node/edge feature must cover all
# node/edge types.
for
feat_name
,
feat_data
in
g
.
ndata
.
items
():
existing_types
=
set
(
feat_data
.
keys
())
assert
existing_types
==
set
(
g
.
ntypes
),
(
f
"Node feature
{
feat_name
}
does not cover all node types."
+
f
"Existing types:
{
existing_types
}
."
+
f
"Expected types:
{
g
.
ntypes
}
."
)
for
feat_name
,
feat_data
in
g
.
edata
.
items
():
existing_types
=
set
(
feat_data
.
keys
())
assert
existing_types
==
set
(
g
.
canonical_etypes
),
(
f
"Edge feature
{
feat_name
}
does not cover all edge types."
+
f
"Existing types:
{
existing_types
}
."
+
f
"Expected types:
{
g
.
etypes
}
."
)
# 4. Convert the DGLGraph to a FusedCSCSamplingGraph.
fused_csc_sampling_graph
=
from_dglgraph
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment