shared_memory.cc 4.96 KB
Newer Older
1
2
3
4
5
6
7
8
/**
 *  Copyright (c) 2023 by Contributors
 * @file shared_memory.cc
 * @brief Source file of graphbolt shared memory.
 */
#ifndef _WIN32
#include <fcntl.h>
#include <sys/mman.h>
9
#include <sys/stat.h>
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
#include <unistd.h>
#endif  // !_WIN32

#include <graphbolt/shared_memory.h>
#include <stdio.h>
#include <string.h>
#include <torch/torch.h>

namespace graphbolt {
namespace sampling {

// Two processes opening the same path are guaranteed to access the same shared
// memory object if and only if path begins with a slash ('/') character.
constexpr char kSharedMemNamePrefix[] = "/dgl.graphbolt.";
constexpr char kSharedMemNameSuffix[] = ".lock";

// A prefix and a suffix are added to the name of the shared memory to create
// the name of the shared memory object.
inline std::string DecorateName(const std::string& name) {
  return kSharedMemNamePrefix + name + kSharedMemNameSuffix;
}

SharedMemory::SharedMemory(const std::string& name)
    : name_(name), size_(0), ptr_(nullptr) {
#ifdef _WIN32
  this->handle_ = nullptr;
#else   // _WIN32
  this->file_descriptor_ = -1;
  this->is_creator_ = false;
#endif  // _WIN32
}

#ifdef _WIN32

SharedMemory::~SharedMemory() {
  if (ptr_) CHECK(UnmapViewOfFile(ptr_)) << "Win32 Error: " << GetLastError();
  if (handle_) CloseHandle(handle_);
}

void* SharedMemory::Create(size_t size) {
  size_ = size;

  std::string decorated_name = DecorateName(name_);
  handle_ = CreateFileMapping(
      INVALID_HANDLE_VALUE, nullptr, PAGE_READWRITE,
      static_cast<DWORD>(size >> 32), static_cast<DWORD>(size & 0xFFFFFFFF),
      decorated_name.c_str());
  TORCH_CHECK(
      handle_ != nullptr, "Failed to open ", decorated_name,
      ", Win32 error: ", GetLastError());

  ptr_ = MapViewOfFile(handle_, FILE_MAP_ALL_ACCESS, 0, 0, size);
  TORCH_CHECK(
      ptr_ != nullptr, "Memory mapping failed, Win32 error: ", GetLastError());
  return ptr_;
}

67
void* SharedMemory::Open() {
68
69
70
71
72
73
  std::string decorated_name = DecorateName(name_);
  handle_ = OpenFileMapping(FILE_MAP_ALL_ACCESS, FALSE, decorated_name.c_str());
  TORCH_CHECK(
      handle_ != nullptr, "Failed to open ", decorated_name,
      ", Win32 Error: ", GetLastError());

74
  ptr_ = MapViewOfFile(handle_, FILE_MAP_ALL_ACCESS, 0, 0, 0);
75
76
  TORCH_CHECK(
      ptr_ != nullptr, "Memory mapping failed, Win32 error: ", GetLastError());
77
78
79
80
81
82
83
84

  // Obtain the size of the memory-mapped file.
  MEMORY_BASIC_INFORMATION memInfo;
  TORCH_CHECK(
      VirtualQuery(ptr_, &memInfo, sizeof(memInfo)) != 0,
      "Failed to get the size of shared memory: ", GetLastError());
  size_ = static_cast<size_t>(memInfo.RegionSize);

85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
  return ptr_;
}

bool SharedMemory::Exists(const std::string& name) {
  std::string decorated_name = DecorateName(name);
  HANDLE handle =
      OpenFileMapping(FILE_MAP_ALL_ACCESS, FALSE, decorated_name.c_str());
  bool exists = handle != nullptr;
  if (exists) {
    CloseHandle(handle);
  }
  return exists;
}

#else  // _WIN32

SharedMemory::~SharedMemory() {
  if (ptr_ && size_ != 0) CHECK(munmap(ptr_, size_) != -1) << strerror(errno);
  if (file_descriptor_ != -1) close(file_descriptor_);

  std::string decorated_name = DecorateName(name_);
  if (is_creator_ && decorated_name != "") shm_unlink(decorated_name.c_str());
}

void *SharedMemory::Create(size_t size) {
  size_ = size;
  is_creator_ = true;

  // TODO(zhenkun): handle the error properly if the shared memory object
  // already exists.
  std::string decorated_name = DecorateName(name_);
  file_descriptor_ =
      shm_open(decorated_name.c_str(), O_RDWR | O_CREAT, S_IRUSR | S_IWUSR);
  TORCH_CHECK(file_descriptor_ != -1, "Failed to open: ", strerror(errno));

  auto status = ftruncate(file_descriptor_, size);
  TORCH_CHECK(status != -1, "Failed to truncate the file: ", strerror(errno));

  ptr_ =
      mmap(NULL, size, PROT_READ | PROT_WRITE, MAP_SHARED, file_descriptor_, 0);
  TORCH_CHECK(
      ptr_ != MAP_FAILED,
      "Failed to map shared memory, mmap failed with error: ", strerror(errno));
  return ptr_;
}

131
void *SharedMemory::Open() {
132
133
134
135
136
137
138
  std::string decorated_name = DecorateName(name_);
  file_descriptor_ =
      shm_open(decorated_name.c_str(), O_RDWR, S_IRUSR | S_IWUSR);
  TORCH_CHECK(
      file_descriptor_ != -1, "Failed to open ", decorated_name, ": ",
      strerror(errno));

139
140
141
142
143
144
145
146
  struct stat shm_stat;
  TORCH_CHECK(
      fstat(file_descriptor_, &shm_stat) == 0,
      "Failed to get the size of shared memory: ", strerror(errno));
  size_ = shm_stat.st_size;

  ptr_ = mmap(
      NULL, size_, PROT_READ | PROT_WRITE, MAP_SHARED, file_descriptor_, 0);
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
  TORCH_CHECK(
      ptr_ != MAP_FAILED,
      "Failed to map shared memory, mmap failed with error: ", strerror(errno));
  return ptr_;
}

bool SharedMemory::Exists(const std::string &name) {
  std::string decorated_name = DecorateName(name);
  int file_descriptor =
      shm_open(decorated_name.c_str(), O_RDONLY, S_IRUSR | S_IWUSR);
  bool exists = file_descriptor > 0;
  if (exists) {
    close(file_descriptor);
  }
  return exists;
}

#endif  // _WIN32

}  // namespace sampling
}  // namespace graphbolt