Unverified Commit b226fe01 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Windows] Support NDArray in shared memory on Windows (#3615)

* support shared memory on windows

* Update shared_mem.cc
parent 9c106547
......@@ -259,9 +259,7 @@ class NDArray {
template<typename T>
std::vector<T> ToVector() const;
#ifndef _WIN32
std::shared_ptr<SharedMemory> GetSharedMem() const;
#endif // _WIN32
/*!
* \brief Function to copy data from one array to another.
......@@ -313,9 +311,7 @@ struct NDArray::Container {
*/
DLTensor dl_tensor;
#ifndef _WIN32
std::shared_ptr<SharedMemory> mem;
#endif // _WIN32
/*!
* \brief addtional context, reserved for recycling
* \note We can attach additional content here
......
......@@ -6,6 +6,9 @@
#ifndef DGL_RUNTIME_SHARED_MEM_H_
#define DGL_RUNTIME_SHARED_MEM_H_
#ifdef _WIN32
#include <windows.h>
#endif // _WIN32
#include <string>
namespace dgl {
......@@ -28,7 +31,11 @@ class SharedMemory {
bool own_;
/* \brief the file descripter of the shared memory. */
#ifndef _WIN32
int fd_;
#else // !_WIN32
HANDLE handle_;
#endif // _WIN32
/* \brief the address of the shared memory. */
void *ptr_;
/* \brief the size of the shared memory. */
......
......@@ -70,7 +70,9 @@ class ThreadGroup {
/*!
* \brief Platform-agnostic no-op.
*/
void Yield();
// This used to be Yield(), renaming to YieldThread() because windows.h defined it as a
// macro in later SDKs.
void YieldThread();
/*!
* \return the maximum number of effective workers for this system.
......
......@@ -58,10 +58,8 @@ struct NDArray::Internal {
using dgl::runtime::NDArray;
if (ptr->manager_ctx != nullptr) {
static_cast<NDArray::Container*>(ptr->manager_ctx)->DecRef();
#ifndef _WIN32
} else if (ptr->mem) {
ptr->mem = nullptr;
#endif // _WIN32
} else if (ptr->dl_tensor.data != nullptr) {
dgl::runtime::DeviceAPI::Get(ptr->dl_tensor.ctx)->FreeDataSpace(
ptr->dl_tensor.ctx, ptr->dl_tensor.data);
......@@ -191,7 +189,6 @@ NDArray NDArray::EmptyShared(const std::string &name,
NDArray ret = Internal::Create(shape, dtype, ctx);
// setup memory content
size_t size = GetDataSize(ret.data_->dl_tensor);
#ifndef _WIN32
auto mem = std::make_shared<SharedMemory>(name);
if (is_create) {
ret.data_->dl_tensor.data = mem->CreateNew(size);
......@@ -200,9 +197,6 @@ NDArray NDArray::EmptyShared(const std::string &name,
}
ret.data_->mem = mem;
#else
LOG(FATAL) << "Windows doesn't support NDArray with shared memory";
#endif // _WIN32
return ret;
}
......@@ -310,11 +304,9 @@ template std::vector<uint64_t> NDArray::ToVector<uint64_t>() const;
template std::vector<float> NDArray::ToVector<float>() const;
template std::vector<double> NDArray::ToVector<double>() const;
#ifndef _WIN32
std::shared_ptr<SharedMemory> NDArray::GetSharedMem() const {
return this->data_->mem;
}
#endif // _WIN32
void NDArray::Save(dmlc::Stream* strm) const {
......
......@@ -18,7 +18,6 @@
namespace dgl {
namespace runtime {
#ifndef _WIN32
/*
* Shared memory is a resource that cannot be cleaned up if the process doesn't
* exit normally. We'll manage the resource with ResourceManager.
......@@ -33,21 +32,25 @@ class SharedMemoryResource: public Resource {
void Destroy() {
// LOG(INFO) << "remove " << name << " for shared memory";
#ifndef _WIN32
shm_unlink(name.c_str());
#else // _WIN32
// NOTHING; Windows automatically removes the shared memory object once all handles
// are unmapped.
#endif
}
};
#endif // _WIN32
SharedMemory::SharedMemory(const std::string &name) {
#ifndef _WIN32
this->name = name;
this->own_ = false;
#ifndef _WIN32
this->fd_ = -1;
#else
this->handle_ = nullptr;
#endif
this->ptr_ = nullptr;
this->size_ = 0;
#else
LOG(FATAL) << "Shared memory is not supported on Windows.";
#endif // _WIN32
}
SharedMemory::~SharedMemory() {
......@@ -61,7 +64,9 @@ SharedMemory::~SharedMemory() {
DeleteResource(name);
}
#else
LOG(FATAL) << "Shared memory is not supported on Windows.";
CHECK(UnmapViewOfFile(ptr_)) << "Win32 Error: " << GetLastError();
CloseHandle(handle_);
// Windows do not need a separate shm_unlink step.
#endif // _WIN32
}
......@@ -74,7 +79,7 @@ void *SharedMemory::CreateNew(size_t sz) {
int flag = O_RDWR|O_CREAT;
fd_ = shm_open(name.c_str(), flag, S_IRUSR | S_IWUSR);
CHECK_NE(fd_, -1) << "fail to open " << name << ": " << strerror(errno);
// Shared memory cannot be deleted if the process exits abnormally.
// Shared memory cannot be deleted if the process exits abnormally in Linux.
AddResource(name, std::shared_ptr<Resource>(new SharedMemoryResource(name)));
auto res = ftruncate(fd_, sz);
CHECK_NE(res, -1)
......@@ -85,8 +90,22 @@ void *SharedMemory::CreateNew(size_t sz) {
this->size_ = sz;
return ptr_;
#else
LOG(FATAL) << "Shared memory is not supported on Windows.";
handle_ = CreateFileMapping(
INVALID_HANDLE_VALUE,
nullptr,
PAGE_READWRITE,
static_cast<DWORD>(sz >> 32),
static_cast<DWORD>(sz & 0xFFFFFFFF),
name.c_str());
CHECK(handle_ != nullptr) << "fail to open " << name << ", Win32 error: " << GetLastError();
ptr_ = MapViewOfFile(handle_, FILE_MAP_ALL_ACCESS, 0, 0, sz);
if (ptr_ == nullptr) {
LOG(FATAL) << "Memory mapping failed, Win32 error: " << GetLastError();
CloseHandle(handle_);
return nullptr;
}
this->size_ = sz;
return ptr_;
#endif // _WIN32
}
......@@ -101,23 +120,36 @@ void *SharedMemory::Open(size_t sz) {
this->size_ = sz;
return ptr_;
#else
LOG(FATAL) << "Shared memory is not supported on Windows.";
handle_ = OpenFileMapping(FILE_MAP_ALL_ACCESS, FALSE, name.c_str());
CHECK(handle_ != nullptr) << "fail to open " << name << ", Win32 Error: " << GetLastError();
ptr_ = MapViewOfFile(handle_, FILE_MAP_ALL_ACCESS, 0, 0, sz);
if (ptr_ == nullptr) {
LOG(FATAL) << "Memory mapping failed, Win32 error: " << GetLastError();
CloseHandle(handle_);
return nullptr;
}
this->size_ = sz;
return ptr_;
#endif // _WIN32
}
bool SharedMemory::Exist(const std::string &name) {
#ifndef _WIN32
int fd_ = shm_open(name.c_str(), O_RDONLY, S_IRUSR | S_IWUSR);
if (fd_ >= 0) {
close(fd_);
int fd = shm_open(name.c_str(), O_RDONLY, S_IRUSR | S_IWUSR);
if (fd >= 0) {
close(fd);
return true;
} else {
return false;
}
#else
LOG(FATAL) << "Shared memory is not supported on Windows.";
HANDLE handle = OpenFileMapping(FILE_MAP_ALL_ACCESS, FALSE, name.c_str());
if (handle != nullptr) {
CloseHandle(handle);
return true;
} else {
return false;
}
#endif // _WIN32
}
......
......@@ -68,7 +68,7 @@ class ParallelLauncher {
// Wait n jobs to finish
int WaitForJobs() {
while (num_pending_.load() != 0) {
dgl::runtime::threading::Yield();
dgl::runtime::threading::YieldThread();
}
if (!has_error_.load()) return 0;
// the following is intended to use string due to
......@@ -143,7 +143,7 @@ class SpscTaskQueue {
*/
void Push(const Task& input) {
while (!Enqueue(input)) {
dgl::runtime::threading::Yield();
dgl::runtime::threading::YieldThread();
}
if (pending_.fetch_add(1) == -1) {
std::unique_lock<std::mutex> lock(mutex_);
......@@ -162,7 +162,7 @@ class SpscTaskQueue {
// If a new task comes to the queue quickly, this wait avoid the worker from sleeping.
// The default spin count is set by following the typical omp convention
for (uint32_t i = 0; i < spin_count && pending_.load() == 0; ++i) {
dgl::runtime::threading::Yield();
dgl::runtime::threading::YieldThread();
}
if (pending_.fetch_sub(1) == 0) {
std::unique_lock<std::mutex> lock(mutex_);
......@@ -374,7 +374,7 @@ int DGLBackendParallelBarrier(int task_id, DGLParallelGroupEnv* penv) {
if (i != task_id) {
while (sync_counter[i * kSyncStride].load(
std::memory_order_relaxed) <= old_counter) {
dgl::runtime::threading::Yield();
dgl::runtime::threading::YieldThread();
}
}
}
......
......@@ -191,7 +191,7 @@ int ThreadGroup::Configure(AffinityMode mode, int nthreads, bool exclude_worker0
return impl_->Configure(mode, nthreads, exclude_worker0);
}
void Yield() {
void YieldThread() {
std::this_thread::yield();
}
......
......@@ -42,7 +42,6 @@ def _assert_is_identical_hetero(g, g2):
assert F.array_equal(src, src2)
assert F.array_equal(dst, dst2)
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
@unittest.skipIf(dgl.backend.backend_name == 'tensorflow', reason='Not support tensorflow for now')
@parametrize_dtype
def test_single_process(idtype):
......@@ -60,7 +59,6 @@ def sub_proc(hg_origin, name):
_assert_is_identical_hetero(hg_origin, hg_rebuild)
_assert_is_identical_hetero(hg_origin, hg_save_again)
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
@unittest.skipIf(dgl.backend.backend_name == 'tensorflow', reason='Not support tensorflow for now')
@parametrize_dtype
def test_multi_process(idtype):
......@@ -70,7 +68,6 @@ def test_multi_process(idtype):
p.start()
p.join()
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
@unittest.skipIf(F._default_context_str == 'cpu', reason="Need gpu for this test")
@unittest.skipIf(dgl.backend.backend_name == 'tensorflow', reason='Not support tensorflow for now')
def test_copy_from_gpu():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment