"docs/vscode:/vscode.git/clone" did not exist on "fe121c63f5b1ffe8cfa4d4c9b46530d1cb3035e6"
shared_memory_utils.cc 6.48 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
/**
 *  Copyright (c) 2023 by Contributors
 *
 * @file shared_memory_utils.cc
 * @brief Share memory utility function implementation.
 */
#include "./shared_memory_utils.h"

#include <graphbolt/serialize.h>
#include <graphbolt/shared_memory.h>
11
12
13
14
15
16
#include <torch/torch.h>

#include <cstring>
#include <string>
#include <tuple>
#include <vector>
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184

namespace graphbolt {
namespace sampling {

static SharedMemoryPtr CopyTorchArchiveToSharedMemory(
    const std::string& name, int64_t size,
    torch::serialize::OutputArchive& archive) {
  std::stringstream serialized;
  archive.save_to(serialized);
  auto serialized_str = serialized.str();
  auto shm = std::make_unique<SharedMemory>(name);
  auto mem_buf = shm->Create(size);
  // Use the first 8 bytes to store the size of the serialized string.
  static_cast<int64_t*>(mem_buf)[0] = serialized_str.size();
  memcpy(
      (char*)mem_buf + sizeof(int64_t), serialized_str.data(),
      serialized_str.size());
  return shm;
}

static SharedMemoryPtr LoadTorchArchiveFromSharedMemory(
    const std::string& name, int64_t max_meta_size,
    torch::serialize::InputArchive& archive) {
  auto shm = std::make_unique<SharedMemory>(name);
  auto mem_buf = shm->Open(max_meta_size);
  int64_t meta_size = static_cast<int64_t*>(mem_buf)[0];
  archive.load_from(
      static_cast<const char*>(mem_buf) + sizeof(int64_t), meta_size);
  return shm;
}

static SharedMemoryPtr CopyTensorsDataToSharedMemory(
    const std::string& name,
    const std::vector<torch::optional<torch::Tensor>>& tensors) {
  int64_t memory_size = 0;
  for (const auto& optional_tensor : tensors) {
    if (optional_tensor.has_value()) {
      auto tensor = optional_tensor.value();
      memory_size += tensor.numel() * tensor.element_size();
    }
  }
  auto shm = std::make_unique<SharedMemory>(name);
  auto mem_buf = shm->Create(memory_size);
  for (auto optional_tensor : tensors) {
    if (optional_tensor.has_value()) {
      auto tensor = optional_tensor.value().contiguous();
      int64_t size = tensor.numel() * tensor.element_size();
      memcpy(mem_buf, tensor.data_ptr(), size);
      mem_buf = static_cast<char*>(mem_buf) + size;
    }
  }
  return shm;
}

/**
 * @brief Load tensors data from shared memory.
 * @param name The name of shared memory.
 * @param tensor_metas The meta info of tensors, including a flag indicating
 * whether the optional tensor has value, tensor shape and dtype.
 *
 * @return A pair of shared memory holding the tensors.
 */
static std::pair<SharedMemoryPtr, std::vector<torch::optional<torch::Tensor>>>
LoadTensorsDataFromSharedMemory(
    const std::string& name,
    const std::vector<
        std::tuple<bool, std::vector<int64_t>, torch::ScalarType>>&
        tensor_metas) {
  auto shm = std::make_unique<SharedMemory>(name);
  int64_t memory_size = 0;
  for (const auto& meta : tensor_metas) {
    if (std::get<0>(meta)) {
      int64_t size = std::accumulate(
          std::get<1>(meta).begin(), std::get<1>(meta).end(), 1,
          std::multiplies<int64_t>());
      memory_size += size * torch::elementSize(std::get<2>(meta));
    }
  }
  auto mem_buf = shm->Open(memory_size);
  std::vector<torch::optional<torch::Tensor>> optional_tensors;
  for (const auto& meta : tensor_metas) {
    if (std::get<0>(meta)) {
      auto tensor =
          torch::from_blob(mem_buf, std::get<1>(meta), std::get<2>(meta));
      optional_tensors.push_back(tensor);
      int64_t size = std::accumulate(
          std::get<1>(meta).begin(), std::get<1>(meta).end(), 1,
          std::multiplies<int64_t>());
      mem_buf = static_cast<char*>(mem_buf) +
                size * torch::elementSize(std::get<2>(meta));
    } else {
      optional_tensors.push_back(torch::nullopt);
    }
  }
  return std::make_pair(std::move(shm), std::move(optional_tensors));
}

SharedMemoryTensors CopyTensorsToSharedMemory(
    const std::string& name,
    const std::vector<torch::optional<torch::Tensor>>& tensors,
    int64_t max_meta_memory_size) {
  torch::serialize::OutputArchive archive;
  archive.write("num_tensors", static_cast<int64_t>(tensors.size()));
  for (size_t i = 0; i < tensors.size(); ++i) {
    archive.write(
        "tensor_" + std::to_string(i) + "_has_value", tensors[i].has_value());
    if (tensors[i].has_value()) {
      archive.write(
          "tensor_" + std::to_string(i) + "_shape", tensors[i].value().sizes());
      archive.write(
          "tensor_" + std::to_string(i) + "_dtype",
          tensors[i].value().scalar_type());
    }
  }
  auto meta_shm = CopyTorchArchiveToSharedMemory(
      name + "_meta", max_meta_memory_size, archive);
  auto data_shm = CopyTensorsDataToSharedMemory(name + "_data", tensors);

  std::vector<torch::optional<torch::Tensor>> ret_tensors;
  auto mem_buf = data_shm->GetMemory();
  for (auto optional_tensor : tensors) {
    if (optional_tensor.has_value()) {
      auto tensor = optional_tensor.value();
      ret_tensors.push_back(
          torch::from_blob(mem_buf, tensor.sizes(), tensor.dtype()));
      int64_t size = tensor.numel() * tensor.element_size();
      mem_buf = static_cast<char*>(mem_buf) + size;
    } else {
      ret_tensors.push_back(torch::nullopt);
    }
  }
  return std::make_tuple(
      std::move(meta_shm), std::move(data_shm), std::move(ret_tensors));
}

SharedMemoryTensors LoadTensorsFromSharedMemory(
    const std::string& name, int64_t meta_memory_size) {
  torch::serialize::InputArchive archive;
  auto meta_shm = LoadTorchArchiveFromSharedMemory(
      name + "_meta", meta_memory_size, archive);
  std::vector<std::tuple<bool, std::vector<int64_t>, torch::ScalarType>> metas;
  int64_t num_tensors = read_from_archive(archive, "num_tensors").toInt();
  for (int64_t i = 0; i < num_tensors; ++i) {
    bool has_value =
        read_from_archive(archive, "tensor_" + std::to_string(i) + "_has_value")
            .toBool();
    if (has_value) {
      auto shape =
          read_from_archive(archive, "tensor_" + std::to_string(i) + "_shape")
              .toIntVector();
      auto dtype =
          read_from_archive(archive, "tensor_" + std::to_string(i) + "_dtype")
              .toScalarType();
      metas.push_back({true, shape, dtype});
    } else {
      metas.push_back({false, {}, torch::ScalarType::Undefined});
    }
  }
  SharedMemoryPtr data_shm;
  std::vector<torch::optional<torch::Tensor>> ret_tensors;
  std::tie(data_shm, ret_tensors) =
      LoadTensorsDataFromSharedMemory(name + "_data", metas);
  return std::make_tuple(
      std::move(meta_shm), std::move(data_shm), std::move(ret_tensors));
}

}  // namespace sampling
}  // namespace graphbolt