Unverified Commit feaeb1c2 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] Polish the shared memory helper. (#6279)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-28-63.ap-northeast-1.compute.internal>
parent aa562f7e
......@@ -31,8 +31,6 @@ struct SamplerArgs<SamplerType::LABOR> {
int64_t num_nodes;
};
class SharedMemoryHelper;
/**
* @brief A sampling oriented csc format graph.
*
......@@ -238,22 +236,26 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
static c10::intrusive_ptr<CSCSamplingGraph> LoadFromSharedMemory(
const std::string& shared_memory_name);
/**
* @brief Hold the shared memory objects of the the tensor metadata and data.
* @note Shared memory used to hold the tensor metadata and data of this
* class. By storing its shared memory objects, the graph controls the
* resources of shared memory, which will be released automatically when the
* graph is destroyed. This function is for internal use by CopyToSharedMemory
* and LoadFromSharedMemory. Please contact the DGL team if you need to use
* it.
* @param tensor_metadata_shm The shared memory objects of tensor metadata.
* @param tensor_data_shm The shared memory objects of tensor data.
*/
void HoldSharedMemoryObject(
SharedMemoryPtr tensor_metadata_shm, SharedMemoryPtr tensor_data_shm);
private:
template <typename NumPickFn, typename PickFn>
c10::intrusive_ptr<SampledSubgraph> SampleNeighborsImpl(
const torch::Tensor& nodes, bool return_eids, NumPickFn num_pick_fn,
PickFn pick_fn) const;
/**
* @brief Build a CSCSamplingGraph from a shared memory helper. This function
* takes ownership of the shared memory objects in the helper.
*
* @param shared_memory_helper The shared memory helper.
* @return A new CSCSamplingGraph on shared memory.
*/
static c10::intrusive_ptr<CSCSamplingGraph> BuildGraphFromSharedMemoryHelper(
SharedMemoryHelper&& shared_memory_helper);
/** @brief CSC format index pointer array. */
torch::Tensor indptr_;
......@@ -292,12 +294,12 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
static constexpr int64_t SERIALIZED_METAINFO_SIZE_MAX = 32768;
/**
* @brief Shared memory used to hold the tensor meta information and data of
* this class. By storing its shared memory objects, the graph controls the
* @brief Shared memory used to hold the tensor metadata and data of this
* class. By storing its shared memory objects, the graph controls the
* resources of shared memory, which will be released automatically when the
* graph is destroyed.
*/
SharedMemoryPtr tensor_meta_shm_, tensor_data_shm_;
SharedMemoryPtr tensor_metadata_shm_, tensor_data_shm_;
};
/**
......
......@@ -17,7 +17,7 @@
#include <vector>
#include "./random.h"
#include "./shared_memory_utils.h"
#include "./shared_memory_helper.h"
namespace graphbolt {
namespace sampling {
......@@ -438,8 +438,7 @@ CSCSamplingGraph::SampleNegativeEdgesUniform(
return std::make_tuple(neg_src, neg_dst);
}
c10::intrusive_ptr<CSCSamplingGraph>
CSCSamplingGraph::BuildGraphFromSharedMemoryHelper(
static c10::intrusive_ptr<CSCSamplingGraph> BuildGraphFromSharedMemoryHelper(
SharedMemoryHelper&& helper) {
helper.InitializeRead();
auto indptr = helper.ReadTorchTensor();
......@@ -450,8 +449,9 @@ CSCSamplingGraph::BuildGraphFromSharedMemoryHelper(
auto graph = c10::make_intrusive<CSCSamplingGraph>(
indptr.value(), indices.value(), node_type_offset, type_per_edge,
edge_attributes);
std::tie(graph->tensor_meta_shm_, graph->tensor_data_shm_) =
helper.ReleaseSharedMemory();
auto shared_memory = helper.ReleaseSharedMemory();
graph->HoldSharedMemoryObject(
std::move(shared_memory.first), std::move(shared_memory.second));
return graph;
}
......@@ -473,6 +473,12 @@ c10::intrusive_ptr<CSCSamplingGraph> CSCSamplingGraph::LoadFromSharedMemory(
return BuildGraphFromSharedMemoryHelper(std::move(helper));
}
void CSCSamplingGraph::HoldSharedMemoryObject(
SharedMemoryPtr tensor_metadata_shm, SharedMemoryPtr tensor_data_shm) {
tensor_metadata_shm_ = std::move(tensor_metadata_shm);
tensor_data_shm_ = std::move(tensor_data_shm);
}
int64_t NumPick(
int64_t fanout, bool replace,
const torch::optional<torch::Tensor>& probs_or_mask, int64_t offset,
......
/**
* Copyright (c) 2023 by Contributors
*
* @file shared_memory_utils.cc
* @brief Share memory utility function implementation.
* @file shared_memory_helper.cc
* @brief Share memory helper implementation.
*/
#include "./shared_memory_utils.h"
#include "./shared_memory_helper.h"
#include <graphbolt/serialize.h>
#include <graphbolt/shared_memory.h>
......@@ -19,7 +19,7 @@ namespace graphbolt {
namespace sampling {
static std::string GetSharedMemoryMetadataName(const std::string& name) {
return name + "_meta";
return name + "_metadata";
}
static std::string GetSharedMemoryDataName(const std::string& name) {
......
/**
* Copyright (c) 2023 by Contributors
*
* @file shared_memory_utils.h
* @brief Share memory utilities.
* @file shared_memory_helper.h
* @brief Share memory helper.
*/
#ifndef GRAPHBOLT_SHM_UTILS_H_
#define GRAPHBOLT_SHM_UTILS_H_
#ifndef GRAPHBOLT_SHARED_MEMORY_HELPER_H_
#define GRAPHBOLT_SHARED_MEMORY_HELPER_H_
#include <graphbolt/shared_memory.h>
#include <torch/torch.h>
......@@ -28,9 +28,9 @@ namespace sampling {
* memory. However, the size of the binary buffer is not known in advance. To
* solve this problem, we use two shared memory objects: one for storing the
* metadata and the other for storing the binary buffer. The metadata includes
* the meta information of data structures such as size and shape. The size of
* the metadata is decided by the user via `max_metadata_size`. The size of
* the binary buffer is decided by the size of the data structures.
* the metadata of data structures such as size and shape. The size of the
* metadata is decided by the user via `max_metadata_size`. The size of the
* binary buffer is decided by the size of the data structures.
*
* To avoid repeated shared memory allocation, this helper class uses lazy data
* structure writing. The data structures are written to the shared memory only
......@@ -140,4 +140,4 @@ class SharedMemoryHelper {
} // namespace sampling
} // namespace graphbolt
#endif // GRAPHBOLT_SHM_UTILS_H_
#endif // GRAPHBOLT_SHARED_MEMORY_HELPER_H_
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment