shared_memory_helper.cc 7.3 KB
Newer Older
1
2
3
/**
 *  Copyright (c) 2023 by Contributors
 *
4
5
 * @file shared_memory_helper.cc
 * @brief Share memory helper implementation.
6
 */
7
#include "./shared_memory_helper.h"
8
9
10

#include <graphbolt/serialize.h>
#include <graphbolt/shared_memory.h>
11
12
13
14
15
16
#include <torch/torch.h>

#include <cstring>
#include <string>
#include <tuple>
#include <vector>
17
18
19
20

namespace graphbolt {
namespace sampling {

21
static std::string GetSharedMemoryMetadataName(const std::string& name) {
22
  return name + "_metadata";
23
24
}

25
26
static std::string GetSharedMemoryDataName(const std::string& name) {
  return name + "_data";
27
28
}

29
30
31
32
33
34
35
// To avoid unaligned memory access, we round the size of the binary buffer to
// the nearest multiple of 8 bytes.
inline static int64_t GetRoundedSize(int64_t size) {
  constexpr int64_t ALIGNED_SIZE = 8;
  return (size + ALIGNED_SIZE - 1) / ALIGNED_SIZE * ALIGNED_SIZE;
}

36
SharedMemoryHelper::SharedMemoryHelper(const std::string& name)
37
    : name_(name),
38
39
      metadata_size_(0),
      data_size_(0),
40
41
42
43
44
45
46
47
48
49
50
51
      metadata_shared_memory_(nullptr),
      data_shared_memory_(nullptr),
      metadata_offset_(0),
      data_offset_(0) {}

void SharedMemoryHelper::InitializeRead() {
  metadata_offset_ = 0;
  data_offset_ = 0;
  if (metadata_shared_memory_ == nullptr) {
    // Reader process opens the shared memory.
    metadata_shared_memory_ =
        std::make_unique<SharedMemory>(GetSharedMemoryMetadataName(name_));
52
53
    metadata_shared_memory_->Open();
    metadata_size_ = metadata_shared_memory_->GetSize();
54
55
    data_shared_memory_ =
        std::make_unique<SharedMemory>(GetSharedMemoryDataName(name_));
56
57
    data_shared_memory_->Open();
    data_size_ = data_shared_memory_->GetSize();
58
59
60
  }
}

61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
void SharedMemoryHelper::WriteTorchArchive(
    torch::serialize::OutputArchive&& archive) {
  metadata_to_write_.emplace_back(std::move(archive));
}

torch::serialize::InputArchive SharedMemoryHelper::ReadTorchArchive() {
  auto metadata_ptr = this->GetCurrentMetadataPtr();
  int64_t metadata_size = static_cast<int64_t*>(metadata_ptr)[0];
  torch::serialize::InputArchive archive;
  archive.load_from(
      static_cast<const char*>(metadata_ptr) + sizeof(int64_t), metadata_size);
  auto rounded_size = GetRoundedSize(metadata_size);
  this->MoveMetadataPtr(sizeof(int64_t) + rounded_size);
  return archive;
}

void SharedMemoryHelper::WriteTorchTensor(
    torch::optional<torch::Tensor> tensor) {
  torch::serialize::OutputArchive archive;
  archive.write("has_value", tensor.has_value());
  if (tensor.has_value()) {
    archive.write("shape", tensor.value().sizes());
    archive.write("dtype", tensor.value().scalar_type());
84
  }
85
86
87
88
89
90
  this->WriteTorchArchive(std::move(archive));
  tensors_to_write_.push_back(tensor);
}

torch::optional<torch::Tensor> SharedMemoryHelper::ReadTorchTensor() {
  auto archive = this->ReadTorchArchive();
91
  bool has_value = read_from_archive<bool>(archive, "has_value");
92
  if (has_value) {
93
94
    auto shape = read_from_archive<std::vector<int64_t>>(archive, "shape");
    auto dtype = read_from_archive<torch::ScalarType>(archive, "dtype");
95
96
97
98
99
100
101
    auto data_ptr = this->GetCurrentDataPtr();
    auto tensor = torch::from_blob(data_ptr, shape, dtype);
    auto rounded_size = GetRoundedSize(tensor.numel() * tensor.element_size());
    this->MoveDataPtr(rounded_size);
    return tensor;
  } else {
    return torch::nullopt;
102
103
104
  }
}

105
106
void SharedMemoryHelper::WriteTorchTensorDict(
    torch::optional<torch::Dict<std::string, torch::Tensor>> tensor_dict) {
107
  torch::serialize::OutputArchive archive;
108
109
110
111
  if (!tensor_dict.has_value()) {
    archive.write("has_value", false);
    this->WriteTorchArchive(std::move(archive));
    return;
112
  }
113
114
115
116
117
118
119
120
121
122
123
  archive.write("has_value", true);
  auto dict_value = tensor_dict.value();
  archive.write("num_tensors", static_cast<int64_t>(dict_value.size()));
  int counter = 0;
  for (auto it = dict_value.begin(); it != dict_value.end(); ++it) {
    archive.write(std::string("key_") + std::to_string(counter), it->key());
    counter++;
  }
  this->WriteTorchArchive(std::move(archive));
  for (auto it = dict_value.begin(); it != dict_value.end(); ++it) {
    this->WriteTorchTensor(it->value());
124
125
126
  }
}

127
128
129
torch::optional<torch::Dict<std::string, torch::Tensor>>
SharedMemoryHelper::ReadTorchTensorDict() {
  auto archive = this->ReadTorchArchive();
130
  if (!read_from_archive<bool>(archive, "has_value")) {
131
132
    return torch::nullopt;
  }
133
  int64_t num_tensors = read_from_archive<int64_t>(archive, "num_tensors");
134
  torch::Dict<std::string, torch::Tensor> tensor_dict;
135
  for (int64_t i = 0; i < num_tensors; ++i) {
136
137
    auto key = read_from_archive<std::string>(
        archive, std::string("key_") + std::to_string(i));
138
139
140
141
142
143
    auto tensor = this->ReadTorchTensor();
    tensor_dict.insert(key, tensor.value());
  }
  return tensor_dict;
}

144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
void SharedMemoryHelper::SerializeMetadata() {
  for (auto& archive : metadata_to_write_) {
    std::stringstream serialized;
    archive.save_to(serialized);
    metadata_strings_to_write_.push_back(std::move(serialized.str()));
  }
  metadata_to_write_.clear();
}

void SharedMemoryHelper::WriteMetadataToSharedMemory() {
  metadata_offset_ = 0;
  for (const auto& str : metadata_strings_to_write_) {
    auto metadata_ptr = this->GetCurrentMetadataPtr();
    static_cast<int64_t*>(metadata_ptr)[0] = str.size();
    memcpy(
        static_cast<char*>(metadata_ptr) + sizeof(int64_t), str.data(),
        str.size());
    int64_t rounded_size = GetRoundedSize(str.size());
    this->MoveMetadataPtr(sizeof(int64_t) + rounded_size);
  }
  metadata_strings_to_write_.clear();
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
}

void SharedMemoryHelper::WriteTorchTensorInternal(
    torch::optional<torch::Tensor> tensor) {
  if (tensor.has_value()) {
    size_t memory_size = tensor.value().numel() * tensor.value().element_size();
    auto data_ptr = this->GetCurrentDataPtr();
    auto contiguous_tensor = tensor.value().contiguous();
    memcpy(data_ptr, contiguous_tensor.data_ptr(), memory_size);
    this->MoveDataPtr(GetRoundedSize(memory_size));
  }
}

void SharedMemoryHelper::Flush() {
  size_t data_size = 0;
  for (auto tensor : tensors_to_write_) {
    if (tensor.has_value()) {
      auto tensor_size = tensor.value().numel() * tensor.value().element_size();
      data_size += GetRoundedSize(tensor_size);
184
185
    }
  }
186
187
188
189
190
191
192
193
194
195

  // Serialize the metadata archives.
  SerializeMetadata();

  // Create the shared memory objects.
  const size_t metadata_size = std::accumulate(
      metadata_strings_to_write_.begin(), metadata_strings_to_write_.end(), 0,
      [](size_t sum, const std::string& str) {
        return sum + sizeof(int64_t) + GetRoundedSize(str.size());
      });
196
197
  metadata_shared_memory_ =
      std::make_unique<SharedMemory>(GetSharedMemoryMetadataName(name_));
198
199
  metadata_shared_memory_->Create(metadata_size);
  metadata_size_ = metadata_size;
200

201
202
  // Write the metadata and tensor data to the shared memory.
  WriteMetadataToSharedMemory();
203
204
205
  data_shared_memory_ =
      std::make_unique<SharedMemory>(GetSharedMemoryDataName(name_));
  data_shared_memory_->Create(data_size);
206
  data_size_ = data_size;
207
208
209
210
  data_offset_ = 0;
  for (auto tensor : tensors_to_write_) {
    this->WriteTorchTensorInternal(tensor);
  }
211

212
213
214
215
216
217
218
219
  metadata_to_write_.clear();
  tensors_to_write_.clear();
}

std::pair<SharedMemoryPtr, SharedMemoryPtr>
SharedMemoryHelper::ReleaseSharedMemory() {
  return std::make_pair(
      std::move(metadata_shared_memory_), std::move(data_shared_memory_));
220
221
222
223
}

}  // namespace sampling
}  // namespace graphbolt