/** * Copyright (c) 2023 by Contributors * @file graphbolt/include/serialize.h * @brief Utility functions for serialize and deserialize. */ #ifndef GRAPHBOLT_INCLUDE_SERIALIZE_H_ #define GRAPHBOLT_INCLUDE_SERIALIZE_H_ #include #include #include namespace graphbolt { namespace utils { /** * @brief Utility function to write to archive. * @param archive Output archive. * @param key Key name used in saving. * @param data Data that could be constructed as `torch::IValue`. **/ template void write_to_archive( torch::serialize::OutputArchive& archive, const std::string& key, const DataT& data) { archive.write(key, data); } /** * @brief Specialization utility function to save string vector. * @param archive Output archive. * @param key Key name used in saving. * @param data Vector of string. **/ template <> void write_to_archive>( torch::serialize::OutputArchive& archive, const std::string& key, const std::vector& data) { archive.write( key + "/size", torch::tensor(static_cast(data.size()))); for (const auto index : c10::irange(data.size())) { archive.write(key + "/" + std::to_string(index), data[index]); } } /** * @brief Utility function to read from archive. * @param archive Input archive. * @param key Key name used in reading. * @param data Data that could be constructed as `torch::IValue`. **/ template void read_from_archive( torch::serialize::InputArchive& archive, const std::string& key, DataT& data) { archive.read(key, data); } /** * @brief Specialization utility function to read from archive. * @param archive Input archive. * @param key Key name used in reading. * @param data Data that is `bool`. **/ template <> void read_from_archive( torch::serialize::InputArchive& archive, const std::string& key, bool& data) { torch::IValue iv_data; archive.read(key, iv_data); data = iv_data.toBool(); } /** * @brief Specialization utility function to read from archive. * @param archive Input archive. * @param key Key name used in reading. * @param data Data that is `int64_t`. **/ template <> void read_from_archive( torch::serialize::InputArchive& archive, const std::string& key, int64_t& data) { torch::IValue iv_data; archive.read(key, iv_data); data = iv_data.toInt(); } /** * @brief Specialization utility function to read from archive. * @param archive Input archive. * @param key Key name used in reading. * @param data Data that is `std::string`. **/ template <> void read_from_archive( torch::serialize::InputArchive& archive, const std::string& key, std::string& data) { torch::IValue iv_data; archive.read(key, iv_data); data = iv_data.toString(); } /** * @brief Specialization utility function to read from archive. * @param archive Input archive. * @param key Key name used in reading. * @param data Data that is `torch::Tensor`. **/ template <> void read_from_archive( torch::serialize::InputArchive& archive, const std::string& key, torch::Tensor& data) { torch::IValue iv_data; archive.read(key, iv_data); data = iv_data.toTensor(); } /** * @brief Specialization utility function to read to string vector. * @param archive Output archive. * @param key Key name used in saving. * @param data Vector of string. **/ template <> void read_from_archive>( torch::serialize::InputArchive& archive, const std::string& key, std::vector& data) { int64_t size = 0; read_from_archive(archive, key + "/size", size); data.resize(static_cast(size)); std::string element; for (int64_t index = 0; index < size; ++index) { read_from_archive( archive, key + "/" + std::to_string(index), element); data[index] = element; } } } // namespace utils } // namespace graphbolt #endif // GRAPHBOLT_INCLUDE_SERIALIZE_H_