Unverified Commit 72d16f78 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] optimize load/save logic (#5713)

parent 1e295c53
......@@ -96,6 +96,9 @@ struct HeteroInfo {
*/
class CSCSamplingGraph : public torch::CustomClassHolder {
public:
/** @brief Default constructor. */
CSCSamplingGraph() = default;
/**
* @brief Constructor for CSC with data.
* @param num_nodes The number of nodes in the graph.
......@@ -203,33 +206,4 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
} // namespace sampling
} // namespace graphbolt
/**
* @brief Overload stream operator to enable `torch::save()` and `torch.load()`
* for CSCSamplingGraph.
*/
namespace torch {
/**
* @brief Overload input stream operator for CSCSamplingGraph deserialization.
* @param archive Input stream for deserializing.
* @param graph CSCSamplingGraph.
*
* @return archive
*/
inline serialize::InputArchive& operator>>(
serialize::InputArchive& archive,
graphbolt::sampling::CSCSamplingGraph& graph);
/**
* @brief Overload output stream operator for CSCSamplingGraph serialization.
* @param archive Output stream for serializing.
* @param graph CSCSamplingGraph.
*
* @return archive
*/
inline serialize::OutputArchive& operator<<(
serialize::OutputArchive& archive,
const graphbolt::sampling::CSCSamplingGraph& graph);
} // namespace torch
#endif // GRAPHBOLT_CSC_SAMPLING_GRAPH_H_
......@@ -12,134 +12,69 @@
#include <string>
#include <vector>
namespace graphbolt {
namespace utils {
#include "csc_sampling_graph.h"
/**
* @brief Utility function to write to archive.
* @param archive Output archive.
* @param key Key name used in saving.
* @param data Data that could be constructed as `torch::IValue`.
* @brief Overload stream operator to enable `torch::save()` and `torch.load()`
* for CSCSamplingGraph.
*/
template <typename DataT>
void write_to_archive(
torch::serialize::OutputArchive& archive, const std::string& key,
const DataT& data) {
archive.write(key, data);
}
namespace torch {
/**
* @brief Specialization utility function to save string vector.
* @param archive Output archive.
* @param key Key name used in saving.
* @param data Vector of string.
* @brief Overload input stream operator for CSCSamplingGraph deserialization.
* @param archive Input stream for deserializing.
* @param graph CSCSamplingGraph.
*
* @return archive
*/
template <>
void write_to_archive<std::vector<std::string>>(
torch::serialize::OutputArchive& archive, const std::string& key,
const std::vector<std::string>& data) {
archive.write(
key + "/size", torch::tensor(static_cast<int64_t>(data.size())));
for (const auto index : c10::irange(data.size())) {
archive.write(key + "/" + std::to_string(index), data[index]);
}
}
inline serialize::InputArchive& operator>>(
serialize::InputArchive& archive,
graphbolt::sampling::CSCSamplingGraph& graph);
/**
* @brief Utility function to read from archive.
* @param archive Input archive.
* @param key Key name used in reading.
* @param data Data that could be constructed as `torch::IValue`.
* @brief Overload output stream operator for CSCSamplingGraph serialization.
* @param archive Output stream for serializing.
* @param graph CSCSamplingGraph.
*
* @return archive
*/
template <typename DataT = torch::IValue>
void read_from_archive(
torch::serialize::InputArchive& archive, const std::string& key,
DataT& data) {
archive.read(key, data);
}
inline serialize::OutputArchive& operator<<(
serialize::OutputArchive& archive,
const graphbolt::sampling::CSCSamplingGraph& graph);
/**
* @brief Specialization utility function to read from archive.
* @param archive Input archive.
* @param key Key name used in reading.
* @param data Data that is `bool`.
*/
template <>
void read_from_archive<bool>(
torch::serialize::InputArchive& archive, const std::string& key,
bool& data) {
torch::IValue iv_data;
archive.read(key, iv_data);
data = iv_data.toBool();
}
} // namespace torch
/**
* @brief Specialization utility function to read from archive.
* @param archive Input archive.
* @param key Key name used in reading.
* @param data Data that is `int64_t`.
*/
template <>
void read_from_archive<int64_t>(
torch::serialize::InputArchive& archive, const std::string& key,
int64_t& data) {
torch::IValue iv_data;
archive.read(key, iv_data);
data = iv_data.toInt();
}
namespace graphbolt {
/**
* @brief Specialization utility function to read from archive.
* @param archive Input archive.
* @param key Key name used in reading.
* @param data Data that is `std::string`.
* @brief Load CSCSamplingGraph from file.
* @param filename File name to read.
*
* @return CSCSamplingGraph.
*/
template <>
void read_from_archive<std::string>(
torch::serialize::InputArchive& archive, const std::string& key,
std::string& data) {
torch::IValue iv_data;
archive.read(key, iv_data);
data = iv_data.toString();
}
c10::intrusive_ptr<sampling::CSCSamplingGraph> LoadCSCSamplingGraph(
const std::string& filename);
/**
* @brief Specialization utility function to read from archive.
* @param archive Input archive.
* @param key Key name used in reading.
* @param data Data that is `torch::Tensor`.
* @brief Save CSCSamplingGraph to file.
* @param graph CSCSamplingGraph to save.
* @param filename File name to save.
*
*/
template <>
void read_from_archive<torch::Tensor>(
torch::serialize::InputArchive& archive, const std::string& key,
torch::Tensor& data) {
torch::IValue iv_data;
archive.read(key, iv_data);
data = iv_data.toTensor();
}
void SaveCSCSamplingGraph(
c10::intrusive_ptr<sampling::CSCSamplingGraph> graph,
const std::string& filename);
/**
* @brief Specialization utility function to read to string vector.
* @param archive Output archive.
* @param key Key name used in saving.
* @param data Vector of string.
* @brief Read data from archive.
* @param archive Input archive.
* @param key Key name of data.
*
* @return data.
*/
template <>
void read_from_archive<std::vector<std::string>>(
torch::serialize::InputArchive& archive, const std::string& key,
std::vector<std::string>& data) {
int64_t size = 0;
read_from_archive<int64_t>(archive, key + "/size", size);
data.resize(static_cast<size_t>(size));
std::string element;
for (int64_t index = 0; index < size; ++index) {
read_from_archive<std::string>(
archive, key + "/" + std::to_string(index), element);
data[index] = element;
}
}
torch::IValue read_from_archive(
torch::serialize::InputArchive& archive, const std::string& key);
} // namespace utils
} // namespace graphbolt
#endif // GRAPHBOLT_SERIALIZE_H_
......@@ -11,26 +11,27 @@ namespace graphbolt {
namespace sampling {
void HeteroInfo::Load(torch::serialize::InputArchive& archive) {
int64_t magic_num = 0x0;
utils::read_from_archive(archive, "HeteroInfo/magic_num", magic_num);
const int64_t magic_num =
read_from_archive(archive, "HeteroInfo/magic_num").toInt();
TORCH_CHECK(
magic_num == kHeteroInfoSerializeMagic,
"Magic numbers mismatch when loading HeteroInfo.");
utils::read_from_archive(archive, "HeteroInfo/node_types", node_types);
utils::read_from_archive(archive, "HeteroInfo/edge_types", edge_types);
utils::read_from_archive(
archive, "HeteroInfo/node_type_offset", node_type_offset);
utils::read_from_archive(archive, "HeteroInfo/type_per_edge", type_per_edge);
node_types = read_from_archive(archive, "HeteroInfo/node_types")
.to<decltype(node_types)>();
edge_types = read_from_archive(archive, "HeteroInfo/edge_types")
.to<decltype(edge_types)>();
node_type_offset =
read_from_archive(archive, "HeteroInfo/node_type_offset").toTensor();
type_per_edge =
read_from_archive(archive, "HeteroInfo/type_per_edge").toTensor();
}
void HeteroInfo::Save(torch::serialize::OutputArchive& archive) const {
utils::write_to_archive(
archive, "HeteroInfo/magic_num", kHeteroInfoSerializeMagic);
utils::write_to_archive(archive, "HeteroInfo/node_types", node_types);
utils::write_to_archive(archive, "HeteroInfo/edge_types", edge_types);
utils::write_to_archive(
archive, "HeteroInfo/node_type_offset", node_type_offset);
utils::write_to_archive(archive, "HeteroInfo/type_per_edge", type_per_edge);
archive.write("HeteroInfo/magic_num", kHeteroInfoSerializeMagic);
archive.write("HeteroInfo/node_types", node_types);
archive.write("HeteroInfo/edge_types", edge_types);
archive.write("HeteroInfo/node_type_offset", node_type_offset);
archive.write("HeteroInfo/type_per_edge", type_per_edge);
}
CSCSamplingGraph::CSCSamplingGraph(
......@@ -73,17 +74,16 @@ c10::intrusive_ptr<CSCSamplingGraph> CSCSamplingGraph::FromCSCWithHeteroInfo(
}
void CSCSamplingGraph::Load(torch::serialize::InputArchive& archive) {
int64_t magic_num = 0x0;
utils::read_from_archive(archive, "CSCSamplingGraph/magic_num", magic_num);
const int64_t magic_num =
read_from_archive(archive, "CSCSamplingGraph/magic_num").toInt();
TORCH_CHECK(
magic_num == kCSCSamplingGraphSerializeMagic,
"Magic numbers mismatch when loading CSCSamplingGraph.");
utils::read_from_archive(archive, "CSCSamplingGraph/num_nodes", num_nodes_);
utils::read_from_archive(archive, "CSCSamplingGraph/indptr", indptr_);
utils::read_from_archive(archive, "CSCSamplingGraph/indices", indices_);
bool is_heterogeneous = false;
utils::read_from_archive(
archive, "CSCSamplingGraph/is_hetero", is_heterogeneous);
num_nodes_ = read_from_archive(archive, "CSCSamplingGraph/num_nodes").toInt();
indptr_ = read_from_archive(archive, "CSCSamplingGraph/indptr").toTensor();
indices_ = read_from_archive(archive, "CSCSamplingGraph/indices").toTensor();
const bool is_heterogeneous =
read_from_archive(archive, "CSCSamplingGraph/is_hetero").toBool();
if (is_heterogeneous) {
hetero_info_ = std::make_shared<HeteroInfo>();
hetero_info_->Load(archive);
......@@ -91,13 +91,11 @@ void CSCSamplingGraph::Load(torch::serialize::InputArchive& archive) {
}
void CSCSamplingGraph::Save(torch::serialize::OutputArchive& archive) const {
archive.write(
"CSCSamplingGraph/magic_num",
torch::IValue(kCSCSamplingGraphSerializeMagic));
archive.write("CSCSamplingGraph/num_nodes", torch::IValue(num_nodes_));
archive.write("CSCSamplingGraph/magic_num", kCSCSamplingGraphSerializeMagic);
archive.write("CSCSamplingGraph/num_nodes", num_nodes_);
archive.write("CSCSamplingGraph/indptr", indptr_);
archive.write("CSCSamplingGraph/indices", indices_);
archive.write("CSCSamplingGraph/is_hetero", torch::IValue(IsHeterogeneous()));
archive.write("CSCSamplingGraph/is_hetero", IsHeterogeneous());
if (IsHeterogeneous()) {
hetero_info_->Save(archive);
}
......@@ -105,21 +103,3 @@ void CSCSamplingGraph::Save(torch::serialize::OutputArchive& archive) const {
} // namespace sampling
} // namespace graphbolt
namespace torch {
serialize::InputArchive& operator>>(
serialize::InputArchive& archive,
graphbolt::sampling::CSCSamplingGraph& graph) {
graph.Load(archive);
return archive;
}
serialize::OutputArchive& operator<<(
serialize::OutputArchive& archive,
const graphbolt::sampling::CSCSamplingGraph& graph) {
graph.Save(archive);
return archive;
}
} // namespace torch
/**
* Copyright (c) 2023 by Contributors
* @file graphbolt/src/serialize.cc
* @brief Source file of serialize.
*/
#include <graphbolt/serialize.h>
namespace torch {
serialize::InputArchive& operator>>(
serialize::InputArchive& archive,
graphbolt::sampling::CSCSamplingGraph& graph) {
graph.Load(archive);
return archive;
}
serialize::OutputArchive& operator<<(
serialize::OutputArchive& archive,
const graphbolt::sampling::CSCSamplingGraph& graph) {
graph.Save(archive);
return archive;
}
} // namespace torch
namespace graphbolt {
c10::intrusive_ptr<sampling::CSCSamplingGraph> LoadCSCSamplingGraph(
const std::string& filename) {
auto&& graph = c10::make_intrusive<sampling::CSCSamplingGraph>();
torch::load(*graph, filename);
return graph;
}
void SaveCSCSamplingGraph(
c10::intrusive_ptr<sampling::CSCSamplingGraph> graph,
const std::string& filename) {
torch::save(*graph, filename);
}
torch::IValue read_from_archive(
torch::serialize::InputArchive& archive, const std::string& key) {
torch::IValue data;
archive.read(key, data);
return data;
}
} // namespace graphbolt
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment