shared_memory_helper.cc 7.15 KB
Newer Older
1
2
3
/**
 *  Copyright (c) 2023 by Contributors
 *
4
5
 * @file shared_memory_helper.cc
 * @brief Share memory helper implementation.
6
 */
7
#include "./shared_memory_helper.h"
8
9
10

#include <graphbolt/serialize.h>
#include <graphbolt/shared_memory.h>
11
12
13
14
15
16
#include <torch/torch.h>

#include <cstring>
#include <string>
#include <tuple>
#include <vector>
17
18
19
20

namespace graphbolt {
namespace sampling {

21
static std::string GetSharedMemoryMetadataName(const std::string& name) {
22
  return name + "_metadata";
23
24
}

25
26
static std::string GetSharedMemoryDataName(const std::string& name) {
  return name + "_data";
27
28
}

29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
// To avoid unaligned memory access, we round the size of the binary buffer to
// the nearest multiple of 8 bytes.
inline static int64_t GetRoundedSize(int64_t size) {
  constexpr int64_t ALIGNED_SIZE = 8;
  return (size + ALIGNED_SIZE - 1) / ALIGNED_SIZE * ALIGNED_SIZE;
}

SharedMemoryHelper::SharedMemoryHelper(
    const std::string& name, int64_t max_metadata_size)
    : name_(name),
      max_metadata_size_(max_metadata_size),
      metadata_shared_memory_(nullptr),
      data_shared_memory_(nullptr),
      metadata_offset_(0),
      data_offset_(0) {}

void SharedMemoryHelper::InitializeRead() {
  metadata_offset_ = 0;
  data_offset_ = 0;
  if (metadata_shared_memory_ == nullptr) {
    // Reader process opens the shared memory.
    metadata_shared_memory_ =
        std::make_unique<SharedMemory>(GetSharedMemoryMetadataName(name_));
    metadata_shared_memory_->Open(max_metadata_size_);
    auto archive = this->ReadTorchArchive();
    int64_t data_size = read_from_archive(archive, "data_size").toInt();
    data_shared_memory_ =
        std::make_unique<SharedMemory>(GetSharedMemoryDataName(name_));
    data_shared_memory_->Open(data_size);
  } else {
    // Writer process already has the shared memory.
    // Skip the first archive recording data size before read.
    this->ReadTorchArchive();
62
63
64
  }
}

65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
void SharedMemoryHelper::WriteTorchArchive(
    torch::serialize::OutputArchive&& archive) {
  metadata_to_write_.emplace_back(std::move(archive));
}

torch::serialize::InputArchive SharedMemoryHelper::ReadTorchArchive() {
  auto metadata_ptr = this->GetCurrentMetadataPtr();
  int64_t metadata_size = static_cast<int64_t*>(metadata_ptr)[0];
  torch::serialize::InputArchive archive;
  archive.load_from(
      static_cast<const char*>(metadata_ptr) + sizeof(int64_t), metadata_size);
  auto rounded_size = GetRoundedSize(metadata_size);
  this->MoveMetadataPtr(sizeof(int64_t) + rounded_size);
  return archive;
}

void SharedMemoryHelper::WriteTorchTensor(
    torch::optional<torch::Tensor> tensor) {
  torch::serialize::OutputArchive archive;
  archive.write("has_value", tensor.has_value());
  if (tensor.has_value()) {
    archive.write("shape", tensor.value().sizes());
    archive.write("dtype", tensor.value().scalar_type());
88
  }
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
  this->WriteTorchArchive(std::move(archive));
  tensors_to_write_.push_back(tensor);
}

torch::optional<torch::Tensor> SharedMemoryHelper::ReadTorchTensor() {
  auto archive = this->ReadTorchArchive();
  bool has_value = read_from_archive(archive, "has_value").toBool();
  if (has_value) {
    auto shape = read_from_archive(archive, "shape").toIntVector();
    auto dtype = read_from_archive(archive, "dtype").toScalarType();
    auto data_ptr = this->GetCurrentDataPtr();
    auto tensor = torch::from_blob(data_ptr, shape, dtype);
    auto rounded_size = GetRoundedSize(tensor.numel() * tensor.element_size());
    this->MoveDataPtr(rounded_size);
    return tensor;
  } else {
    return torch::nullopt;
106
107
108
  }
}

109
110
void SharedMemoryHelper::WriteTorchTensorDict(
    torch::optional<torch::Dict<std::string, torch::Tensor>> tensor_dict) {
111
  torch::serialize::OutputArchive archive;
112
113
114
115
  if (!tensor_dict.has_value()) {
    archive.write("has_value", false);
    this->WriteTorchArchive(std::move(archive));
    return;
116
  }
117
118
119
120
121
122
123
124
125
126
127
  archive.write("has_value", true);
  auto dict_value = tensor_dict.value();
  archive.write("num_tensors", static_cast<int64_t>(dict_value.size()));
  int counter = 0;
  for (auto it = dict_value.begin(); it != dict_value.end(); ++it) {
    archive.write(std::string("key_") + std::to_string(counter), it->key());
    counter++;
  }
  this->WriteTorchArchive(std::move(archive));
  for (auto it = dict_value.begin(); it != dict_value.end(); ++it) {
    this->WriteTorchTensor(it->value());
128
129
130
  }
}

131
132
133
134
135
136
torch::optional<torch::Dict<std::string, torch::Tensor>>
SharedMemoryHelper::ReadTorchTensorDict() {
  auto archive = this->ReadTorchArchive();
  if (!read_from_archive(archive, "has_value").toBool()) {
    return torch::nullopt;
  }
137
  int64_t num_tensors = read_from_archive(archive, "num_tensors").toInt();
138
  torch::Dict<std::string, torch::Tensor> tensor_dict;
139
  for (int64_t i = 0; i < num_tensors; ++i) {
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
    auto key =
        read_from_archive(archive, std::string("key_") + std::to_string(i))
            .toStringRef();
    auto tensor = this->ReadTorchTensor();
    tensor_dict.insert(key, tensor.value());
  }
  return tensor_dict;
}

void SharedMemoryHelper::WriteTorchArchiveInternal(
    torch::serialize::OutputArchive& archive) {
  std::stringstream serialized;
  archive.save_to(serialized);
  auto serialized_str = serialized.str();
  auto metadata_ptr = this->GetCurrentMetadataPtr();
  static_cast<int64_t*>(metadata_ptr)[0] = serialized_str.size();
  memcpy(
      static_cast<char*>(metadata_ptr) + sizeof(int64_t), serialized_str.data(),
      serialized_str.size());
  int64_t rounded_size = GetRoundedSize(serialized_str.size());
  this->MoveMetadataPtr(sizeof(int64_t) + rounded_size);
}

void SharedMemoryHelper::WriteTorchTensorInternal(
    torch::optional<torch::Tensor> tensor) {
  if (tensor.has_value()) {
    size_t memory_size = tensor.value().numel() * tensor.value().element_size();
    auto data_ptr = this->GetCurrentDataPtr();
    auto contiguous_tensor = tensor.value().contiguous();
    memcpy(data_ptr, contiguous_tensor.data_ptr(), memory_size);
    this->MoveDataPtr(GetRoundedSize(memory_size));
  }
}

void SharedMemoryHelper::Flush() {
  // The first archive records the size of the tensor data.
  torch::serialize::OutputArchive archive;
  size_t data_size = 0;
  for (auto tensor : tensors_to_write_) {
    if (tensor.has_value()) {
      auto tensor_size = tensor.value().numel() * tensor.value().element_size();
      data_size += GetRoundedSize(tensor_size);
182
183
    }
  }
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
  archive.write("data_size", static_cast<int64_t>(data_size));
  metadata_shared_memory_ =
      std::make_unique<SharedMemory>(GetSharedMemoryMetadataName(name_));
  metadata_shared_memory_->Create(max_metadata_size_);
  metadata_offset_ = 0;
  this->WriteTorchArchiveInternal(archive);
  for (auto& archive : metadata_to_write_) {
    this->WriteTorchArchiveInternal(archive);
  }

  data_shared_memory_ =
      std::make_unique<SharedMemory>(GetSharedMemoryDataName(name_));
  data_shared_memory_->Create(data_size);
  data_offset_ = 0;
  for (auto tensor : tensors_to_write_) {
    this->WriteTorchTensorInternal(tensor);
  }
  metadata_to_write_.clear();
  tensors_to_write_.clear();
}

std::pair<SharedMemoryPtr, SharedMemoryPtr>
SharedMemoryHelper::ReleaseSharedMemory() {
  return std::make_pair(
      std::move(metadata_shared_memory_), std::move(data_shared_memory_));
209
210
211
212
}

}  // namespace sampling
}  // namespace graphbolt