Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
fc06d7fc
Unverified
Commit
fc06d7fc
authored
Oct 12, 2023
by
LastWhisper
Committed by
GitHub
Oct 12, 2023
Browse files
[GraphBolt] Enable `CSCSamplingGraph::edge_attributes` save and load. (#6422)
parent
2595fa98
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
28 additions
and
0 deletions
+28
-0
graphbolt/src/csc_sampling_graph.cc
graphbolt/src/csc_sampling_graph.cc
+24
-0
tests/python/pytorch/graphbolt/impl/test_ondisk_dataset.py
tests/python/pytorch/graphbolt/impl/test_ondisk_dataset.py
+4
-0
No files found.
graphbolt/src/csc_sampling_graph.cc
View file @
fc06d7fc
...
@@ -80,6 +80,25 @@ void CSCSamplingGraph::Load(torch::serialize::InputArchive& archive) {
...
@@ -80,6 +80,25 @@ void CSCSamplingGraph::Load(torch::serialize::InputArchive& archive) {
type_per_edge_
=
type_per_edge_
=
read_from_archive
(
archive
,
"CSCSamplingGraph/type_per_edge"
).
toTensor
();
read_from_archive
(
archive
,
"CSCSamplingGraph/type_per_edge"
).
toTensor
();
}
}
// Optional edge attributes.
torch
::
IValue
has_edge_attributes
;
if
(
archive
.
try_read
(
"CSCSamplingGraph/has_edge_attributes"
,
has_edge_attributes
)
&&
has_edge_attributes
.
toBool
())
{
torch
::
Dict
<
torch
::
IValue
,
torch
::
IValue
>
generic_dict
=
read_from_archive
(
archive
,
"CSCSamplingGraph/edge_attributes"
)
.
toGenericDict
();
EdgeAttrMap
target_dict
;
for
(
const
auto
&
pair
:
generic_dict
)
{
std
::
string
key
=
pair
.
key
().
toStringRef
();
torch
::
Tensor
value
=
pair
.
value
().
toTensor
();
// Use move to avoid copy.
target_dict
.
insert
(
std
::
move
(
key
),
std
::
move
(
value
));
}
// Same as above.
edge_attributes_
=
std
::
move
(
target_dict
);
}
}
}
void
CSCSamplingGraph
::
Save
(
torch
::
serialize
::
OutputArchive
&
archive
)
const
{
void
CSCSamplingGraph
::
Save
(
torch
::
serialize
::
OutputArchive
&
archive
)
const
{
...
@@ -97,6 +116,11 @@ void CSCSamplingGraph::Save(torch::serialize::OutputArchive& archive) const {
...
@@ -97,6 +116,11 @@ void CSCSamplingGraph::Save(torch::serialize::OutputArchive& archive) const {
if
(
type_per_edge_
)
{
if
(
type_per_edge_
)
{
archive
.
write
(
"CSCSamplingGraph/type_per_edge"
,
type_per_edge_
.
value
());
archive
.
write
(
"CSCSamplingGraph/type_per_edge"
,
type_per_edge_
.
value
());
}
}
archive
.
write
(
"CSCSamplingGraph/has_edge_attributes"
,
edge_attributes_
.
has_value
());
if
(
edge_attributes_
)
{
archive
.
write
(
"CSCSamplingGraph/edge_attributes"
,
edge_attributes_
.
value
());
}
}
}
void
CSCSamplingGraph
::
SetState
(
void
CSCSamplingGraph
::
SetState
(
...
...
tests/python/pytorch/graphbolt/impl/test_ondisk_dataset.py
View file @
fc06d7fc
...
@@ -1577,6 +1577,10 @@ def test_OnDiskDataset_load_graph():
...
@@ -1577,6 +1577,10 @@ def test_OnDiskDataset_load_graph():
with
open
(
yaml_file
,
"w"
)
as
f
:
with
open
(
yaml_file
,
"w"
)
as
f
:
f
.
write
(
yaml_content
)
f
.
write
(
yaml_content
)
# Check if the CSCSamplingGraph.edge_attributes loaded.
dataset
=
gb
.
OnDiskDataset
(
test_dir
).
load
()
assert
dataset
.
graph
.
edge_attributes
is
not
None
# Case1. Test modify the `type` field.
# Case1. Test modify the `type` field.
dataset
=
gb
.
OnDiskDataset
(
test_dir
)
dataset
=
gb
.
OnDiskDataset
(
test_dir
)
dataset
.
yaml_data
[
"graph_topology"
][
"type"
]
=
"fake_type"
dataset
.
yaml_data
[
"graph_topology"
][
"type"
]
=
"fake_type"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment