Unverified Commit 72d16f78 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] optimize load/save logic (#5713)

parent 1e295c53
...@@ -96,6 +96,9 @@ struct HeteroInfo { ...@@ -96,6 +96,9 @@ struct HeteroInfo {
*/ */
class CSCSamplingGraph : public torch::CustomClassHolder { class CSCSamplingGraph : public torch::CustomClassHolder {
public: public:
/** @brief Default constructor. */
CSCSamplingGraph() = default;
/** /**
* @brief Constructor for CSC with data. * @brief Constructor for CSC with data.
* @param num_nodes The number of nodes in the graph. * @param num_nodes The number of nodes in the graph.
...@@ -203,33 +206,4 @@ class CSCSamplingGraph : public torch::CustomClassHolder { ...@@ -203,33 +206,4 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
} // namespace sampling } // namespace sampling
} // namespace graphbolt } // namespace graphbolt
/**
* @brief Overload stream operator to enable `torch::save()` and `torch.load()`
* for CSCSamplingGraph.
*/
namespace torch {
/**
* @brief Overload input stream operator for CSCSamplingGraph deserialization.
* @param archive Input stream for deserializing.
* @param graph CSCSamplingGraph.
*
* @return archive
*/
inline serialize::InputArchive& operator>>(
serialize::InputArchive& archive,
graphbolt::sampling::CSCSamplingGraph& graph);
/**
* @brief Overload output stream operator for CSCSamplingGraph serialization.
* @param archive Output stream for serializing.
* @param graph CSCSamplingGraph.
*
* @return archive
*/
inline serialize::OutputArchive& operator<<(
serialize::OutputArchive& archive,
const graphbolt::sampling::CSCSamplingGraph& graph);
} // namespace torch
#endif // GRAPHBOLT_CSC_SAMPLING_GRAPH_H_ #endif // GRAPHBOLT_CSC_SAMPLING_GRAPH_H_
...@@ -12,134 +12,69 @@ ...@@ -12,134 +12,69 @@
#include <string> #include <string>
#include <vector> #include <vector>
namespace graphbolt { #include "csc_sampling_graph.h"
namespace utils {
/** /**
* @brief Utility function to write to archive. * @brief Overload stream operator to enable `torch::save()` and `torch.load()`
* @param archive Output archive. * for CSCSamplingGraph.
* @param key Key name used in saving.
* @param data Data that could be constructed as `torch::IValue`.
*/ */
template <typename DataT> namespace torch {
void write_to_archive(
torch::serialize::OutputArchive& archive, const std::string& key,
const DataT& data) {
archive.write(key, data);
}
/** /**
* @brief Specialization utility function to save string vector. * @brief Overload input stream operator for CSCSamplingGraph deserialization.
* @param archive Output archive. * @param archive Input stream for deserializing.
* @param key Key name used in saving. * @param graph CSCSamplingGraph.
* @param data Vector of string. *
* @return archive
*/ */
template <> inline serialize::InputArchive& operator>>(
void write_to_archive<std::vector<std::string>>( serialize::InputArchive& archive,
torch::serialize::OutputArchive& archive, const std::string& key, graphbolt::sampling::CSCSamplingGraph& graph);
const std::vector<std::string>& data) {
archive.write(
key + "/size", torch::tensor(static_cast<int64_t>(data.size())));
for (const auto index : c10::irange(data.size())) {
archive.write(key + "/" + std::to_string(index), data[index]);
}
}
/** /**
* @brief Utility function to read from archive. * @brief Overload output stream operator for CSCSamplingGraph serialization.
* @param archive Input archive. * @param archive Output stream for serializing.
* @param key Key name used in reading. * @param graph CSCSamplingGraph.
* @param data Data that could be constructed as `torch::IValue`. *
* @return archive
*/ */
template <typename DataT = torch::IValue> inline serialize::OutputArchive& operator<<(
void read_from_archive( serialize::OutputArchive& archive,
torch::serialize::InputArchive& archive, const std::string& key, const graphbolt::sampling::CSCSamplingGraph& graph);
DataT& data) {
archive.read(key, data);
}
/** } // namespace torch
* @brief Specialization utility function to read from archive.
* @param archive Input archive.
* @param key Key name used in reading.
* @param data Data that is `bool`.
*/
template <>
void read_from_archive<bool>(
torch::serialize::InputArchive& archive, const std::string& key,
bool& data) {
torch::IValue iv_data;
archive.read(key, iv_data);
data = iv_data.toBool();
}
/** namespace graphbolt {
* @brief Specialization utility function to read from archive.
* @param archive Input archive.
* @param key Key name used in reading.
* @param data Data that is `int64_t`.
*/
template <>
void read_from_archive<int64_t>(
torch::serialize::InputArchive& archive, const std::string& key,
int64_t& data) {
torch::IValue iv_data;
archive.read(key, iv_data);
data = iv_data.toInt();
}
/** /**
* @brief Specialization utility function to read from archive. * @brief Load CSCSamplingGraph from file.
* @param archive Input archive. * @param filename File name to read.
* @param key Key name used in reading. *
* @param data Data that is `std::string`. * @return CSCSamplingGraph.
*/ */
template <> c10::intrusive_ptr<sampling::CSCSamplingGraph> LoadCSCSamplingGraph(
void read_from_archive<std::string>( const std::string& filename);
torch::serialize::InputArchive& archive, const std::string& key,
std::string& data) {
torch::IValue iv_data;
archive.read(key, iv_data);
data = iv_data.toString();
}
/** /**
* @brief Specialization utility function to read from archive. * @brief Save CSCSamplingGraph to file.
* @param archive Input archive. * @param graph CSCSamplingGraph to save.
* @param key Key name used in reading. * @param filename File name to save.
* @param data Data that is `torch::Tensor`. *
*/ */
template <> void SaveCSCSamplingGraph(
void read_from_archive<torch::Tensor>( c10::intrusive_ptr<sampling::CSCSamplingGraph> graph,
torch::serialize::InputArchive& archive, const std::string& key, const std::string& filename);
torch::Tensor& data) {
torch::IValue iv_data;
archive.read(key, iv_data);
data = iv_data.toTensor();
}
/** /**
* @brief Specialization utility function to read to string vector. * @brief Read data from archive.
* @param archive Output archive. * @param archive Input archive.
* @param key Key name used in saving. * @param key Key name of data.
* @param data Vector of string. *
* @return data.
*/ */
template <> torch::IValue read_from_archive(
void read_from_archive<std::vector<std::string>>( torch::serialize::InputArchive& archive, const std::string& key);
torch::serialize::InputArchive& archive, const std::string& key,
std::vector<std::string>& data) {
int64_t size = 0;
read_from_archive<int64_t>(archive, key + "/size", size);
data.resize(static_cast<size_t>(size));
std::string element;
for (int64_t index = 0; index < size; ++index) {
read_from_archive<std::string>(
archive, key + "/" + std::to_string(index), element);
data[index] = element;
}
}
} // namespace utils
} // namespace graphbolt } // namespace graphbolt
#endif // GRAPHBOLT_SERIALIZE_H_ #endif // GRAPHBOLT_SERIALIZE_H_
...@@ -11,26 +11,27 @@ namespace graphbolt { ...@@ -11,26 +11,27 @@ namespace graphbolt {
namespace sampling { namespace sampling {
void HeteroInfo::Load(torch::serialize::InputArchive& archive) { void HeteroInfo::Load(torch::serialize::InputArchive& archive) {
int64_t magic_num = 0x0; const int64_t magic_num =
utils::read_from_archive(archive, "HeteroInfo/magic_num", magic_num); read_from_archive(archive, "HeteroInfo/magic_num").toInt();
TORCH_CHECK( TORCH_CHECK(
magic_num == kHeteroInfoSerializeMagic, magic_num == kHeteroInfoSerializeMagic,
"Magic numbers mismatch when loading HeteroInfo."); "Magic numbers mismatch when loading HeteroInfo.");
utils::read_from_archive(archive, "HeteroInfo/node_types", node_types); node_types = read_from_archive(archive, "HeteroInfo/node_types")
utils::read_from_archive(archive, "HeteroInfo/edge_types", edge_types); .to<decltype(node_types)>();
utils::read_from_archive( edge_types = read_from_archive(archive, "HeteroInfo/edge_types")
archive, "HeteroInfo/node_type_offset", node_type_offset); .to<decltype(edge_types)>();
utils::read_from_archive(archive, "HeteroInfo/type_per_edge", type_per_edge); node_type_offset =
read_from_archive(archive, "HeteroInfo/node_type_offset").toTensor();
type_per_edge =
read_from_archive(archive, "HeteroInfo/type_per_edge").toTensor();
} }
void HeteroInfo::Save(torch::serialize::OutputArchive& archive) const { void HeteroInfo::Save(torch::serialize::OutputArchive& archive) const {
utils::write_to_archive( archive.write("HeteroInfo/magic_num", kHeteroInfoSerializeMagic);
archive, "HeteroInfo/magic_num", kHeteroInfoSerializeMagic); archive.write("HeteroInfo/node_types", node_types);
utils::write_to_archive(archive, "HeteroInfo/node_types", node_types); archive.write("HeteroInfo/edge_types", edge_types);
utils::write_to_archive(archive, "HeteroInfo/edge_types", edge_types); archive.write("HeteroInfo/node_type_offset", node_type_offset);
utils::write_to_archive( archive.write("HeteroInfo/type_per_edge", type_per_edge);
archive, "HeteroInfo/node_type_offset", node_type_offset);
utils::write_to_archive(archive, "HeteroInfo/type_per_edge", type_per_edge);
} }
CSCSamplingGraph::CSCSamplingGraph( CSCSamplingGraph::CSCSamplingGraph(
...@@ -73,17 +74,16 @@ c10::intrusive_ptr<CSCSamplingGraph> CSCSamplingGraph::FromCSCWithHeteroInfo( ...@@ -73,17 +74,16 @@ c10::intrusive_ptr<CSCSamplingGraph> CSCSamplingGraph::FromCSCWithHeteroInfo(
} }
void CSCSamplingGraph::Load(torch::serialize::InputArchive& archive) { void CSCSamplingGraph::Load(torch::serialize::InputArchive& archive) {
int64_t magic_num = 0x0; const int64_t magic_num =
utils::read_from_archive(archive, "CSCSamplingGraph/magic_num", magic_num); read_from_archive(archive, "CSCSamplingGraph/magic_num").toInt();
TORCH_CHECK( TORCH_CHECK(
magic_num == kCSCSamplingGraphSerializeMagic, magic_num == kCSCSamplingGraphSerializeMagic,
"Magic numbers mismatch when loading CSCSamplingGraph."); "Magic numbers mismatch when loading CSCSamplingGraph.");
utils::read_from_archive(archive, "CSCSamplingGraph/num_nodes", num_nodes_); num_nodes_ = read_from_archive(archive, "CSCSamplingGraph/num_nodes").toInt();
utils::read_from_archive(archive, "CSCSamplingGraph/indptr", indptr_); indptr_ = read_from_archive(archive, "CSCSamplingGraph/indptr").toTensor();
utils::read_from_archive(archive, "CSCSamplingGraph/indices", indices_); indices_ = read_from_archive(archive, "CSCSamplingGraph/indices").toTensor();
bool is_heterogeneous = false; const bool is_heterogeneous =
utils::read_from_archive( read_from_archive(archive, "CSCSamplingGraph/is_hetero").toBool();
archive, "CSCSamplingGraph/is_hetero", is_heterogeneous);
if (is_heterogeneous) { if (is_heterogeneous) {
hetero_info_ = std::make_shared<HeteroInfo>(); hetero_info_ = std::make_shared<HeteroInfo>();
hetero_info_->Load(archive); hetero_info_->Load(archive);
...@@ -91,13 +91,11 @@ void CSCSamplingGraph::Load(torch::serialize::InputArchive& archive) { ...@@ -91,13 +91,11 @@ void CSCSamplingGraph::Load(torch::serialize::InputArchive& archive) {
} }
void CSCSamplingGraph::Save(torch::serialize::OutputArchive& archive) const { void CSCSamplingGraph::Save(torch::serialize::OutputArchive& archive) const {
archive.write( archive.write("CSCSamplingGraph/magic_num", kCSCSamplingGraphSerializeMagic);
"CSCSamplingGraph/magic_num", archive.write("CSCSamplingGraph/num_nodes", num_nodes_);
torch::IValue(kCSCSamplingGraphSerializeMagic));
archive.write("CSCSamplingGraph/num_nodes", torch::IValue(num_nodes_));
archive.write("CSCSamplingGraph/indptr", indptr_); archive.write("CSCSamplingGraph/indptr", indptr_);
archive.write("CSCSamplingGraph/indices", indices_); archive.write("CSCSamplingGraph/indices", indices_);
archive.write("CSCSamplingGraph/is_hetero", torch::IValue(IsHeterogeneous())); archive.write("CSCSamplingGraph/is_hetero", IsHeterogeneous());
if (IsHeterogeneous()) { if (IsHeterogeneous()) {
hetero_info_->Save(archive); hetero_info_->Save(archive);
} }
...@@ -105,21 +103,3 @@ void CSCSamplingGraph::Save(torch::serialize::OutputArchive& archive) const { ...@@ -105,21 +103,3 @@ void CSCSamplingGraph::Save(torch::serialize::OutputArchive& archive) const {
} // namespace sampling } // namespace sampling
} // namespace graphbolt } // namespace graphbolt
namespace torch {
serialize::InputArchive& operator>>(
serialize::InputArchive& archive,
graphbolt::sampling::CSCSamplingGraph& graph) {
graph.Load(archive);
return archive;
}
serialize::OutputArchive& operator<<(
serialize::OutputArchive& archive,
const graphbolt::sampling::CSCSamplingGraph& graph) {
graph.Save(archive);
return archive;
}
} // namespace torch
/**
* Copyright (c) 2023 by Contributors
* @file graphbolt/src/serialize.cc
* @brief Source file of serialize.
*/
#include <graphbolt/serialize.h>
namespace torch {
serialize::InputArchive& operator>>(
serialize::InputArchive& archive,
graphbolt::sampling::CSCSamplingGraph& graph) {
graph.Load(archive);
return archive;
}
serialize::OutputArchive& operator<<(
serialize::OutputArchive& archive,
const graphbolt::sampling::CSCSamplingGraph& graph) {
graph.Save(archive);
return archive;
}
} // namespace torch
namespace graphbolt {
c10::intrusive_ptr<sampling::CSCSamplingGraph> LoadCSCSamplingGraph(
const std::string& filename) {
auto&& graph = c10::make_intrusive<sampling::CSCSamplingGraph>();
torch::load(*graph, filename);
return graph;
}
void SaveCSCSamplingGraph(
c10::intrusive_ptr<sampling::CSCSamplingGraph> graph,
const std::string& filename) {
torch::save(*graph, filename);
}
torch::IValue read_from_archive(
torch::serialize::InputArchive& archive, const std::string& key) {
torch::IValue data;
archive.read(key, data);
return data;
}
} // namespace graphbolt
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment