Unverified Commit b226fe01 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Windows] Support NDArray in shared memory on Windows (#3615)

* support shared memory on windows

* Update shared_mem.cc
parent 9c106547
...@@ -259,9 +259,7 @@ class NDArray { ...@@ -259,9 +259,7 @@ class NDArray {
template<typename T> template<typename T>
std::vector<T> ToVector() const; std::vector<T> ToVector() const;
#ifndef _WIN32
std::shared_ptr<SharedMemory> GetSharedMem() const; std::shared_ptr<SharedMemory> GetSharedMem() const;
#endif // _WIN32
/*! /*!
* \brief Function to copy data from one array to another. * \brief Function to copy data from one array to another.
...@@ -313,9 +311,7 @@ struct NDArray::Container { ...@@ -313,9 +311,7 @@ struct NDArray::Container {
*/ */
DLTensor dl_tensor; DLTensor dl_tensor;
#ifndef _WIN32
std::shared_ptr<SharedMemory> mem; std::shared_ptr<SharedMemory> mem;
#endif // _WIN32
/*! /*!
* \brief addtional context, reserved for recycling * \brief addtional context, reserved for recycling
* \note We can attach additional content here * \note We can attach additional content here
......
...@@ -6,6 +6,9 @@ ...@@ -6,6 +6,9 @@
#ifndef DGL_RUNTIME_SHARED_MEM_H_ #ifndef DGL_RUNTIME_SHARED_MEM_H_
#define DGL_RUNTIME_SHARED_MEM_H_ #define DGL_RUNTIME_SHARED_MEM_H_
#ifdef _WIN32
#include <windows.h>
#endif // _WIN32
#include <string> #include <string>
namespace dgl { namespace dgl {
...@@ -28,7 +31,11 @@ class SharedMemory { ...@@ -28,7 +31,11 @@ class SharedMemory {
bool own_; bool own_;
/* \brief the file descripter of the shared memory. */ /* \brief the file descripter of the shared memory. */
#ifndef _WIN32
int fd_; int fd_;
#else // !_WIN32
HANDLE handle_;
#endif // _WIN32
/* \brief the address of the shared memory. */ /* \brief the address of the shared memory. */
void *ptr_; void *ptr_;
/* \brief the size of the shared memory. */ /* \brief the size of the shared memory. */
......
...@@ -70,7 +70,9 @@ class ThreadGroup { ...@@ -70,7 +70,9 @@ class ThreadGroup {
/*! /*!
* \brief Platform-agnostic no-op. * \brief Platform-agnostic no-op.
*/ */
void Yield(); // This used to be Yield(), renaming to YieldThread() because windows.h defined it as a
// macro in later SDKs.
void YieldThread();
/*! /*!
* \return the maximum number of effective workers for this system. * \return the maximum number of effective workers for this system.
......
...@@ -58,10 +58,8 @@ struct NDArray::Internal { ...@@ -58,10 +58,8 @@ struct NDArray::Internal {
using dgl::runtime::NDArray; using dgl::runtime::NDArray;
if (ptr->manager_ctx != nullptr) { if (ptr->manager_ctx != nullptr) {
static_cast<NDArray::Container*>(ptr->manager_ctx)->DecRef(); static_cast<NDArray::Container*>(ptr->manager_ctx)->DecRef();
#ifndef _WIN32
} else if (ptr->mem) { } else if (ptr->mem) {
ptr->mem = nullptr; ptr->mem = nullptr;
#endif // _WIN32
} else if (ptr->dl_tensor.data != nullptr) { } else if (ptr->dl_tensor.data != nullptr) {
dgl::runtime::DeviceAPI::Get(ptr->dl_tensor.ctx)->FreeDataSpace( dgl::runtime::DeviceAPI::Get(ptr->dl_tensor.ctx)->FreeDataSpace(
ptr->dl_tensor.ctx, ptr->dl_tensor.data); ptr->dl_tensor.ctx, ptr->dl_tensor.data);
...@@ -191,7 +189,6 @@ NDArray NDArray::EmptyShared(const std::string &name, ...@@ -191,7 +189,6 @@ NDArray NDArray::EmptyShared(const std::string &name,
NDArray ret = Internal::Create(shape, dtype, ctx); NDArray ret = Internal::Create(shape, dtype, ctx);
// setup memory content // setup memory content
size_t size = GetDataSize(ret.data_->dl_tensor); size_t size = GetDataSize(ret.data_->dl_tensor);
#ifndef _WIN32
auto mem = std::make_shared<SharedMemory>(name); auto mem = std::make_shared<SharedMemory>(name);
if (is_create) { if (is_create) {
ret.data_->dl_tensor.data = mem->CreateNew(size); ret.data_->dl_tensor.data = mem->CreateNew(size);
...@@ -200,9 +197,6 @@ NDArray NDArray::EmptyShared(const std::string &name, ...@@ -200,9 +197,6 @@ NDArray NDArray::EmptyShared(const std::string &name,
} }
ret.data_->mem = mem; ret.data_->mem = mem;
#else
LOG(FATAL) << "Windows doesn't support NDArray with shared memory";
#endif // _WIN32
return ret; return ret;
} }
...@@ -310,11 +304,9 @@ template std::vector<uint64_t> NDArray::ToVector<uint64_t>() const; ...@@ -310,11 +304,9 @@ template std::vector<uint64_t> NDArray::ToVector<uint64_t>() const;
template std::vector<float> NDArray::ToVector<float>() const; template std::vector<float> NDArray::ToVector<float>() const;
template std::vector<double> NDArray::ToVector<double>() const; template std::vector<double> NDArray::ToVector<double>() const;
#ifndef _WIN32
std::shared_ptr<SharedMemory> NDArray::GetSharedMem() const { std::shared_ptr<SharedMemory> NDArray::GetSharedMem() const {
return this->data_->mem; return this->data_->mem;
} }
#endif // _WIN32
void NDArray::Save(dmlc::Stream* strm) const { void NDArray::Save(dmlc::Stream* strm) const {
......
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
namespace dgl { namespace dgl {
namespace runtime { namespace runtime {
#ifndef _WIN32
/* /*
* Shared memory is a resource that cannot be cleaned up if the process doesn't * Shared memory is a resource that cannot be cleaned up if the process doesn't
* exit normally. We'll manage the resource with ResourceManager. * exit normally. We'll manage the resource with ResourceManager.
...@@ -33,21 +32,25 @@ class SharedMemoryResource: public Resource { ...@@ -33,21 +32,25 @@ class SharedMemoryResource: public Resource {
void Destroy() { void Destroy() {
// LOG(INFO) << "remove " << name << " for shared memory"; // LOG(INFO) << "remove " << name << " for shared memory";
#ifndef _WIN32
shm_unlink(name.c_str()); shm_unlink(name.c_str());
#else // _WIN32
// NOTHING; Windows automatically removes the shared memory object once all handles
// are unmapped.
#endif
} }
}; };
#endif // _WIN32
SharedMemory::SharedMemory(const std::string &name) { SharedMemory::SharedMemory(const std::string &name) {
#ifndef _WIN32
this->name = name; this->name = name;
this->own_ = false; this->own_ = false;
#ifndef _WIN32
this->fd_ = -1; this->fd_ = -1;
#else
this->handle_ = nullptr;
#endif
this->ptr_ = nullptr; this->ptr_ = nullptr;
this->size_ = 0; this->size_ = 0;
#else
LOG(FATAL) << "Shared memory is not supported on Windows.";
#endif // _WIN32
} }
SharedMemory::~SharedMemory() { SharedMemory::~SharedMemory() {
...@@ -61,7 +64,9 @@ SharedMemory::~SharedMemory() { ...@@ -61,7 +64,9 @@ SharedMemory::~SharedMemory() {
DeleteResource(name); DeleteResource(name);
} }
#else #else
LOG(FATAL) << "Shared memory is not supported on Windows."; CHECK(UnmapViewOfFile(ptr_)) << "Win32 Error: " << GetLastError();
CloseHandle(handle_);
// Windows do not need a separate shm_unlink step.
#endif // _WIN32 #endif // _WIN32
} }
...@@ -74,7 +79,7 @@ void *SharedMemory::CreateNew(size_t sz) { ...@@ -74,7 +79,7 @@ void *SharedMemory::CreateNew(size_t sz) {
int flag = O_RDWR|O_CREAT; int flag = O_RDWR|O_CREAT;
fd_ = shm_open(name.c_str(), flag, S_IRUSR | S_IWUSR); fd_ = shm_open(name.c_str(), flag, S_IRUSR | S_IWUSR);
CHECK_NE(fd_, -1) << "fail to open " << name << ": " << strerror(errno); CHECK_NE(fd_, -1) << "fail to open " << name << ": " << strerror(errno);
// Shared memory cannot be deleted if the process exits abnormally. // Shared memory cannot be deleted if the process exits abnormally in Linux.
AddResource(name, std::shared_ptr<Resource>(new SharedMemoryResource(name))); AddResource(name, std::shared_ptr<Resource>(new SharedMemoryResource(name)));
auto res = ftruncate(fd_, sz); auto res = ftruncate(fd_, sz);
CHECK_NE(res, -1) CHECK_NE(res, -1)
...@@ -85,8 +90,22 @@ void *SharedMemory::CreateNew(size_t sz) { ...@@ -85,8 +90,22 @@ void *SharedMemory::CreateNew(size_t sz) {
this->size_ = sz; this->size_ = sz;
return ptr_; return ptr_;
#else #else
LOG(FATAL) << "Shared memory is not supported on Windows."; handle_ = CreateFileMapping(
return nullptr; INVALID_HANDLE_VALUE,
nullptr,
PAGE_READWRITE,
static_cast<DWORD>(sz >> 32),
static_cast<DWORD>(sz & 0xFFFFFFFF),
name.c_str());
CHECK(handle_ != nullptr) << "fail to open " << name << ", Win32 error: " << GetLastError();
ptr_ = MapViewOfFile(handle_, FILE_MAP_ALL_ACCESS, 0, 0, sz);
if (ptr_ == nullptr) {
LOG(FATAL) << "Memory mapping failed, Win32 error: " << GetLastError();
CloseHandle(handle_);
return nullptr;
}
this->size_ = sz;
return ptr_;
#endif // _WIN32 #endif // _WIN32
} }
...@@ -101,23 +120,36 @@ void *SharedMemory::Open(size_t sz) { ...@@ -101,23 +120,36 @@ void *SharedMemory::Open(size_t sz) {
this->size_ = sz; this->size_ = sz;
return ptr_; return ptr_;
#else #else
LOG(FATAL) << "Shared memory is not supported on Windows."; handle_ = OpenFileMapping(FILE_MAP_ALL_ACCESS, FALSE, name.c_str());
return nullptr; CHECK(handle_ != nullptr) << "fail to open " << name << ", Win32 Error: " << GetLastError();
ptr_ = MapViewOfFile(handle_, FILE_MAP_ALL_ACCESS, 0, 0, sz);
if (ptr_ == nullptr) {
LOG(FATAL) << "Memory mapping failed, Win32 error: " << GetLastError();
CloseHandle(handle_);
return nullptr;
}
this->size_ = sz;
return ptr_;
#endif // _WIN32 #endif // _WIN32
} }
bool SharedMemory::Exist(const std::string &name) { bool SharedMemory::Exist(const std::string &name) {
#ifndef _WIN32 #ifndef _WIN32
int fd_ = shm_open(name.c_str(), O_RDONLY, S_IRUSR | S_IWUSR); int fd = shm_open(name.c_str(), O_RDONLY, S_IRUSR | S_IWUSR);
if (fd_ >= 0) { if (fd >= 0) {
close(fd_); close(fd);
return true; return true;
} else { } else {
return false; return false;
} }
#else #else
LOG(FATAL) << "Shared memory is not supported on Windows."; HANDLE handle = OpenFileMapping(FILE_MAP_ALL_ACCESS, FALSE, name.c_str());
return false; if (handle != nullptr) {
CloseHandle(handle);
return true;
} else {
return false;
}
#endif // _WIN32 #endif // _WIN32
} }
......
...@@ -68,7 +68,7 @@ class ParallelLauncher { ...@@ -68,7 +68,7 @@ class ParallelLauncher {
// Wait n jobs to finish // Wait n jobs to finish
int WaitForJobs() { int WaitForJobs() {
while (num_pending_.load() != 0) { while (num_pending_.load() != 0) {
dgl::runtime::threading::Yield(); dgl::runtime::threading::YieldThread();
} }
if (!has_error_.load()) return 0; if (!has_error_.load()) return 0;
// the following is intended to use string due to // the following is intended to use string due to
...@@ -143,7 +143,7 @@ class SpscTaskQueue { ...@@ -143,7 +143,7 @@ class SpscTaskQueue {
*/ */
void Push(const Task& input) { void Push(const Task& input) {
while (!Enqueue(input)) { while (!Enqueue(input)) {
dgl::runtime::threading::Yield(); dgl::runtime::threading::YieldThread();
} }
if (pending_.fetch_add(1) == -1) { if (pending_.fetch_add(1) == -1) {
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
...@@ -162,7 +162,7 @@ class SpscTaskQueue { ...@@ -162,7 +162,7 @@ class SpscTaskQueue {
// If a new task comes to the queue quickly, this wait avoid the worker from sleeping. // If a new task comes to the queue quickly, this wait avoid the worker from sleeping.
// The default spin count is set by following the typical omp convention // The default spin count is set by following the typical omp convention
for (uint32_t i = 0; i < spin_count && pending_.load() == 0; ++i) { for (uint32_t i = 0; i < spin_count && pending_.load() == 0; ++i) {
dgl::runtime::threading::Yield(); dgl::runtime::threading::YieldThread();
} }
if (pending_.fetch_sub(1) == 0) { if (pending_.fetch_sub(1) == 0) {
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
...@@ -374,7 +374,7 @@ int DGLBackendParallelBarrier(int task_id, DGLParallelGroupEnv* penv) { ...@@ -374,7 +374,7 @@ int DGLBackendParallelBarrier(int task_id, DGLParallelGroupEnv* penv) {
if (i != task_id) { if (i != task_id) {
while (sync_counter[i * kSyncStride].load( while (sync_counter[i * kSyncStride].load(
std::memory_order_relaxed) <= old_counter) { std::memory_order_relaxed) <= old_counter) {
dgl::runtime::threading::Yield(); dgl::runtime::threading::YieldThread();
} }
} }
} }
......
...@@ -191,7 +191,7 @@ int ThreadGroup::Configure(AffinityMode mode, int nthreads, bool exclude_worker0 ...@@ -191,7 +191,7 @@ int ThreadGroup::Configure(AffinityMode mode, int nthreads, bool exclude_worker0
return impl_->Configure(mode, nthreads, exclude_worker0); return impl_->Configure(mode, nthreads, exclude_worker0);
} }
void Yield() { void YieldThread() {
std::this_thread::yield(); std::this_thread::yield();
} }
......
...@@ -42,7 +42,6 @@ def _assert_is_identical_hetero(g, g2): ...@@ -42,7 +42,6 @@ def _assert_is_identical_hetero(g, g2):
assert F.array_equal(src, src2) assert F.array_equal(src, src2)
assert F.array_equal(dst, dst2) assert F.array_equal(dst, dst2)
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
@unittest.skipIf(dgl.backend.backend_name == 'tensorflow', reason='Not support tensorflow for now') @unittest.skipIf(dgl.backend.backend_name == 'tensorflow', reason='Not support tensorflow for now')
@parametrize_dtype @parametrize_dtype
def test_single_process(idtype): def test_single_process(idtype):
...@@ -60,7 +59,6 @@ def sub_proc(hg_origin, name): ...@@ -60,7 +59,6 @@ def sub_proc(hg_origin, name):
_assert_is_identical_hetero(hg_origin, hg_rebuild) _assert_is_identical_hetero(hg_origin, hg_rebuild)
_assert_is_identical_hetero(hg_origin, hg_save_again) _assert_is_identical_hetero(hg_origin, hg_save_again)
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
@unittest.skipIf(dgl.backend.backend_name == 'tensorflow', reason='Not support tensorflow for now') @unittest.skipIf(dgl.backend.backend_name == 'tensorflow', reason='Not support tensorflow for now')
@parametrize_dtype @parametrize_dtype
def test_multi_process(idtype): def test_multi_process(idtype):
...@@ -70,7 +68,6 @@ def test_multi_process(idtype): ...@@ -70,7 +68,6 @@ def test_multi_process(idtype):
p.start() p.start()
p.join() p.join()
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
@unittest.skipIf(F._default_context_str == 'cpu', reason="Need gpu for this test") @unittest.skipIf(F._default_context_str == 'cpu', reason="Need gpu for this test")
@unittest.skipIf(dgl.backend.backend_name == 'tensorflow', reason='Not support tensorflow for now') @unittest.skipIf(dgl.backend.backend_name == 'tensorflow', reason='Not support tensorflow for now')
def test_copy_from_gpu(): def test_copy_from_gpu():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment