shared_memory_helper.h 5.09 KB
Newer Older
1
2
3
/**
 *  Copyright (c) 2023 by Contributors
 *
4
5
 * @file shared_memory_helper.h
 * @brief Share memory helper.
6
 */
7
8
#ifndef GRAPHBOLT_SHARED_MEMORY_HELPER_H_
#define GRAPHBOLT_SHARED_MEMORY_HELPER_H_
9
10
11
12
13
14
15

#include <graphbolt/shared_memory.h>
#include <torch/torch.h>

#include <memory>
#include <sstream>
#include <string>
16
#include <tuple>
17
18
19
20
21
22
#include <vector>

namespace graphbolt {
namespace sampling {

/**
23
24
 * @brief SharedMemoryHelper is a helper class to write/read data structures
 * to/from shared memory.
25
 *
26
27
28
29
30
 * In order to write data structure to shared memory, we need to serialize the
 * data structure to a binary buffer and then write the buffer to the shared
 * memory. However, the size of the binary buffer is not known in advance. To
 * solve this problem, we use two shared memory objects: one for storing the
 * metadata and the other for storing the binary buffer. The metadata includes
31
 * the metadata of data structures such as size and shape. The size of the
32
33
 * metadata is decided by the size of metadata. The size of the binary buffer is
 * decided by the size of the data structures.
34
 *
35
36
37
38
39
 * To avoid repeated shared memory allocation, this helper class uses lazy data
 * structure writing. The data structures are written to the shared memory only
 * when `Flush` is called. The data structures are written in the order of
 * calling `WriteTorchArchive`, `WriteTorchTensor` and `WriteTorchTensorDict`,
 * and also read in the same order.
40
 *
41
42
43
44
45
46
47
48
49
50
51
52
53
54
 * The usage of this class as a writer is as follows:
 * @code{.cpp}
 * SharedMemoryHelper shm_helper("shm_name", 1024, true);
 * shm_helper.WriteTorchArchive(archive);
 * shm_helper.WriteTorchTensor(tensor);
 * shm_helper.WriteTorchTensorDict(tensor_dict);
 * shm_helper.Flush();
 * // After `Flush`, the data structures are written to the shared memory.
 * // Then the helper class can be used as a reader.
 * shm_helper.InitializeRead();
 * auto archive = shm_helper.ReadTorchArchive();
 * auto tensor = shm_helper.ReadTorchTensor();
 * auto tensor_dict = shm_helper.ReadTorchTensorDict();
 * @endcode
55
 *
56
57
58
59
60
61
62
63
 * The usage of this class as a reader is as follows:
 * @code{.cpp}
 * SharedMemoryHelper shm_helper("shm_name", 1024, false);
 * shm_helper.InitializeRead();
 * auto archive = shm_helper.ReadTorchArchive();
 * auto tensor = shm_helper.ReadTorchTensor();
 * auto tensor_dict = shm_helper.ReadTorchTensorDict();
 * @endcode
64
65
66
 *
 *
 */
67
68
69
70
71
72
class SharedMemoryHelper {
 public:
  /**
   * @brief Constructor of the shared memory helper.
   * @param name The name of the shared memory.
   */
73
  SharedMemoryHelper(const std::string& name);
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95

  /** @brief Initialize this helper class before reading. */
  void InitializeRead();

  void WriteTorchArchive(torch::serialize::OutputArchive&& archive);
  torch::serialize::InputArchive ReadTorchArchive();

  void WriteTorchTensor(torch::optional<torch::Tensor> tensor);
  torch::optional<torch::Tensor> ReadTorchTensor();

  void WriteTorchTensorDict(
      torch::optional<torch::Dict<std::string, torch::Tensor>> tensor_dict);
  torch::optional<torch::Dict<std::string, torch::Tensor>>
  ReadTorchTensorDict();

  /** @brief Flush the data structures to the shared memory. */
  void Flush();

  /** @brief Release the shared memory and return their left values. */
  std::pair<SharedMemoryPtr, SharedMemoryPtr> ReleaseSharedMemory();

 private:
96
97
98
99
  /**
   * @brief Serialize metadata to string.
   */
  void SerializeMetadata();
100
101
102
103
  /**
   * @brief Write the metadata to the shared memory. This function is
   * called by `Flush`.
   */
104
  void WriteMetadataToSharedMemory();
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
  /**
   * @brief Write the tensor data to the shared memory. This function is
   * called by `Flush`.
   */
  void WriteTorchTensorInternal(torch::optional<torch::Tensor> tensor);

  inline void* GetCurrentMetadataPtr() const {
    return static_cast<char*>(metadata_shared_memory_->GetMemory()) +
           metadata_offset_;
  }
  inline void* GetCurrentDataPtr() const {
    return static_cast<char*>(data_shared_memory_->GetMemory()) + data_offset_;
  }
  inline void MoveMetadataPtr(int64_t offset) {
    TORCH_CHECK(
120
        metadata_offset_ + offset <= metadata_size_,
121
122
123
        "The size of metadata exceeds the maximum size of shared memory.");
    metadata_offset_ += offset;
  }
124
125
126
127
128
129
  inline void MoveDataPtr(int64_t offset) {
    TORCH_CHECK(
        data_offset_ + offset <= data_size_,
        "The size of data exceeds the maximum size of shared memory.");
    data_offset_ += offset;
  }
130
131
132
133

  std::string name_;
  bool is_creator_;

134
135
  size_t metadata_size_;
  size_t data_size_;
136
137
138
139
140

  // The shared memory objects for storing metadata and tensor data.
  SharedMemoryPtr metadata_shared_memory_, data_shared_memory_;

  // The read/write offsets of the metadata and tensor data.
141
  size_t metadata_offset_, data_offset_;
142
143
144
145

  // The data structures to write to the shared memory. They are written to the
  // shared memory only when `Flush` is called.
  std::vector<torch::serialize::OutputArchive> metadata_to_write_;
146
  std::vector<std::string> metadata_strings_to_write_;
147
148
  std::vector<torch::optional<torch::Tensor>> tensors_to_write_;
};
149
150
151
152

}  // namespace sampling
}  // namespace graphbolt

153
#endif  // GRAPHBOLT_SHARED_MEMORY_HELPER_H_