shared_memory_utils.h 4.92 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
/**
 *  Copyright (c) 2023 by Contributors
 *
 * @file shared_memory_utils.h
 * @brief Share memory utilities.
 */
#ifndef GRAPHBOLT_SHM_UTILS_H_
#define GRAPHBOLT_SHM_UTILS_H_

#include <graphbolt/shared_memory.h>
#include <torch/torch.h>

#include <memory>
#include <sstream>
#include <string>
16
#include <tuple>
17
18
19
20
21
22
#include <vector>

namespace graphbolt {
namespace sampling {

/**
23
24
 * @brief SharedMemoryHelper is a helper class to write/read data structures
 * to/from shared memory.
25
 *
26
27
28
29
30
31
32
33
 * In order to write data structure to shared memory, we need to serialize the
 * data structure to a binary buffer and then write the buffer to the shared
 * memory. However, the size of the binary buffer is not known in advance. To
 * solve this problem, we use two shared memory objects: one for storing the
 * metadata and the other for storing the binary buffer. The metadata includes
 * the meta information of data structures such as size and shape. The size of
 * the metadata is decided by the user via `max_metadata_size`. The size of
 * the binary buffer is decided by the size of the data structures.
34
 *
35
36
37
38
39
 * To avoid repeated shared memory allocation, this helper class uses lazy data
 * structure writing. The data structures are written to the shared memory only
 * when `Flush` is called. The data structures are written in the order of
 * calling `WriteTorchArchive`, `WriteTorchTensor` and `WriteTorchTensorDict`,
 * and also read in the same order.
40
 *
41
42
43
44
45
46
47
48
49
50
51
52
53
54
 * The usage of this class as a writer is as follows:
 * @code{.cpp}
 * SharedMemoryHelper shm_helper("shm_name", 1024, true);
 * shm_helper.WriteTorchArchive(archive);
 * shm_helper.WriteTorchTensor(tensor);
 * shm_helper.WriteTorchTensorDict(tensor_dict);
 * shm_helper.Flush();
 * // After `Flush`, the data structures are written to the shared memory.
 * // Then the helper class can be used as a reader.
 * shm_helper.InitializeRead();
 * auto archive = shm_helper.ReadTorchArchive();
 * auto tensor = shm_helper.ReadTorchTensor();
 * auto tensor_dict = shm_helper.ReadTorchTensorDict();
 * @endcode
55
 *
56
57
58
59
60
61
62
63
 * The usage of this class as a reader is as follows:
 * @code{.cpp}
 * SharedMemoryHelper shm_helper("shm_name", 1024, false);
 * shm_helper.InitializeRead();
 * auto archive = shm_helper.ReadTorchArchive();
 * auto tensor = shm_helper.ReadTorchTensor();
 * auto tensor_dict = shm_helper.ReadTorchTensorDict();
 * @endcode
64
65
66
 *
 *
 */
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
class SharedMemoryHelper {
 public:
  /**
   * @brief Constructor of the shared memory helper.
   * @param name The name of the shared memory.
   * @param max_metadata_size The maximum size of metadata.
   */
  SharedMemoryHelper(const std::string& name, int64_t max_metadata_size);

  /** @brief Initialize this helper class before reading. */
  void InitializeRead();

  void WriteTorchArchive(torch::serialize::OutputArchive&& archive);
  torch::serialize::InputArchive ReadTorchArchive();

  void WriteTorchTensor(torch::optional<torch::Tensor> tensor);
  torch::optional<torch::Tensor> ReadTorchTensor();

  void WriteTorchTensorDict(
      torch::optional<torch::Dict<std::string, torch::Tensor>> tensor_dict);
  torch::optional<torch::Dict<std::string, torch::Tensor>>
  ReadTorchTensorDict();

  /** @brief Flush the data structures to the shared memory. */
  void Flush();

  /** @brief Release the shared memory and return their left values. */
  std::pair<SharedMemoryPtr, SharedMemoryPtr> ReleaseSharedMemory();

 private:
  /**
   * @brief Write the metadata to the shared memory. This function is
   * called by `Flush`.
   */
  void WriteTorchArchiveInternal(torch::serialize::OutputArchive& archive);
  /**
   * @brief Write the tensor data to the shared memory. This function is
   * called by `Flush`.
   */
  void WriteTorchTensorInternal(torch::optional<torch::Tensor> tensor);

  inline void* GetCurrentMetadataPtr() const {
    return static_cast<char*>(metadata_shared_memory_->GetMemory()) +
           metadata_offset_;
  }
  inline void* GetCurrentDataPtr() const {
    return static_cast<char*>(data_shared_memory_->GetMemory()) + data_offset_;
  }
  inline void MoveMetadataPtr(int64_t offset) {
    TORCH_CHECK(
        metadata_offset_ + offset <= max_metadata_size_,
        "The size of metadata exceeds the maximum size of shared memory.");
    metadata_offset_ += offset;
  }
  inline void MoveDataPtr(int64_t offset) { data_offset_ += offset; }

  std::string name_;
  bool is_creator_;

  int64_t max_metadata_size_;

  // The shared memory objects for storing metadata and tensor data.
  SharedMemoryPtr metadata_shared_memory_, data_shared_memory_;

  // The read/write offsets of the metadata and tensor data.
  int64_t metadata_offset_, data_offset_;

  // The data structures to write to the shared memory. They are written to the
  // shared memory only when `Flush` is called.
  std::vector<torch::serialize::OutputArchive> metadata_to_write_;
  std::vector<torch::optional<torch::Tensor>> tensors_to_write_;
};
139
140
141
142
143

}  // namespace sampling
}  // namespace graphbolt

#endif  // GRAPHBOLT_SHM_UTILS_H_