shared_memory_helper.cc 7.35 KB
Newer Older
sangwzh's avatar
sangwzh committed
1
// !!! This is a file automatically generated by hipify!!!
2
3
4
/**
 *  Copyright (c) 2023 by Contributors
 *
5
6
 * @file shared_memory_helper.cc
 * @brief Share memory helper implementation.
7
 */
sangwzh's avatar
sangwzh committed
8
#include "shared_memory_helper.h"
9
10
11

#include <graphbolt/serialize.h>
#include <graphbolt/shared_memory.h>
12
13
14
15
16
17
#include <torch/torch.h>

#include <cstring>
#include <string>
#include <tuple>
#include <vector>
18
19
20
21

namespace graphbolt {
namespace sampling {

22
static std::string GetSharedMemoryMetadataName(const std::string& name) {
23
  return name + "_metadata";
24
25
}

26
27
static std::string GetSharedMemoryDataName(const std::string& name) {
  return name + "_data";
28
29
}

30
31
32
33
34
35
36
// To avoid unaligned memory access, we round the size of the binary buffer to
// the nearest multiple of 8 bytes.
inline static int64_t GetRoundedSize(int64_t size) {
  constexpr int64_t ALIGNED_SIZE = 8;
  return (size + ALIGNED_SIZE - 1) / ALIGNED_SIZE * ALIGNED_SIZE;
}

37
SharedMemoryHelper::SharedMemoryHelper(const std::string& name)
38
    : name_(name),
39
40
      metadata_size_(0),
      data_size_(0),
41
42
43
44
45
46
47
48
49
50
51
52
      metadata_shared_memory_(nullptr),
      data_shared_memory_(nullptr),
      metadata_offset_(0),
      data_offset_(0) {}

void SharedMemoryHelper::InitializeRead() {
  metadata_offset_ = 0;
  data_offset_ = 0;
  if (metadata_shared_memory_ == nullptr) {
    // Reader process opens the shared memory.
    metadata_shared_memory_ =
        std::make_unique<SharedMemory>(GetSharedMemoryMetadataName(name_));
53
54
    metadata_shared_memory_->Open();
    metadata_size_ = metadata_shared_memory_->GetSize();
55
56
    data_shared_memory_ =
        std::make_unique<SharedMemory>(GetSharedMemoryDataName(name_));
57
58
    data_shared_memory_->Open();
    data_size_ = data_shared_memory_->GetSize();
59
60
61
  }
}

62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
void SharedMemoryHelper::WriteTorchArchive(
    torch::serialize::OutputArchive&& archive) {
  metadata_to_write_.emplace_back(std::move(archive));
}

torch::serialize::InputArchive SharedMemoryHelper::ReadTorchArchive() {
  auto metadata_ptr = this->GetCurrentMetadataPtr();
  int64_t metadata_size = static_cast<int64_t*>(metadata_ptr)[0];
  torch::serialize::InputArchive archive;
  archive.load_from(
      static_cast<const char*>(metadata_ptr) + sizeof(int64_t), metadata_size);
  auto rounded_size = GetRoundedSize(metadata_size);
  this->MoveMetadataPtr(sizeof(int64_t) + rounded_size);
  return archive;
}

void SharedMemoryHelper::WriteTorchTensor(
    torch::optional<torch::Tensor> tensor) {
  torch::serialize::OutputArchive archive;
  archive.write("has_value", tensor.has_value());
  if (tensor.has_value()) {
    archive.write("shape", tensor.value().sizes());
    archive.write("dtype", tensor.value().scalar_type());
85
  }
86
87
88
89
90
91
  this->WriteTorchArchive(std::move(archive));
  tensors_to_write_.push_back(tensor);
}

torch::optional<torch::Tensor> SharedMemoryHelper::ReadTorchTensor() {
  auto archive = this->ReadTorchArchive();
92
  bool has_value = read_from_archive<bool>(archive, "has_value");
93
  if (has_value) {
94
95
    auto shape = read_from_archive<std::vector<int64_t>>(archive, "shape");
    auto dtype = read_from_archive<torch::ScalarType>(archive, "dtype");
96
97
98
99
100
101
102
    auto data_ptr = this->GetCurrentDataPtr();
    auto tensor = torch::from_blob(data_ptr, shape, dtype);
    auto rounded_size = GetRoundedSize(tensor.numel() * tensor.element_size());
    this->MoveDataPtr(rounded_size);
    return tensor;
  } else {
    return torch::nullopt;
103
104
105
  }
}

106
107
void SharedMemoryHelper::WriteTorchTensorDict(
    torch::optional<torch::Dict<std::string, torch::Tensor>> tensor_dict) {
108
  torch::serialize::OutputArchive archive;
109
110
111
112
  if (!tensor_dict.has_value()) {
    archive.write("has_value", false);
    this->WriteTorchArchive(std::move(archive));
    return;
113
  }
114
115
116
117
118
119
120
121
122
123
124
  archive.write("has_value", true);
  auto dict_value = tensor_dict.value();
  archive.write("num_tensors", static_cast<int64_t>(dict_value.size()));
  int counter = 0;
  for (auto it = dict_value.begin(); it != dict_value.end(); ++it) {
    archive.write(std::string("key_") + std::to_string(counter), it->key());
    counter++;
  }
  this->WriteTorchArchive(std::move(archive));
  for (auto it = dict_value.begin(); it != dict_value.end(); ++it) {
    this->WriteTorchTensor(it->value());
125
126
127
  }
}

128
129
130
torch::optional<torch::Dict<std::string, torch::Tensor>>
SharedMemoryHelper::ReadTorchTensorDict() {
  auto archive = this->ReadTorchArchive();
131
  if (!read_from_archive<bool>(archive, "has_value")) {
132
133
    return torch::nullopt;
  }
134
  int64_t num_tensors = read_from_archive<int64_t>(archive, "num_tensors");
135
  torch::Dict<std::string, torch::Tensor> tensor_dict;
136
  for (int64_t i = 0; i < num_tensors; ++i) {
137
138
    auto key = read_from_archive<std::string>(
        archive, std::string("key_") + std::to_string(i));
139
140
141
142
143
144
    auto tensor = this->ReadTorchTensor();
    tensor_dict.insert(key, tensor.value());
  }
  return tensor_dict;
}

145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
void SharedMemoryHelper::SerializeMetadata() {
  for (auto& archive : metadata_to_write_) {
    std::stringstream serialized;
    archive.save_to(serialized);
    metadata_strings_to_write_.push_back(std::move(serialized.str()));
  }
  metadata_to_write_.clear();
}

void SharedMemoryHelper::WriteMetadataToSharedMemory() {
  metadata_offset_ = 0;
  for (const auto& str : metadata_strings_to_write_) {
    auto metadata_ptr = this->GetCurrentMetadataPtr();
    static_cast<int64_t*>(metadata_ptr)[0] = str.size();
    memcpy(
        static_cast<char*>(metadata_ptr) + sizeof(int64_t), str.data(),
        str.size());
    int64_t rounded_size = GetRoundedSize(str.size());
    this->MoveMetadataPtr(sizeof(int64_t) + rounded_size);
  }
  metadata_strings_to_write_.clear();
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
}

void SharedMemoryHelper::WriteTorchTensorInternal(
    torch::optional<torch::Tensor> tensor) {
  if (tensor.has_value()) {
    size_t memory_size = tensor.value().numel() * tensor.value().element_size();
    auto data_ptr = this->GetCurrentDataPtr();
    auto contiguous_tensor = tensor.value().contiguous();
    memcpy(data_ptr, contiguous_tensor.data_ptr(), memory_size);
    this->MoveDataPtr(GetRoundedSize(memory_size));
  }
}

void SharedMemoryHelper::Flush() {
  size_t data_size = 0;
  for (auto tensor : tensors_to_write_) {
    if (tensor.has_value()) {
      auto tensor_size = tensor.value().numel() * tensor.value().element_size();
      data_size += GetRoundedSize(tensor_size);
185
186
    }
  }
187
188
189
190
191
192
193
194
195
196

  // Serialize the metadata archives.
  SerializeMetadata();

  // Create the shared memory objects.
  const size_t metadata_size = std::accumulate(
      metadata_strings_to_write_.begin(), metadata_strings_to_write_.end(), 0,
      [](size_t sum, const std::string& str) {
        return sum + sizeof(int64_t) + GetRoundedSize(str.size());
      });
197
198
  metadata_shared_memory_ =
      std::make_unique<SharedMemory>(GetSharedMemoryMetadataName(name_));
199
200
  metadata_shared_memory_->Create(metadata_size);
  metadata_size_ = metadata_size;
201

202
203
  // Write the metadata and tensor data to the shared memory.
  WriteMetadataToSharedMemory();
204
205
206
  data_shared_memory_ =
      std::make_unique<SharedMemory>(GetSharedMemoryDataName(name_));
  data_shared_memory_->Create(data_size);
207
  data_size_ = data_size;
208
209
210
211
  data_offset_ = 0;
  for (auto tensor : tensors_to_write_) {
    this->WriteTorchTensorInternal(tensor);
  }
212

213
214
215
216
217
218
219
220
  metadata_to_write_.clear();
  tensors_to_write_.clear();
}

std::pair<SharedMemoryPtr, SharedMemoryPtr>
SharedMemoryHelper::ReleaseSharedMemory() {
  return std::make_pair(
      std::move(metadata_shared_memory_), std::move(data_shared_memory_));
221
222
223
224
}

}  // namespace sampling
}  // namespace graphbolt