"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "a7f07c1ef592fdcd60f37b1481bebb3de9705808"
Unverified Commit 70af8f0d authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] add support for load/save CSCSamplingGraph in python level (#5733)

parent 6862e372
...@@ -48,12 +48,34 @@ void CSCSamplingGraph::Load(torch::serialize::InputArchive& archive) { ...@@ -48,12 +48,34 @@ void CSCSamplingGraph::Load(torch::serialize::InputArchive& archive) {
"Magic numbers mismatch when loading CSCSamplingGraph."); "Magic numbers mismatch when loading CSCSamplingGraph.");
indptr_ = read_from_archive(archive, "CSCSamplingGraph/indptr").toTensor(); indptr_ = read_from_archive(archive, "CSCSamplingGraph/indptr").toTensor();
indices_ = read_from_archive(archive, "CSCSamplingGraph/indices").toTensor(); indices_ = read_from_archive(archive, "CSCSamplingGraph/indices").toTensor();
if (read_from_archive(archive, "CSCSamplingGraph/has_node_type_offset")
.toBool()) {
node_type_offset_ =
read_from_archive(archive, "CSCSamplingGraph/node_type_offset")
.toTensor();
}
if (read_from_archive(archive, "CSCSamplingGraph/has_type_per_edge")
.toBool()) {
type_per_edge_ =
read_from_archive(archive, "CSCSamplingGraph/type_per_edge").toTensor();
}
} }
void CSCSamplingGraph::Save(torch::serialize::OutputArchive& archive) const { void CSCSamplingGraph::Save(torch::serialize::OutputArchive& archive) const {
archive.write("CSCSamplingGraph/magic_num", kCSCSamplingGraphSerializeMagic); archive.write("CSCSamplingGraph/magic_num", kCSCSamplingGraphSerializeMagic);
archive.write("CSCSamplingGraph/indptr", indptr_); archive.write("CSCSamplingGraph/indptr", indptr_);
archive.write("CSCSamplingGraph/indices", indices_); archive.write("CSCSamplingGraph/indices", indices_);
archive.write(
"CSCSamplingGraph/has_node_type_offset", node_type_offset_.has_value());
if (node_type_offset_) {
archive.write(
"CSCSamplingGraph/node_type_offset", node_type_offset_.value());
}
archive.write(
"CSCSamplingGraph/has_type_per_edge", type_per_edge_.has_value());
if (type_per_edge_) {
archive.write("CSCSamplingGraph/type_per_edge", type_per_edge_.value());
}
} }
} // namespace sampling } // namespace sampling
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
*/ */
#include <graphbolt/csc_sampling_graph.h> #include <graphbolt/csc_sampling_graph.h>
#include <graphbolt/serialize.h>
namespace graphbolt { namespace graphbolt {
namespace sampling { namespace sampling {
...@@ -18,6 +19,8 @@ TORCH_LIBRARY(graphbolt, m) { ...@@ -18,6 +19,8 @@ TORCH_LIBRARY(graphbolt, m) {
.def("node_type_offset", &CSCSamplingGraph::NodeTypeOffset) .def("node_type_offset", &CSCSamplingGraph::NodeTypeOffset)
.def("type_per_edge", &CSCSamplingGraph::TypePerEdge); .def("type_per_edge", &CSCSamplingGraph::TypePerEdge);
m.def("from_csc", &CSCSamplingGraph::FromCSC); m.def("from_csc", &CSCSamplingGraph::FromCSC);
m.def("load_csc_sampling_graph", &LoadCSCSamplingGraph);
m.def("save_csc_sampling_graph", &SaveCSCSamplingGraph);
} }
} // namespace sampling } // namespace sampling
......
"""CSC format sampling graph.""" """CSC format sampling graph."""
# pylint: disable= invalid-name # pylint: disable= invalid-name
import os
import tarfile
import tempfile
from typing import Dict, Optional, Tuple from typing import Dict, Optional, Tuple
import torch import torch
...@@ -254,3 +257,35 @@ def _csc_sampling_graph_str(graph: CSCSamplingGraph) -> str: ...@@ -254,3 +257,35 @@ def _csc_sampling_graph_str(graph: CSCSamplingGraph) -> str:
final_str = prefix + _add_indent(final_str, len(prefix)) final_str = prefix + _add_indent(final_str, len(prefix))
return final_str return final_str
def load_csc_sampling_graph(filename):
"""Load CSCSamplingGraph from tar file."""
with tempfile.TemporaryDirectory() as temp_dir:
with tarfile.open(filename, "r") as archive:
archive.extractall(temp_dir)
graph_filename = os.path.join(temp_dir, "csc_sampling_graph.pt")
metadata_filename = os.path.join(temp_dir, "metadata.pt")
return CSCSamplingGraph(
torch.ops.graphbolt.load_csc_sampling_graph(graph_filename),
torch.load(metadata_filename),
)
def save_csc_sampling_graph(graph, filename):
"""Save CSCSamplingGraph to tar file."""
with tempfile.TemporaryDirectory() as temp_dir:
graph_filename = os.path.join(temp_dir, "csc_sampling_graph.pt")
torch.ops.graphbolt.save_csc_sampling_graph(
graph._c_csc_graph, graph_filename
)
metadata_filename = os.path.join(temp_dir, "metadata.pt")
torch.save(graph.metadata, metadata_filename)
with tarfile.open(filename, "w") as archive:
archive.add(
graph_filename, arcname=os.path.basename(graph_filename)
)
archive.add(
metadata_filename, arcname=os.path.basename(metadata_filename)
)
print(f"CSCSamplingGraph has been saved to {filename}.")
import os
import tempfile
import unittest import unittest
import backend as F import backend as F
...@@ -211,7 +213,64 @@ def test_node_type_offset_wrong_legnth(node_type_offset): ...@@ -211,7 +213,64 @@ def test_node_type_offset_wrong_legnth(node_type_offset):
) )
if __name__ == "__main__": @unittest.skipIf(
test_empty_graph(10) F._default_context_str == "gpu",
test_node_type_offset_wrong_legnth(torch.tensor([0, 1, 5])) reason="Graph is CPU only at present.",
test_hetero_graph(10, 50, 3, 5) )
@pytest.mark.parametrize(
"num_nodes, num_edges", [(1, 1), (100, 1), (10, 50), (1000, 50000)]
)
def test_load_save_homo_graph(num_nodes, num_edges):
csc_indptr, indices = random_homo_graph(num_nodes, num_edges)
graph = gb.from_csc(csc_indptr, indices)
with tempfile.TemporaryDirectory() as test_dir:
filename = os.path.join(test_dir, "csc_sampling_graph.tar")
gb.save_csc_sampling_graph(graph, filename)
graph2 = gb.load_csc_sampling_graph(filename)
assert graph.num_nodes == graph2.num_nodes
assert graph.num_edges == graph2.num_edges
assert torch.equal(graph.csc_indptr, graph2.csc_indptr)
assert torch.equal(graph.indices, graph2.indices)
assert graph.metadata is None and graph2.metadata is None
assert graph.node_type_offset is None and graph2.node_type_offset is None
assert graph.type_per_edge is None and graph2.type_per_edge is None
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Graph is CPU only at present.",
)
@pytest.mark.parametrize(
"num_nodes, num_edges", [(1, 1), (100, 1), (10, 50), (1000, 50000)]
)
@pytest.mark.parametrize("num_ntypes, num_etypes", [(1, 1), (3, 5), (100, 1)])
def test_load_save_hetero_graph(num_nodes, num_edges, num_ntypes, num_etypes):
(
csc_indptr,
indices,
node_type_offset,
type_per_edge,
metadata,
) = random_hetero_graph(num_nodes, num_edges, num_ntypes, num_etypes)
graph = gb.from_csc(
csc_indptr, indices, node_type_offset, type_per_edge, metadata
)
with tempfile.TemporaryDirectory() as test_dir:
filename = os.path.join(test_dir, "csc_sampling_graph.tar")
gb.save_csc_sampling_graph(graph, filename)
graph2 = gb.load_csc_sampling_graph(filename)
assert graph.num_nodes == graph2.num_nodes
assert graph.num_edges == graph2.num_edges
assert torch.equal(graph.csc_indptr, graph2.csc_indptr)
assert torch.equal(graph.indices, graph2.indices)
assert torch.equal(graph.node_type_offset, graph2.node_type_offset)
assert torch.equal(graph.type_per_edge, graph2.type_per_edge)
assert graph.metadata.node_type_to_id == graph2.metadata.node_type_to_id
assert graph.metadata.edge_type_to_id == graph2.metadata.edge_type_to_id
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment