"dgl_sparse/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "a23b490df468cf86c94c1d05ccb864beea3ec2a8"
Unverified Commit a5e5f11a authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] de-duplicate code for reading data from achive (#6761)

parent e181ef15
...@@ -57,14 +57,19 @@ inline serialize::OutputArchive& operator<<( ...@@ -57,14 +57,19 @@ inline serialize::OutputArchive& operator<<(
namespace graphbolt { namespace graphbolt {
/** /**
* @brief Read data from archive. * @brief Read data from archive and format to specified type.
* @param archive Input archive. * @param archive Input archive.
* @param key Key name of data. * @param key Key name of data.
* *
* @return data. * @return data.
*/ */
torch::IValue read_from_archive( template <typename T>
torch::serialize::InputArchive& archive, const std::string& key); T read_from_archive(
torch::serialize::InputArchive& archive, const std::string& key) {
torch::IValue data;
archive.read(key, data);
return data.to<T>();
}
} // namespace graphbolt } // namespace graphbolt
......
...@@ -109,91 +109,46 @@ c10::intrusive_ptr<FusedCSCSamplingGraph> FusedCSCSamplingGraph::Create( ...@@ -109,91 +109,46 @@ c10::intrusive_ptr<FusedCSCSamplingGraph> FusedCSCSamplingGraph::Create(
void FusedCSCSamplingGraph::Load(torch::serialize::InputArchive& archive) { void FusedCSCSamplingGraph::Load(torch::serialize::InputArchive& archive) {
const int64_t magic_num = const int64_t magic_num =
read_from_archive(archive, "FusedCSCSamplingGraph/magic_num").toInt(); read_from_archive<int64_t>(archive, "FusedCSCSamplingGraph/magic_num");
TORCH_CHECK( TORCH_CHECK(
magic_num == kCSCSamplingGraphSerializeMagic, magic_num == kCSCSamplingGraphSerializeMagic,
"Magic numbers mismatch when loading FusedCSCSamplingGraph."); "Magic numbers mismatch when loading FusedCSCSamplingGraph.");
indptr_ = indptr_ =
read_from_archive(archive, "FusedCSCSamplingGraph/indptr").toTensor(); read_from_archive<torch::Tensor>(archive, "FusedCSCSamplingGraph/indptr");
indices_ = indices_ = read_from_archive<torch::Tensor>(
read_from_archive(archive, "FusedCSCSamplingGraph/indices").toTensor(); archive, "FusedCSCSamplingGraph/indices");
if (read_from_archive(archive, "FusedCSCSamplingGraph/has_node_type_offset") if (read_from_archive<bool>(
.toBool()) { archive, "FusedCSCSamplingGraph/has_node_type_offset")) {
node_type_offset_ = node_type_offset_ = read_from_archive<torch::Tensor>(
read_from_archive(archive, "FusedCSCSamplingGraph/node_type_offset") archive, "FusedCSCSamplingGraph/node_type_offset");
.toTensor();
} }
if (read_from_archive(archive, "FusedCSCSamplingGraph/has_type_per_edge") if (read_from_archive<bool>(
.toBool()) { archive, "FusedCSCSamplingGraph/has_type_per_edge")) {
type_per_edge_ = type_per_edge_ = read_from_archive<torch::Tensor>(
read_from_archive(archive, "FusedCSCSamplingGraph/type_per_edge") archive, "FusedCSCSamplingGraph/type_per_edge");
.toTensor();
} }
if (read_from_archive(archive, "FusedCSCSamplingGraph/has_node_type_to_id") if (read_from_archive<bool>(
.toBool()) { archive, "FusedCSCSamplingGraph/has_node_type_to_id")) {
torch::Dict<torch::IValue, torch::IValue> generic_dict = node_type_to_id_ = read_from_archive<NodeTypeToIDMap>(
read_from_archive(archive, "FusedCSCSamplingGraph/node_type_to_id") archive, "FusedCSCSamplingGraph/node_type_to_id");
.toGenericDict();
NodeTypeToIDMap node_type_to_id;
for (const auto& pair : generic_dict) {
std::string key = pair.key().toStringRef();
int64_t value = pair.value().toInt();
node_type_to_id.insert(std::move(key), value);
}
node_type_to_id_ = std::move(node_type_to_id);
} }
if (read_from_archive(archive, "FusedCSCSamplingGraph/has_edge_type_to_id") if (read_from_archive<bool>(
.toBool()) { archive, "FusedCSCSamplingGraph/has_edge_type_to_id")) {
torch::Dict<torch::IValue, torch::IValue> generic_dict = edge_type_to_id_ = read_from_archive<EdgeTypeToIDMap>(
read_from_archive(archive, "FusedCSCSamplingGraph/edge_type_to_id") archive, "FusedCSCSamplingGraph/edge_type_to_id");
.toGenericDict();
EdgeTypeToIDMap edge_type_to_id;
for (const auto& pair : generic_dict) {
std::string key = pair.key().toStringRef();
int64_t value = pair.value().toInt();
edge_type_to_id.insert(std::move(key), value);
}
edge_type_to_id_ = std::move(edge_type_to_id);
} }
// Optional node attributes. if (read_from_archive<bool>(
torch::IValue has_node_attributes; archive, "FusedCSCSamplingGraph/has_node_attributes")) {
if (archive.try_read( node_attributes_ = read_from_archive<NodeAttrMap>(
"FusedCSCSamplingGraph/has_node_attributes", has_node_attributes) && archive, "FusedCSCSamplingGraph/node_attributes");
has_node_attributes.toBool()) {
torch::Dict<torch::IValue, torch::IValue> generic_dict =
read_from_archive(archive, "FusedCSCSamplingGraph/node_attributes")
.toGenericDict();
NodeAttrMap target_dict;
for (const auto& pair : generic_dict) {
std::string key = pair.key().toStringRef();
torch::Tensor value = pair.value().toTensor();
// Use move to avoid copy.
target_dict.insert(std::move(key), std::move(value));
}
// Same as above.
node_attributes_ = std::move(target_dict);
} }
if (read_from_archive<bool>(
// Optional edge attributes. archive, "FusedCSCSamplingGraph/has_edge_attributes")) {
torch::IValue has_edge_attributes; edge_attributes_ = read_from_archive<EdgeAttrMap>(
if (archive.try_read( archive, "FusedCSCSamplingGraph/edge_attributes");
"FusedCSCSamplingGraph/has_edge_attributes", has_edge_attributes) &&
has_edge_attributes.toBool()) {
torch::Dict<torch::IValue, torch::IValue> generic_dict =
read_from_archive(archive, "FusedCSCSamplingGraph/edge_attributes")
.toGenericDict();
EdgeAttrMap target_dict;
for (const auto& pair : generic_dict) {
std::string key = pair.key().toStringRef();
torch::Tensor value = pair.value().toTensor();
// Use move to avoid copy.
target_dict.insert(std::move(key), std::move(value));
}
// Same as above.
edge_attributes_ = std::move(target_dict);
} }
} }
......
...@@ -24,14 +24,3 @@ serialize::OutputArchive& operator<<( ...@@ -24,14 +24,3 @@ serialize::OutputArchive& operator<<(
} }
} // namespace torch } // namespace torch
namespace graphbolt {
torch::IValue read_from_archive(
torch::serialize::InputArchive& archive, const std::string& key) {
torch::IValue data;
archive.read(key, data);
return data;
}
} // namespace graphbolt
...@@ -88,10 +88,10 @@ void SharedMemoryHelper::WriteTorchTensor( ...@@ -88,10 +88,10 @@ void SharedMemoryHelper::WriteTorchTensor(
torch::optional<torch::Tensor> SharedMemoryHelper::ReadTorchTensor() { torch::optional<torch::Tensor> SharedMemoryHelper::ReadTorchTensor() {
auto archive = this->ReadTorchArchive(); auto archive = this->ReadTorchArchive();
bool has_value = read_from_archive(archive, "has_value").toBool(); bool has_value = read_from_archive<bool>(archive, "has_value");
if (has_value) { if (has_value) {
auto shape = read_from_archive(archive, "shape").toIntVector(); auto shape = read_from_archive<std::vector<int64_t>>(archive, "shape");
auto dtype = read_from_archive(archive, "dtype").toScalarType(); auto dtype = read_from_archive<torch::ScalarType>(archive, "dtype");
auto data_ptr = this->GetCurrentDataPtr(); auto data_ptr = this->GetCurrentDataPtr();
auto tensor = torch::from_blob(data_ptr, shape, dtype); auto tensor = torch::from_blob(data_ptr, shape, dtype);
auto rounded_size = GetRoundedSize(tensor.numel() * tensor.element_size()); auto rounded_size = GetRoundedSize(tensor.numel() * tensor.element_size());
...@@ -127,15 +127,14 @@ void SharedMemoryHelper::WriteTorchTensorDict( ...@@ -127,15 +127,14 @@ void SharedMemoryHelper::WriteTorchTensorDict(
torch::optional<torch::Dict<std::string, torch::Tensor>> torch::optional<torch::Dict<std::string, torch::Tensor>>
SharedMemoryHelper::ReadTorchTensorDict() { SharedMemoryHelper::ReadTorchTensorDict() {
auto archive = this->ReadTorchArchive(); auto archive = this->ReadTorchArchive();
if (!read_from_archive(archive, "has_value").toBool()) { if (!read_from_archive<bool>(archive, "has_value")) {
return torch::nullopt; return torch::nullopt;
} }
int64_t num_tensors = read_from_archive(archive, "num_tensors").toInt(); int64_t num_tensors = read_from_archive<int64_t>(archive, "num_tensors");
torch::Dict<std::string, torch::Tensor> tensor_dict; torch::Dict<std::string, torch::Tensor> tensor_dict;
for (int64_t i = 0; i < num_tensors; ++i) { for (int64_t i = 0; i < num_tensors; ++i) {
auto key = auto key = read_from_archive<std::string>(
read_from_archive(archive, std::string("key_") + std::to_string(i)) archive, std::string("key_") + std::to_string(i));
.toStringRef();
auto tensor = this->ReadTorchTensor(); auto tensor = this->ReadTorchTensor();
tensor_dict.insert(key, tensor.value()); tensor_dict.insert(key, tensor.value());
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment