Unverified Commit 8a852530 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] support node/edge_type_to_id in shared memory (#6693)

parent a19d5f3b
...@@ -378,9 +378,9 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder { ...@@ -378,9 +378,9 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
* @brief Maximum number of bytes used to serialize the metadata of the * @brief Maximum number of bytes used to serialize the metadata of the
* member tensors, including tensor shape and dtype. The constant is estimated * member tensors, including tensor shape and dtype. The constant is estimated
* by multiplying the number of tensors in this class and the maximum number * by multiplying the number of tensors in this class and the maximum number
* of bytes used to serialize the metadata of a tensor (4 * 8192 for now). * of bytes used to serialize the metadata of a tensor (10 * 8192 for now).
*/ */
static constexpr int64_t SERIALIZED_METAINFO_SIZE_MAX = 32768; static constexpr int64_t SERIALIZED_METAINFO_SIZE_MAX = 10 * 81920;
/** /**
* @brief Shared memory used to hold the tensor metadata and data of this * @brief Shared memory used to hold the tensor metadata and data of this
......
...@@ -19,6 +19,32 @@ ...@@ -19,6 +19,32 @@
#include "./random.h" #include "./random.h"
#include "./shared_memory_helper.h" #include "./shared_memory_helper.h"
namespace {
torch::optional<torch::Dict<std::string, torch::Tensor>> TensorizeDict(
const torch::optional<torch::Dict<std::string, int64_t>>& dict) {
if (!dict.has_value()) {
return torch::nullopt;
}
torch::Dict<std::string, torch::Tensor> result;
for (const auto& pair : dict.value()) {
result.insert(pair.key(), torch::tensor(pair.value(), torch::kInt64));
}
return result;
}
torch::optional<torch::Dict<std::string, int64_t>> DetensorizeDict(
const torch::optional<torch::Dict<std::string, torch::Tensor>>& dict) {
if (!dict.has_value()) {
return torch::nullopt;
}
torch::Dict<std::string, int64_t> result;
for (const auto& pair : dict.value()) {
result.insert(pair.key(), pair.value().item<int64_t>());
}
return result;
}
} // namespace
namespace graphbolt { namespace graphbolt {
namespace sampling { namespace sampling {
...@@ -556,10 +582,12 @@ BuildGraphFromSharedMemoryHelper(SharedMemoryHelper&& helper) { ...@@ -556,10 +582,12 @@ BuildGraphFromSharedMemoryHelper(SharedMemoryHelper&& helper) {
auto indices = helper.ReadTorchTensor(); auto indices = helper.ReadTorchTensor();
auto node_type_offset = helper.ReadTorchTensor(); auto node_type_offset = helper.ReadTorchTensor();
auto type_per_edge = helper.ReadTorchTensor(); auto type_per_edge = helper.ReadTorchTensor();
auto node_type_to_id = DetensorizeDict(helper.ReadTorchTensorDict());
auto edge_type_to_id = DetensorizeDict(helper.ReadTorchTensorDict());
auto edge_attributes = helper.ReadTorchTensorDict(); auto edge_attributes = helper.ReadTorchTensorDict();
auto graph = c10::make_intrusive<FusedCSCSamplingGraph>( auto graph = c10::make_intrusive<FusedCSCSamplingGraph>(
indptr.value(), indices.value(), node_type_offset, type_per_edge, indptr.value(), indices.value(), node_type_offset, type_per_edge,
torch::nullopt, torch::nullopt, edge_attributes); node_type_to_id, edge_type_to_id, edge_attributes);
auto shared_memory = helper.ReleaseSharedMemory(); auto shared_memory = helper.ReleaseSharedMemory();
graph->HoldSharedMemoryObject( graph->HoldSharedMemoryObject(
std::move(shared_memory.first), std::move(shared_memory.second)); std::move(shared_memory.first), std::move(shared_memory.second));
...@@ -574,6 +602,8 @@ FusedCSCSamplingGraph::CopyToSharedMemory( ...@@ -574,6 +602,8 @@ FusedCSCSamplingGraph::CopyToSharedMemory(
helper.WriteTorchTensor(indices_); helper.WriteTorchTensor(indices_);
helper.WriteTorchTensor(node_type_offset_); helper.WriteTorchTensor(node_type_offset_);
helper.WriteTorchTensor(type_per_edge_); helper.WriteTorchTensor(type_per_edge_);
helper.WriteTorchTensorDict(TensorizeDict(node_type_to_id_));
helper.WriteTorchTensorDict(TensorizeDict(edge_type_to_id_));
helper.WriteTorchTensorDict(edge_attributes_); helper.WriteTorchTensorDict(edge_attributes_);
helper.Flush(); helper.Flush();
return BuildGraphFromSharedMemoryHelper(std::move(helper)); return BuildGraphFromSharedMemoryHelper(std::move(helper));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment