Unverified Commit a5e5f11a authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] de-duplicate code for reading data from achive (#6761)

parent e181ef15
......@@ -57,14 +57,19 @@ inline serialize::OutputArchive& operator<<(
namespace graphbolt {
/**
* @brief Read data from archive.
* @brief Read data from archive and format to specified type.
* @param archive Input archive.
* @param key Key name of data.
*
* @return data.
*/
torch::IValue read_from_archive(
torch::serialize::InputArchive& archive, const std::string& key);
template <typename T>
T read_from_archive(
torch::serialize::InputArchive& archive, const std::string& key) {
torch::IValue data;
archive.read(key, data);
return data.to<T>();
}
} // namespace graphbolt
......
......@@ -109,91 +109,46 @@ c10::intrusive_ptr<FusedCSCSamplingGraph> FusedCSCSamplingGraph::Create(
void FusedCSCSamplingGraph::Load(torch::serialize::InputArchive& archive) {
const int64_t magic_num =
read_from_archive(archive, "FusedCSCSamplingGraph/magic_num").toInt();
read_from_archive<int64_t>(archive, "FusedCSCSamplingGraph/magic_num");
TORCH_CHECK(
magic_num == kCSCSamplingGraphSerializeMagic,
"Magic numbers mismatch when loading FusedCSCSamplingGraph.");
indptr_ =
read_from_archive(archive, "FusedCSCSamplingGraph/indptr").toTensor();
indices_ =
read_from_archive(archive, "FusedCSCSamplingGraph/indices").toTensor();
if (read_from_archive(archive, "FusedCSCSamplingGraph/has_node_type_offset")
.toBool()) {
node_type_offset_ =
read_from_archive(archive, "FusedCSCSamplingGraph/node_type_offset")
.toTensor();
}
if (read_from_archive(archive, "FusedCSCSamplingGraph/has_type_per_edge")
.toBool()) {
type_per_edge_ =
read_from_archive(archive, "FusedCSCSamplingGraph/type_per_edge")
.toTensor();
read_from_archive<torch::Tensor>(archive, "FusedCSCSamplingGraph/indptr");
indices_ = read_from_archive<torch::Tensor>(
archive, "FusedCSCSamplingGraph/indices");
if (read_from_archive<bool>(
archive, "FusedCSCSamplingGraph/has_node_type_offset")) {
node_type_offset_ = read_from_archive<torch::Tensor>(
archive, "FusedCSCSamplingGraph/node_type_offset");
}
if (read_from_archive<bool>(
archive, "FusedCSCSamplingGraph/has_type_per_edge")) {
type_per_edge_ = read_from_archive<torch::Tensor>(
archive, "FusedCSCSamplingGraph/type_per_edge");
}
if (read_from_archive(archive, "FusedCSCSamplingGraph/has_node_type_to_id")
.toBool()) {
torch::Dict<torch::IValue, torch::IValue> generic_dict =
read_from_archive(archive, "FusedCSCSamplingGraph/node_type_to_id")
.toGenericDict();
NodeTypeToIDMap node_type_to_id;
for (const auto& pair : generic_dict) {
std::string key = pair.key().toStringRef();
int64_t value = pair.value().toInt();
node_type_to_id.insert(std::move(key), value);
}
node_type_to_id_ = std::move(node_type_to_id);
if (read_from_archive<bool>(
archive, "FusedCSCSamplingGraph/has_node_type_to_id")) {
node_type_to_id_ = read_from_archive<NodeTypeToIDMap>(
archive, "FusedCSCSamplingGraph/node_type_to_id");
}
if (read_from_archive(archive, "FusedCSCSamplingGraph/has_edge_type_to_id")
.toBool()) {
torch::Dict<torch::IValue, torch::IValue> generic_dict =
read_from_archive(archive, "FusedCSCSamplingGraph/edge_type_to_id")
.toGenericDict();
EdgeTypeToIDMap edge_type_to_id;
for (const auto& pair : generic_dict) {
std::string key = pair.key().toStringRef();
int64_t value = pair.value().toInt();
edge_type_to_id.insert(std::move(key), value);
}
edge_type_to_id_ = std::move(edge_type_to_id);
if (read_from_archive<bool>(
archive, "FusedCSCSamplingGraph/has_edge_type_to_id")) {
edge_type_to_id_ = read_from_archive<EdgeTypeToIDMap>(
archive, "FusedCSCSamplingGraph/edge_type_to_id");
}
// Optional node attributes.
torch::IValue has_node_attributes;
if (archive.try_read(
"FusedCSCSamplingGraph/has_node_attributes", has_node_attributes) &&
has_node_attributes.toBool()) {
torch::Dict<torch::IValue, torch::IValue> generic_dict =
read_from_archive(archive, "FusedCSCSamplingGraph/node_attributes")
.toGenericDict();
NodeAttrMap target_dict;
for (const auto& pair : generic_dict) {
std::string key = pair.key().toStringRef();
torch::Tensor value = pair.value().toTensor();
// Use move to avoid copy.
target_dict.insert(std::move(key), std::move(value));
}
// Same as above.
node_attributes_ = std::move(target_dict);
}
// Optional edge attributes.
torch::IValue has_edge_attributes;
if (archive.try_read(
"FusedCSCSamplingGraph/has_edge_attributes", has_edge_attributes) &&
has_edge_attributes.toBool()) {
torch::Dict<torch::IValue, torch::IValue> generic_dict =
read_from_archive(archive, "FusedCSCSamplingGraph/edge_attributes")
.toGenericDict();
EdgeAttrMap target_dict;
for (const auto& pair : generic_dict) {
std::string key = pair.key().toStringRef();
torch::Tensor value = pair.value().toTensor();
// Use move to avoid copy.
target_dict.insert(std::move(key), std::move(value));
}
// Same as above.
edge_attributes_ = std::move(target_dict);
if (read_from_archive<bool>(
archive, "FusedCSCSamplingGraph/has_node_attributes")) {
node_attributes_ = read_from_archive<NodeAttrMap>(
archive, "FusedCSCSamplingGraph/node_attributes");
}
if (read_from_archive<bool>(
archive, "FusedCSCSamplingGraph/has_edge_attributes")) {
edge_attributes_ = read_from_archive<EdgeAttrMap>(
archive, "FusedCSCSamplingGraph/edge_attributes");
}
}
......
......@@ -24,14 +24,3 @@ serialize::OutputArchive& operator<<(
}
} // namespace torch
namespace graphbolt {
torch::IValue read_from_archive(
torch::serialize::InputArchive& archive, const std::string& key) {
torch::IValue data;
archive.read(key, data);
return data;
}
} // namespace graphbolt
......@@ -88,10 +88,10 @@ void SharedMemoryHelper::WriteTorchTensor(
torch::optional<torch::Tensor> SharedMemoryHelper::ReadTorchTensor() {
auto archive = this->ReadTorchArchive();
bool has_value = read_from_archive(archive, "has_value").toBool();
bool has_value = read_from_archive<bool>(archive, "has_value");
if (has_value) {
auto shape = read_from_archive(archive, "shape").toIntVector();
auto dtype = read_from_archive(archive, "dtype").toScalarType();
auto shape = read_from_archive<std::vector<int64_t>>(archive, "shape");
auto dtype = read_from_archive<torch::ScalarType>(archive, "dtype");
auto data_ptr = this->GetCurrentDataPtr();
auto tensor = torch::from_blob(data_ptr, shape, dtype);
auto rounded_size = GetRoundedSize(tensor.numel() * tensor.element_size());
......@@ -127,15 +127,14 @@ void SharedMemoryHelper::WriteTorchTensorDict(
torch::optional<torch::Dict<std::string, torch::Tensor>>
SharedMemoryHelper::ReadTorchTensorDict() {
auto archive = this->ReadTorchArchive();
if (!read_from_archive(archive, "has_value").toBool()) {
if (!read_from_archive<bool>(archive, "has_value")) {
return torch::nullopt;
}
int64_t num_tensors = read_from_archive(archive, "num_tensors").toInt();
int64_t num_tensors = read_from_archive<int64_t>(archive, "num_tensors");
torch::Dict<std::string, torch::Tensor> tensor_dict;
for (int64_t i = 0; i < num_tensors; ++i) {
auto key =
read_from_archive(archive, std::string("key_") + std::to_string(i))
.toStringRef();
auto key = read_from_archive<std::string>(
archive, std::string("key_") + std::to_string(i));
auto tensor = this->ReadTorchTensor();
tensor_dict.insert(key, tensor.value());
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment