Unverified Commit 77fd747f authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] enable load/save for CSCSamplingGraph (#5702)

parent 46af76c3
...@@ -49,6 +49,9 @@ struct HeteroInfo { ...@@ -49,6 +49,9 @@ struct HeteroInfo {
node_type_offset(node_type_offset), node_type_offset(node_type_offset),
type_per_edge(type_per_edge) {} type_per_edge(type_per_edge) {}
/** @brief Default constructor. */
HeteroInfo() = default;
/** @brief List of node types in the graph.*/ /** @brief List of node types in the graph.*/
StringList node_types; StringList node_types;
...@@ -66,6 +69,24 @@ struct HeteroInfo { ...@@ -66,6 +69,24 @@ struct HeteroInfo {
* edge_types. The length of it is equal to the number of edges. * edge_types. The length of it is equal to the number of edges.
*/ */
torch::Tensor type_per_edge; torch::Tensor type_per_edge;
/**
** @brief Magic number to indicate Hetero info version in serialize/
** deserialize stages.
**/
static constexpr int64_t kHeteroInfoSerializeMagic = 0xDD2E60F0F6B4A129;
/**
** @brief Load hetero info from stream.
** @param archive Input stream for deserializing.
**/
void Load(torch::serialize::InputArchive& archive);
/**
** @brief Save hetero info to stream.
** @param archive Output stream for serializing.
**/
void Save(torch::serialize::OutputArchive& archive) const;
}; };
/** /**
...@@ -148,6 +169,24 @@ class CSCSamplingGraph : public torch::CustomClassHolder { ...@@ -148,6 +169,24 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
return hetero_info_->type_per_edge; return hetero_info_->type_per_edge;
} }
/**
** @brief Magic number to indicate graph version in serialize/deserialize
** stage.
**/
static constexpr int64_t kCSCSamplingGraphSerializeMagic = 0xDD2E60F0F6B4A128;
/**
** @brief Load graph from stream.
** @param archive Input stream for deserializing.
**/
void Load(torch::serialize::InputArchive& archive);
/**
** @brief Save graph to stream.
** @param archive Output stream for serializing.
**/
void Save(torch::serialize::OutputArchive& archive) const;
private: private:
/** @brief The number of nodes of the graph. */ /** @brief The number of nodes of the graph. */
int64_t num_nodes_ = 0; int64_t num_nodes_ = 0;
...@@ -161,3 +200,33 @@ class CSCSamplingGraph : public torch::CustomClassHolder { ...@@ -161,3 +200,33 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
} // namespace sampling } // namespace sampling
} // namespace graphbolt } // namespace graphbolt
/**
** @brief Overload stream operator to enable `torch::save()` and `torch.load()`
** for CSCSamplingGraph.
**/
namespace torch {
/**
** @brief Overload input stream operator for CSCSamplingGraph deserialization.
** @param archive Input stream for deserializing.
** @param graph CSCSamplingGraph.
**
** @return archive
**/
inline serialize::InputArchive& operator>>(
serialize::InputArchive& archive,
graphbolt::sampling::CSCSamplingGraph& graph);
/**
* @brief Overload output stream operator for CSCSamplingGraph serialization.
* @param archive Output stream for serializing.
* @param graph CSCSamplingGraph.
*
* @return archive
*/
inline serialize::OutputArchive& operator<<(
serialize::OutputArchive& archive,
const graphbolt::sampling::CSCSamplingGraph& graph);
} // namespace torch
/**
* Copyright (c) 2023 by Contributors
* @file graphbolt/include/serialize.h
* @brief Utility functions for serialize and deserialize.
*/
#ifndef GRAPHBOLT_INCLUDE_SERIALIZE_H_
#define GRAPHBOLT_INCLUDE_SERIALIZE_H_
#include <torch/torch.h>
#include <string>
#include <vector>
namespace graphbolt {
namespace utils {
/**
* @brief Utility function to write to archive.
* @param archive Output archive.
* @param key Key name used in saving.
* @param data Data that could be constructed as `torch::IValue`.
**/
template <typename DataT>
void write_to_archive(
torch::serialize::OutputArchive& archive, const std::string& key,
const DataT& data) {
archive.write(key, data);
}
/**
* @brief Specialization utility function to save string vector.
* @param archive Output archive.
* @param key Key name used in saving.
* @param data Vector of string.
**/
template <>
void write_to_archive<std::vector<std::string>>(
torch::serialize::OutputArchive& archive, const std::string& key,
const std::vector<std::string>& data) {
archive.write(
key + "/size", torch::tensor(static_cast<int64_t>(data.size())));
for (const auto index : c10::irange(data.size())) {
archive.write(key + "/" + std::to_string(index), data[index]);
}
}
/**
* @brief Utility function to read from archive.
* @param archive Input archive.
* @param key Key name used in reading.
* @param data Data that could be constructed as `torch::IValue`.
**/
template <typename DataT = torch::IValue>
void read_from_archive(
torch::serialize::InputArchive& archive, const std::string& key,
DataT& data) {
archive.read(key, data);
}
/**
* @brief Specialization utility function to read from archive.
* @param archive Input archive.
* @param key Key name used in reading.
* @param data Data that is `bool`.
**/
template <>
void read_from_archive<bool>(
torch::serialize::InputArchive& archive, const std::string& key,
bool& data) {
torch::IValue iv_data;
archive.read(key, iv_data);
data = iv_data.toBool();
}
/**
* @brief Specialization utility function to read from archive.
* @param archive Input archive.
* @param key Key name used in reading.
* @param data Data that is `int64_t`.
**/
template <>
void read_from_archive<int64_t>(
torch::serialize::InputArchive& archive, const std::string& key,
int64_t& data) {
torch::IValue iv_data;
archive.read(key, iv_data);
data = iv_data.toInt();
}
/**
* @brief Specialization utility function to read from archive.
* @param archive Input archive.
* @param key Key name used in reading.
* @param data Data that is `std::string`.
**/
template <>
void read_from_archive<std::string>(
torch::serialize::InputArchive& archive, const std::string& key,
std::string& data) {
torch::IValue iv_data;
archive.read(key, iv_data);
data = iv_data.toString();
}
/**
* @brief Specialization utility function to read from archive.
* @param archive Input archive.
* @param key Key name used in reading.
* @param data Data that is `torch::Tensor`.
**/
template <>
void read_from_archive<torch::Tensor>(
torch::serialize::InputArchive& archive, const std::string& key,
torch::Tensor& data) {
torch::IValue iv_data;
archive.read(key, iv_data);
data = iv_data.toTensor();
}
/**
* @brief Specialization utility function to read to string vector.
* @param archive Output archive.
* @param key Key name used in saving.
* @param data Vector of string.
**/
template <>
void read_from_archive<std::vector<std::string>>(
torch::serialize::InputArchive& archive, const std::string& key,
std::vector<std::string>& data) {
int64_t size = 0;
read_from_archive<int64_t>(archive, key + "/size", size);
data.resize(static_cast<size_t>(size));
std::string element;
for (int64_t index = 0; index < size; ++index) {
read_from_archive<std::string>(
archive, key + "/" + std::to_string(index), element);
data[index] = element;
}
}
} // namespace utils
} // namespace graphbolt
#endif // GRAPHBOLT_INCLUDE_SERIALIZE_H_
...@@ -6,9 +6,34 @@ ...@@ -6,9 +6,34 @@
#include "csc_sampling_graph.h" #include "csc_sampling_graph.h"
#include "serialize.h"
namespace graphbolt { namespace graphbolt {
namespace sampling { namespace sampling {
void HeteroInfo::Load(torch::serialize::InputArchive& archive) {
int64_t magic_num = 0x0;
utils::read_from_archive(archive, "HeteroInfo/magic_num", magic_num);
TORCH_CHECK(
magic_num == kHeteroInfoSerializeMagic,
"Magic numbers mismatch when loading HeteroInfo.");
utils::read_from_archive(archive, "HeteroInfo/node_types", node_types);
utils::read_from_archive(archive, "HeteroInfo/edge_types", edge_types);
utils::read_from_archive(
archive, "HeteroInfo/node_type_offset", node_type_offset);
utils::read_from_archive(archive, "HeteroInfo/type_per_edge", type_per_edge);
}
void HeteroInfo::Save(torch::serialize::OutputArchive& archive) const {
utils::write_to_archive(
archive, "HeteroInfo/magic_num", kHeteroInfoSerializeMagic);
utils::write_to_archive(archive, "HeteroInfo/node_types", node_types);
utils::write_to_archive(archive, "HeteroInfo/edge_types", edge_types);
utils::write_to_archive(
archive, "HeteroInfo/node_type_offset", node_type_offset);
utils::write_to_archive(archive, "HeteroInfo/type_per_edge", type_per_edge);
}
CSCSamplingGraph::CSCSamplingGraph( CSCSamplingGraph::CSCSamplingGraph(
int64_t num_nodes, torch::Tensor& indptr, torch::Tensor& indices, int64_t num_nodes, torch::Tensor& indptr, torch::Tensor& indices,
const std::shared_ptr<HeteroInfo>& hetero_info) const std::shared_ptr<HeteroInfo>& hetero_info)
...@@ -48,5 +73,54 @@ c10::intrusive_ptr<CSCSamplingGraph> CSCSamplingGraph::FromCSCWithHeteroInfo( ...@@ -48,5 +73,54 @@ c10::intrusive_ptr<CSCSamplingGraph> CSCSamplingGraph::FromCSCWithHeteroInfo(
num_nodes, indptr, indices, hetero_info); num_nodes, indptr, indices, hetero_info);
} }
void CSCSamplingGraph::Load(torch::serialize::InputArchive& archive) {
int64_t magic_num = 0x0;
utils::read_from_archive(archive, "CSCSamplingGraph/magic_num", magic_num);
TORCH_CHECK(
magic_num == kCSCSamplingGraphSerializeMagic,
"Magic numbers mismatch when loading CSCSamplingGraph.");
utils::read_from_archive(archive, "CSCSamplingGraph/num_nodes", num_nodes_);
utils::read_from_archive(archive, "CSCSamplingGraph/indptr", indptr_);
utils::read_from_archive(archive, "CSCSamplingGraph/indices", indices_);
bool is_heterogeneous = false;
utils::read_from_archive(
archive, "CSCSamplingGraph/is_hetero", is_heterogeneous);
if (is_heterogeneous) {
hetero_info_ = std::make_shared<HeteroInfo>();
hetero_info_->Load(archive);
}
}
void CSCSamplingGraph::Save(torch::serialize::OutputArchive& archive) const {
archive.write(
"CSCSamplingGraph/magic_num",
torch::IValue(kCSCSamplingGraphSerializeMagic));
archive.write("CSCSamplingGraph/num_nodes", torch::IValue(num_nodes_));
archive.write("CSCSamplingGraph/indptr", indptr_);
archive.write("CSCSamplingGraph/indices", indices_);
archive.write("CSCSamplingGraph/is_hetero", torch::IValue(IsHeterogeneous()));
if (IsHeterogeneous()) {
hetero_info_->Save(archive);
}
}
} // namespace sampling } // namespace sampling
} // namespace graphbolt } // namespace graphbolt
namespace torch {
serialize::InputArchive& operator>>(
serialize::InputArchive& archive,
graphbolt::sampling::CSCSamplingGraph& graph) {
graph.Load(archive);
return archive;
}
serialize::OutputArchive& operator<<(
serialize::OutputArchive& archive,
const graphbolt::sampling::CSCSamplingGraph& graph) {
graph.Save(archive);
return archive;
}
} // namespace torch
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment