Unverified Commit feaeb1c2 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] Polish the shared memory helper. (#6279)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-28-63.ap-northeast-1.compute.internal>
parent aa562f7e
...@@ -31,8 +31,6 @@ struct SamplerArgs<SamplerType::LABOR> { ...@@ -31,8 +31,6 @@ struct SamplerArgs<SamplerType::LABOR> {
int64_t num_nodes; int64_t num_nodes;
}; };
class SharedMemoryHelper;
/** /**
* @brief A sampling oriented csc format graph. * @brief A sampling oriented csc format graph.
* *
...@@ -238,22 +236,26 @@ class CSCSamplingGraph : public torch::CustomClassHolder { ...@@ -238,22 +236,26 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
static c10::intrusive_ptr<CSCSamplingGraph> LoadFromSharedMemory( static c10::intrusive_ptr<CSCSamplingGraph> LoadFromSharedMemory(
const std::string& shared_memory_name); const std::string& shared_memory_name);
/**
* @brief Hold the shared memory objects of the the tensor metadata and data.
* @note Shared memory used to hold the tensor metadata and data of this
* class. By storing its shared memory objects, the graph controls the
* resources of shared memory, which will be released automatically when the
* graph is destroyed. This function is for internal use by CopyToSharedMemory
* and LoadFromSharedMemory. Please contact the DGL team if you need to use
* it.
* @param tensor_metadata_shm The shared memory objects of tensor metadata.
* @param tensor_data_shm The shared memory objects of tensor data.
*/
void HoldSharedMemoryObject(
SharedMemoryPtr tensor_metadata_shm, SharedMemoryPtr tensor_data_shm);
private: private:
template <typename NumPickFn, typename PickFn> template <typename NumPickFn, typename PickFn>
c10::intrusive_ptr<SampledSubgraph> SampleNeighborsImpl( c10::intrusive_ptr<SampledSubgraph> SampleNeighborsImpl(
const torch::Tensor& nodes, bool return_eids, NumPickFn num_pick_fn, const torch::Tensor& nodes, bool return_eids, NumPickFn num_pick_fn,
PickFn pick_fn) const; PickFn pick_fn) const;
/**
* @brief Build a CSCSamplingGraph from a shared memory helper. This function
* takes ownership of the shared memory objects in the helper.
*
* @param shared_memory_helper The shared memory helper.
* @return A new CSCSamplingGraph on shared memory.
*/
static c10::intrusive_ptr<CSCSamplingGraph> BuildGraphFromSharedMemoryHelper(
SharedMemoryHelper&& shared_memory_helper);
/** @brief CSC format index pointer array. */ /** @brief CSC format index pointer array. */
torch::Tensor indptr_; torch::Tensor indptr_;
...@@ -292,12 +294,12 @@ class CSCSamplingGraph : public torch::CustomClassHolder { ...@@ -292,12 +294,12 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
static constexpr int64_t SERIALIZED_METAINFO_SIZE_MAX = 32768; static constexpr int64_t SERIALIZED_METAINFO_SIZE_MAX = 32768;
/** /**
* @brief Shared memory used to hold the tensor meta information and data of * @brief Shared memory used to hold the tensor metadata and data of this
* this class. By storing its shared memory objects, the graph controls the * class. By storing its shared memory objects, the graph controls the
* resources of shared memory, which will be released automatically when the * resources of shared memory, which will be released automatically when the
* graph is destroyed. * graph is destroyed.
*/ */
SharedMemoryPtr tensor_meta_shm_, tensor_data_shm_; SharedMemoryPtr tensor_metadata_shm_, tensor_data_shm_;
}; };
/** /**
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
#include <vector> #include <vector>
#include "./random.h" #include "./random.h"
#include "./shared_memory_utils.h" #include "./shared_memory_helper.h"
namespace graphbolt { namespace graphbolt {
namespace sampling { namespace sampling {
...@@ -438,8 +438,7 @@ CSCSamplingGraph::SampleNegativeEdgesUniform( ...@@ -438,8 +438,7 @@ CSCSamplingGraph::SampleNegativeEdgesUniform(
return std::make_tuple(neg_src, neg_dst); return std::make_tuple(neg_src, neg_dst);
} }
c10::intrusive_ptr<CSCSamplingGraph> static c10::intrusive_ptr<CSCSamplingGraph> BuildGraphFromSharedMemoryHelper(
CSCSamplingGraph::BuildGraphFromSharedMemoryHelper(
SharedMemoryHelper&& helper) { SharedMemoryHelper&& helper) {
helper.InitializeRead(); helper.InitializeRead();
auto indptr = helper.ReadTorchTensor(); auto indptr = helper.ReadTorchTensor();
...@@ -450,8 +449,9 @@ CSCSamplingGraph::BuildGraphFromSharedMemoryHelper( ...@@ -450,8 +449,9 @@ CSCSamplingGraph::BuildGraphFromSharedMemoryHelper(
auto graph = c10::make_intrusive<CSCSamplingGraph>( auto graph = c10::make_intrusive<CSCSamplingGraph>(
indptr.value(), indices.value(), node_type_offset, type_per_edge, indptr.value(), indices.value(), node_type_offset, type_per_edge,
edge_attributes); edge_attributes);
std::tie(graph->tensor_meta_shm_, graph->tensor_data_shm_) = auto shared_memory = helper.ReleaseSharedMemory();
helper.ReleaseSharedMemory(); graph->HoldSharedMemoryObject(
std::move(shared_memory.first), std::move(shared_memory.second));
return graph; return graph;
} }
...@@ -473,6 +473,12 @@ c10::intrusive_ptr<CSCSamplingGraph> CSCSamplingGraph::LoadFromSharedMemory( ...@@ -473,6 +473,12 @@ c10::intrusive_ptr<CSCSamplingGraph> CSCSamplingGraph::LoadFromSharedMemory(
return BuildGraphFromSharedMemoryHelper(std::move(helper)); return BuildGraphFromSharedMemoryHelper(std::move(helper));
} }
void CSCSamplingGraph::HoldSharedMemoryObject(
SharedMemoryPtr tensor_metadata_shm, SharedMemoryPtr tensor_data_shm) {
tensor_metadata_shm_ = std::move(tensor_metadata_shm);
tensor_data_shm_ = std::move(tensor_data_shm);
}
int64_t NumPick( int64_t NumPick(
int64_t fanout, bool replace, int64_t fanout, bool replace,
const torch::optional<torch::Tensor>& probs_or_mask, int64_t offset, const torch::optional<torch::Tensor>& probs_or_mask, int64_t offset,
......
/** /**
* Copyright (c) 2023 by Contributors * Copyright (c) 2023 by Contributors
* *
* @file shared_memory_utils.cc * @file shared_memory_helper.cc
* @brief Share memory utility function implementation. * @brief Share memory helper implementation.
*/ */
#include "./shared_memory_utils.h" #include "./shared_memory_helper.h"
#include <graphbolt/serialize.h> #include <graphbolt/serialize.h>
#include <graphbolt/shared_memory.h> #include <graphbolt/shared_memory.h>
...@@ -19,7 +19,7 @@ namespace graphbolt { ...@@ -19,7 +19,7 @@ namespace graphbolt {
namespace sampling { namespace sampling {
static std::string GetSharedMemoryMetadataName(const std::string& name) { static std::string GetSharedMemoryMetadataName(const std::string& name) {
return name + "_meta"; return name + "_metadata";
} }
static std::string GetSharedMemoryDataName(const std::string& name) { static std::string GetSharedMemoryDataName(const std::string& name) {
......
/** /**
* Copyright (c) 2023 by Contributors * Copyright (c) 2023 by Contributors
* *
* @file shared_memory_utils.h * @file shared_memory_helper.h
* @brief Share memory utilities. * @brief Share memory helper.
*/ */
#ifndef GRAPHBOLT_SHM_UTILS_H_ #ifndef GRAPHBOLT_SHARED_MEMORY_HELPER_H_
#define GRAPHBOLT_SHM_UTILS_H_ #define GRAPHBOLT_SHARED_MEMORY_HELPER_H_
#include <graphbolt/shared_memory.h> #include <graphbolt/shared_memory.h>
#include <torch/torch.h> #include <torch/torch.h>
...@@ -28,9 +28,9 @@ namespace sampling { ...@@ -28,9 +28,9 @@ namespace sampling {
* memory. However, the size of the binary buffer is not known in advance. To * memory. However, the size of the binary buffer is not known in advance. To
* solve this problem, we use two shared memory objects: one for storing the * solve this problem, we use two shared memory objects: one for storing the
* metadata and the other for storing the binary buffer. The metadata includes * metadata and the other for storing the binary buffer. The metadata includes
* the meta information of data structures such as size and shape. The size of * the metadata of data structures such as size and shape. The size of the
* the metadata is decided by the user via `max_metadata_size`. The size of * metadata is decided by the user via `max_metadata_size`. The size of the
* the binary buffer is decided by the size of the data structures. * binary buffer is decided by the size of the data structures.
* *
* To avoid repeated shared memory allocation, this helper class uses lazy data * To avoid repeated shared memory allocation, this helper class uses lazy data
* structure writing. The data structures are written to the shared memory only * structure writing. The data structures are written to the shared memory only
...@@ -140,4 +140,4 @@ class SharedMemoryHelper { ...@@ -140,4 +140,4 @@ class SharedMemoryHelper {
} // namespace sampling } // namespace sampling
} // namespace graphbolt } // namespace graphbolt
#endif // GRAPHBOLT_SHM_UTILS_H_ #endif // GRAPHBOLT_SHARED_MEMORY_HELPER_H_
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment