Unverified Commit fc06d7fc authored by LastWhisper's avatar LastWhisper Committed by GitHub
Browse files

[GraphBolt] Enable `CSCSamplingGraph::edge_attributes` save and load. (#6422)

parent 2595fa98
...@@ -80,6 +80,25 @@ void CSCSamplingGraph::Load(torch::serialize::InputArchive& archive) { ...@@ -80,6 +80,25 @@ void CSCSamplingGraph::Load(torch::serialize::InputArchive& archive) {
type_per_edge_ = type_per_edge_ =
read_from_archive(archive, "CSCSamplingGraph/type_per_edge").toTensor(); read_from_archive(archive, "CSCSamplingGraph/type_per_edge").toTensor();
} }
// Optional edge attributes.
torch::IValue has_edge_attributes;
if (archive.try_read(
"CSCSamplingGraph/has_edge_attributes", has_edge_attributes) &&
has_edge_attributes.toBool()) {
torch::Dict<torch::IValue, torch::IValue> generic_dict =
read_from_archive(archive, "CSCSamplingGraph/edge_attributes")
.toGenericDict();
EdgeAttrMap target_dict;
for (const auto& pair : generic_dict) {
std::string key = pair.key().toStringRef();
torch::Tensor value = pair.value().toTensor();
// Use move to avoid copy.
target_dict.insert(std::move(key), std::move(value));
}
// Same as above.
edge_attributes_ = std::move(target_dict);
}
} }
void CSCSamplingGraph::Save(torch::serialize::OutputArchive& archive) const { void CSCSamplingGraph::Save(torch::serialize::OutputArchive& archive) const {
...@@ -97,6 +116,11 @@ void CSCSamplingGraph::Save(torch::serialize::OutputArchive& archive) const { ...@@ -97,6 +116,11 @@ void CSCSamplingGraph::Save(torch::serialize::OutputArchive& archive) const {
if (type_per_edge_) { if (type_per_edge_) {
archive.write("CSCSamplingGraph/type_per_edge", type_per_edge_.value()); archive.write("CSCSamplingGraph/type_per_edge", type_per_edge_.value());
} }
archive.write(
"CSCSamplingGraph/has_edge_attributes", edge_attributes_.has_value());
if (edge_attributes_) {
archive.write("CSCSamplingGraph/edge_attributes", edge_attributes_.value());
}
} }
void CSCSamplingGraph::SetState( void CSCSamplingGraph::SetState(
......
...@@ -1577,6 +1577,10 @@ def test_OnDiskDataset_load_graph(): ...@@ -1577,6 +1577,10 @@ def test_OnDiskDataset_load_graph():
with open(yaml_file, "w") as f: with open(yaml_file, "w") as f:
f.write(yaml_content) f.write(yaml_content)
# Check if the CSCSamplingGraph.edge_attributes loaded.
dataset = gb.OnDiskDataset(test_dir).load()
assert dataset.graph.edge_attributes is not None
# Case1. Test modify the `type` field. # Case1. Test modify the `type` field.
dataset = gb.OnDiskDataset(test_dir) dataset = gb.OnDiskDataset(test_dir)
dataset.yaml_data["graph_topology"]["type"] = "fake_type" dataset.yaml_data["graph_topology"]["type"] = "fake_type"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment