Unverified Commit 77fd747f authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] enable load/save for CSCSamplingGraph (#5702)

parent 46af76c3
......@@ -49,6 +49,9 @@ struct HeteroInfo {
node_type_offset(node_type_offset),
type_per_edge(type_per_edge) {}
/** @brief Default constructor. */
HeteroInfo() = default;
/** @brief List of node types in the graph.*/
StringList node_types;
......@@ -66,6 +69,24 @@ struct HeteroInfo {
* edge_types. The length of it is equal to the number of edges.
*/
torch::Tensor type_per_edge;
/**
** @brief Magic number to indicate Hetero info version in serialize/
** deserialize stages.
**/
static constexpr int64_t kHeteroInfoSerializeMagic = 0xDD2E60F0F6B4A129;
/**
** @brief Load hetero info from stream.
** @param archive Input stream for deserializing.
**/
void Load(torch::serialize::InputArchive& archive);
/**
** @brief Save hetero info to stream.
** @param archive Output stream for serializing.
**/
void Save(torch::serialize::OutputArchive& archive) const;
};
/**
......@@ -148,6 +169,24 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
return hetero_info_->type_per_edge;
}
/**
** @brief Magic number to indicate graph version in serialize/deserialize
** stage.
**/
static constexpr int64_t kCSCSamplingGraphSerializeMagic = 0xDD2E60F0F6B4A128;
/**
** @brief Load graph from stream.
** @param archive Input stream for deserializing.
**/
void Load(torch::serialize::InputArchive& archive);
/**
** @brief Save graph to stream.
** @param archive Output stream for serializing.
**/
void Save(torch::serialize::OutputArchive& archive) const;
private:
/** @brief The number of nodes of the graph. */
int64_t num_nodes_ = 0;
......@@ -161,3 +200,33 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
} // namespace sampling
} // namespace graphbolt
/**
** @brief Overload stream operator to enable `torch::save()` and `torch.load()`
** for CSCSamplingGraph.
**/
namespace torch {
/**
** @brief Overload input stream operator for CSCSamplingGraph deserialization.
** @param archive Input stream for deserializing.
** @param graph CSCSamplingGraph.
**
** @return archive
**/
inline serialize::InputArchive& operator>>(
serialize::InputArchive& archive,
graphbolt::sampling::CSCSamplingGraph& graph);
/**
* @brief Overload output stream operator for CSCSamplingGraph serialization.
* @param archive Output stream for serializing.
* @param graph CSCSamplingGraph.
*
* @return archive
*/
inline serialize::OutputArchive& operator<<(
serialize::OutputArchive& archive,
const graphbolt::sampling::CSCSamplingGraph& graph);
} // namespace torch
/**
* Copyright (c) 2023 by Contributors
* @file graphbolt/include/serialize.h
* @brief Utility functions for serialize and deserialize.
*/
#ifndef GRAPHBOLT_INCLUDE_SERIALIZE_H_
#define GRAPHBOLT_INCLUDE_SERIALIZE_H_
#include <torch/torch.h>
#include <string>
#include <vector>
namespace graphbolt {
namespace utils {
/**
* @brief Utility function to write to archive.
* @param archive Output archive.
* @param key Key name used in saving.
* @param data Data that could be constructed as `torch::IValue`.
**/
template <typename DataT>
void write_to_archive(
torch::serialize::OutputArchive& archive, const std::string& key,
const DataT& data) {
archive.write(key, data);
}
/**
* @brief Specialization utility function to save string vector.
* @param archive Output archive.
* @param key Key name used in saving.
* @param data Vector of string.
**/
template <>
void write_to_archive<std::vector<std::string>>(
torch::serialize::OutputArchive& archive, const std::string& key,
const std::vector<std::string>& data) {
archive.write(
key + "/size", torch::tensor(static_cast<int64_t>(data.size())));
for (const auto index : c10::irange(data.size())) {
archive.write(key + "/" + std::to_string(index), data[index]);
}
}
/**
* @brief Utility function to read from archive.
* @param archive Input archive.
* @param key Key name used in reading.
* @param data Data that could be constructed as `torch::IValue`.
**/
template <typename DataT = torch::IValue>
void read_from_archive(
torch::serialize::InputArchive& archive, const std::string& key,
DataT& data) {
archive.read(key, data);
}
/**
* @brief Specialization utility function to read from archive.
* @param archive Input archive.
* @param key Key name used in reading.
* @param data Data that is `bool`.
**/
template <>
void read_from_archive<bool>(
torch::serialize::InputArchive& archive, const std::string& key,
bool& data) {
torch::IValue iv_data;
archive.read(key, iv_data);
data = iv_data.toBool();
}
/**
* @brief Specialization utility function to read from archive.
* @param archive Input archive.
* @param key Key name used in reading.
* @param data Data that is `int64_t`.
**/
template <>
void read_from_archive<int64_t>(
torch::serialize::InputArchive& archive, const std::string& key,
int64_t& data) {
torch::IValue iv_data;
archive.read(key, iv_data);
data = iv_data.toInt();
}
/**
* @brief Specialization utility function to read from archive.
* @param archive Input archive.
* @param key Key name used in reading.
* @param data Data that is `std::string`.
**/
template <>
void read_from_archive<std::string>(
torch::serialize::InputArchive& archive, const std::string& key,
std::string& data) {
torch::IValue iv_data;
archive.read(key, iv_data);
data = iv_data.toString();
}
/**
* @brief Specialization utility function to read from archive.
* @param archive Input archive.
* @param key Key name used in reading.
* @param data Data that is `torch::Tensor`.
**/
template <>
void read_from_archive<torch::Tensor>(
torch::serialize::InputArchive& archive, const std::string& key,
torch::Tensor& data) {
torch::IValue iv_data;
archive.read(key, iv_data);
data = iv_data.toTensor();
}
/**
* @brief Specialization utility function to read to string vector.
* @param archive Output archive.
* @param key Key name used in saving.
* @param data Vector of string.
**/
template <>
void read_from_archive<std::vector<std::string>>(
torch::serialize::InputArchive& archive, const std::string& key,
std::vector<std::string>& data) {
int64_t size = 0;
read_from_archive<int64_t>(archive, key + "/size", size);
data.resize(static_cast<size_t>(size));
std::string element;
for (int64_t index = 0; index < size; ++index) {
read_from_archive<std::string>(
archive, key + "/" + std::to_string(index), element);
data[index] = element;
}
}
} // namespace utils
} // namespace graphbolt
#endif // GRAPHBOLT_INCLUDE_SERIALIZE_H_
......@@ -6,9 +6,34 @@
#include "csc_sampling_graph.h"
#include "serialize.h"
namespace graphbolt {
namespace sampling {
void HeteroInfo::Load(torch::serialize::InputArchive& archive) {
int64_t magic_num = 0x0;
utils::read_from_archive(archive, "HeteroInfo/magic_num", magic_num);
TORCH_CHECK(
magic_num == kHeteroInfoSerializeMagic,
"Magic numbers mismatch when loading HeteroInfo.");
utils::read_from_archive(archive, "HeteroInfo/node_types", node_types);
utils::read_from_archive(archive, "HeteroInfo/edge_types", edge_types);
utils::read_from_archive(
archive, "HeteroInfo/node_type_offset", node_type_offset);
utils::read_from_archive(archive, "HeteroInfo/type_per_edge", type_per_edge);
}
void HeteroInfo::Save(torch::serialize::OutputArchive& archive) const {
utils::write_to_archive(
archive, "HeteroInfo/magic_num", kHeteroInfoSerializeMagic);
utils::write_to_archive(archive, "HeteroInfo/node_types", node_types);
utils::write_to_archive(archive, "HeteroInfo/edge_types", edge_types);
utils::write_to_archive(
archive, "HeteroInfo/node_type_offset", node_type_offset);
utils::write_to_archive(archive, "HeteroInfo/type_per_edge", type_per_edge);
}
CSCSamplingGraph::CSCSamplingGraph(
int64_t num_nodes, torch::Tensor& indptr, torch::Tensor& indices,
const std::shared_ptr<HeteroInfo>& hetero_info)
......@@ -48,5 +73,54 @@ c10::intrusive_ptr<CSCSamplingGraph> CSCSamplingGraph::FromCSCWithHeteroInfo(
num_nodes, indptr, indices, hetero_info);
}
void CSCSamplingGraph::Load(torch::serialize::InputArchive& archive) {
int64_t magic_num = 0x0;
utils::read_from_archive(archive, "CSCSamplingGraph/magic_num", magic_num);
TORCH_CHECK(
magic_num == kCSCSamplingGraphSerializeMagic,
"Magic numbers mismatch when loading CSCSamplingGraph.");
utils::read_from_archive(archive, "CSCSamplingGraph/num_nodes", num_nodes_);
utils::read_from_archive(archive, "CSCSamplingGraph/indptr", indptr_);
utils::read_from_archive(archive, "CSCSamplingGraph/indices", indices_);
bool is_heterogeneous = false;
utils::read_from_archive(
archive, "CSCSamplingGraph/is_hetero", is_heterogeneous);
if (is_heterogeneous) {
hetero_info_ = std::make_shared<HeteroInfo>();
hetero_info_->Load(archive);
}
}
void CSCSamplingGraph::Save(torch::serialize::OutputArchive& archive) const {
archive.write(
"CSCSamplingGraph/magic_num",
torch::IValue(kCSCSamplingGraphSerializeMagic));
archive.write("CSCSamplingGraph/num_nodes", torch::IValue(num_nodes_));
archive.write("CSCSamplingGraph/indptr", indptr_);
archive.write("CSCSamplingGraph/indices", indices_);
archive.write("CSCSamplingGraph/is_hetero", torch::IValue(IsHeterogeneous()));
if (IsHeterogeneous()) {
hetero_info_->Save(archive);
}
}
} // namespace sampling
} // namespace graphbolt
namespace torch {
serialize::InputArchive& operator>>(
serialize::InputArchive& archive,
graphbolt::sampling::CSCSamplingGraph& graph) {
graph.Load(archive);
return archive;
}
serialize::OutputArchive& operator<<(
serialize::OutputArchive& archive,
const graphbolt::sampling::CSCSamplingGraph& graph) {
graph.Save(archive);
return archive;
}
} // namespace torch
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment