Unverified Commit 3d1f2e87 authored by Da Zheng's avatar Da Zheng Committed by GitHub
Browse files

delete shared memory when receiving signals. (#2419)



* delete shared memory when receive signal.

* rename.

* fix lint.

* fix lint.

* fix compile.

* Fix.

* we need to report error if the shared memory exist.

* disable tensorflow test for shared memory.

* revert.
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-2-202.us-west-1.compute.internal>
Co-authored-by: default avatarJinjing Zhou <VoVAllen@users.noreply.github.com>
parent e8054701
...@@ -986,10 +986,10 @@ def fast_pull(name, id_tensor, part_id, service_id, ...@@ -986,10 +986,10 @@ def fast_pull(name, id_tensor, part_id, service_id,
F.zerocopy_to_dgl_ndarray(local_data)) F.zerocopy_to_dgl_ndarray(local_data))
return F.zerocopy_from_dgl_ndarray(res_tensor) return F.zerocopy_from_dgl_ndarray(res_tensor)
def register_ctrl_c(): def register_sig_handler():
"""HandleCtrlC Register for handling Ctrl+C event. """Register for handling signal event.
""" """
_CAPI_DGLRPCHandleCtrlC() _CAPI_DGLRPCHandleSignal()
def copy_data_to_shared_memory(dst, source): def copy_data_to_shared_memory(dst, source):
"""Copy tensor data to shared-memory tensor """Copy tensor data to shared-memory tensor
......
...@@ -133,7 +133,7 @@ def connect_to_server(ip_config, num_servers, max_queue_size=MAX_QUEUE_SIZE, net ...@@ -133,7 +133,7 @@ def connect_to_server(ip_config, num_servers, max_queue_size=MAX_QUEUE_SIZE, net
rpc.register_service(rpc.CLIENT_BARRIER, rpc.register_service(rpc.CLIENT_BARRIER,
rpc.ClientBarrierRequest, rpc.ClientBarrierRequest,
rpc.ClientBarrierResponse) rpc.ClientBarrierResponse)
rpc.register_ctrl_c() rpc.register_sig_handler()
server_namebook = rpc.read_ip_config(ip_config, num_servers) server_namebook = rpc.read_ip_config(ip_config, num_servers)
num_servers = len(server_namebook) num_servers = len(server_namebook)
rpc.set_num_server(num_servers) rpc.set_num_server(num_servers)
......
...@@ -38,8 +38,8 @@ def start_server(server_id, ip_config, num_servers, num_clients, server_state, \ ...@@ -38,8 +38,8 @@ def start_server(server_id, ip_config, num_servers, num_clients, server_state, \
assert num_clients >= 0, 'num_client (%d) cannot be a negative number.' % num_client assert num_clients >= 0, 'num_client (%d) cannot be a negative number.' % num_client
assert max_queue_size > 0, 'queue_size (%d) cannot be a negative number.' % queue_size assert max_queue_size > 0, 'queue_size (%d) cannot be a negative number.' % queue_size
assert net_type in ('socket'), 'net_type (%s) can only be \'socket\'' % net_type assert net_type in ('socket'), 'net_type (%s) can only be \'socket\'' % net_type
# HandleCtrlC Register for handling Ctrl+C event # Register signal handler.
rpc.register_ctrl_c() rpc.register_sig_handler()
# Register some basic services # Register some basic services
rpc.register_service(rpc.CLIENT_REGISTER, rpc.register_service(rpc.CLIENT_REGISTER,
rpc.ClientRegisterRequest, rpc.ClientRegisterRequest,
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include <dgl/array.h> #include <dgl/array.h>
#include <dgl/random.h> #include <dgl/random.h>
#include <dgl/zerocopy_serializer.h> #include <dgl/zerocopy_serializer.h>
#include "../runtime/resource_manager.h"
#include "../c_api_common.h" #include "../c_api_common.h"
using dgl::network::StringPrintf; using dgl::network::StringPrintf;
...@@ -321,22 +322,24 @@ DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetTensors") ...@@ -321,22 +322,24 @@ DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetTensors")
#if defined(__linux__) #if defined(__linux__)
/*! /*!
* \brief CtrlCHandler, exits if Ctrl+C is pressed * \brief The signal handler.
* \param s signal * \param s signal
*/ */
void CtrlCHandler(int s) { void SigHandler(int s) {
LOG(INFO) << "\nUser pressed Ctrl+C, Exiting"; LOG(INFO) << "\nUser pressed Ctrl+C, Exiting";
CleanupResources();
exit(1); exit(1);
} }
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCHandleCtrlC") DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCHandleSignal")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
// Ctrl+C handler // Ctrl+C handler
struct sigaction sigIntHandler; struct sigaction sigHandler;
sigIntHandler.sa_handler = CtrlCHandler; sigHandler.sa_handler = SigHandler;
sigemptyset(&sigIntHandler.sa_mask); sigemptyset(&sigHandler.sa_mask);
sigIntHandler.sa_flags = 0; sigHandler.sa_flags = 0;
sigaction(SIGINT, &sigIntHandler, nullptr); sigaction(SIGINT, &sigHandler, nullptr);
sigaction(SIGTERM, &sigHandler, nullptr);
}); });
#endif #endif
......
/*!
* Copyright (c) 2020 by Contributors
* \file resource_manager.cc
* \brief Manage the resources.
*/
#include "resource_manager.h"
#include <dmlc/logging.h>
#include <utility>
namespace dgl {
namespace runtime {
/*
* The runtime allocates resources during the computation. Some of the resources cannot be
* destroyed after the process exits especially when the process doesn't exits normally.
* We need to keep track of the resources in the system and clean them up properly.
*/
class ResourceManager {
std::unordered_map<std::string, std::shared_ptr<Resource>> resources;
public:
void Add(const std::string &key, std::shared_ptr<Resource> resource) {
auto it = resources.find(key);
CHECK(it == resources.end()) << key << " already exists";
resources.insert(std::pair<std::string, std::shared_ptr<Resource>>(key, resource));
}
void Erase(const std::string &key) {
resources.erase(key);
}
void Cleanup() {
for (auto it = resources.begin(); it != resources.end(); it++) {
it->second->Destroy();
}
resources.clear();
}
};
static ResourceManager manager;
void AddResource(const std::string &key, std::shared_ptr<Resource> resource) {
manager.Add(key, resource);
}
void DeleteResource(const std::string &key) {
manager.Erase(key);
}
void CleanupResources() {
manager.Cleanup();
}
} // namespace runtime
} // namespace dgl
/*!
* Copyright (c) 2020 by Contributors
* \file resource_manager.h
* \brief Manage the resources in the runtime system.
*/
#ifndef DGL_RUNTIME_RESOURCE_MANAGER_H_
#define DGL_RUNTIME_RESOURCE_MANAGER_H_
#include <unordered_map>
#include <string>
#include <memory>
namespace dgl {
namespace runtime {
/*
* A class that provides the interface to describe a resource that can be managed by
* a resource manager. Some of the resources cannot be free'd automatically when
* the process exits, especially when the process doesn't exit normally. One example
* is shared memory. We can keep track of this kind of resources and manage them
* properly.
*/
class Resource {
public:
virtual ~Resource() {
}
virtual void Destroy() = 0;
};
// Add resource.
void AddResource(const std::string &key, std::shared_ptr<Resource> resource);
// Delete resource.
void DeleteResource(const std::string &key);
// Clean up all resources.
void CleanupResources();
} // namespace runtime
} // namespace dgl
#endif // DGL_RUNTIME_RESOURCE_MANAGER_H_
...@@ -13,9 +13,31 @@ ...@@ -13,9 +13,31 @@
#include <dmlc/logging.h> #include <dmlc/logging.h>
#include <dgl/runtime/shared_mem.h> #include <dgl/runtime/shared_mem.h>
#include "resource_manager.h"
namespace dgl { namespace dgl {
namespace runtime { namespace runtime {
#ifndef _WIN32
/*
* Shared memory is a resource that cannot be cleaned up if the process doesn't
* exit normally. We'll manage the resource with ResourceManager.
*/
class SharedMemoryResource: public Resource {
std::string name;
public:
explicit SharedMemoryResource(const std::string &name) {
this->name = name;
}
void Destroy() {
LOG(INFO) << "remove " << name << " for shared memory";
shm_unlink(name.c_str());
}
};
#endif // _WIN32
SharedMemory::SharedMemory(const std::string &name) { SharedMemory::SharedMemory(const std::string &name) {
#ifndef _WIN32 #ifndef _WIN32
this->name = name; this->name = name;
...@@ -35,6 +57,8 @@ SharedMemory::~SharedMemory() { ...@@ -35,6 +57,8 @@ SharedMemory::~SharedMemory() {
if (own) { if (own) {
LOG(INFO) << "remove " << name << " for shared memory"; LOG(INFO) << "remove " << name << " for shared memory";
shm_unlink(name.c_str()); shm_unlink(name.c_str());
// The resource has been deleted. We don't need to keep track of it any more.
DeleteResource(name);
} }
#else #else
LOG(FATAL) << "Shared memory is not supported on Windows."; LOG(FATAL) << "Shared memory is not supported on Windows.";
...@@ -45,9 +69,13 @@ void *SharedMemory::CreateNew(size_t size) { ...@@ -45,9 +69,13 @@ void *SharedMemory::CreateNew(size_t size) {
#ifndef _WIN32 #ifndef _WIN32
this->own = true; this->own = true;
// We need to create a shared-memory file.
// TODO(zhengda) we need to report error if the shared-memory file exists.
int flag = O_RDWR|O_CREAT; int flag = O_RDWR|O_CREAT;
fd = shm_open(name.c_str(), flag, S_IRUSR | S_IWUSR); fd = shm_open(name.c_str(), flag, S_IRUSR | S_IWUSR);
CHECK_NE(fd, -1) << "fail to open " << name << ": " << strerror(errno); CHECK_NE(fd, -1) << "fail to open " << name << ": " << strerror(errno);
// Shared memory cannot be deleted if the process exits abnormally.
AddResource(name, std::shared_ptr<Resource>(new SharedMemoryResource(name)));
auto res = ftruncate(fd, size); auto res = ftruncate(fd, size);
CHECK_NE(res, -1) CHECK_NE(res, -1)
<< "Failed to truncate the file. " << strerror(errno); << "Failed to truncate the file. " << strerror(errno);
......
...@@ -43,6 +43,7 @@ def _assert_is_identical_hetero(g, g2): ...@@ -43,6 +43,7 @@ def _assert_is_identical_hetero(g, g2):
assert F.array_equal(dst, dst2) assert F.array_equal(dst, dst2)
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet') @unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
@unittest.skipIf(dgl.backend.backend_name == 'tensorflow', reason='Not support tensorflow for now')
@parametrize_dtype @parametrize_dtype
def test_single_process(idtype): def test_single_process(idtype):
hg = create_test_graph(idtype=idtype) hg = create_test_graph(idtype=idtype)
...@@ -60,6 +61,7 @@ def sub_proc(hg_origin, name): ...@@ -60,6 +61,7 @@ def sub_proc(hg_origin, name):
_assert_is_identical_hetero(hg_origin, hg_save_again) _assert_is_identical_hetero(hg_origin, hg_save_again)
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet') @unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
@unittest.skipIf(dgl.backend.backend_name == 'tensorflow', reason='Not support tensorflow for now')
@parametrize_dtype @parametrize_dtype
def test_multi_process(idtype): def test_multi_process(idtype):
hg = create_test_graph(idtype=idtype) hg = create_test_graph(idtype=idtype)
...@@ -70,6 +72,7 @@ def test_multi_process(idtype): ...@@ -70,6 +72,7 @@ def test_multi_process(idtype):
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet') @unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
@unittest.skipIf(F._default_context_str == 'cpu', reason="Need gpu for this test") @unittest.skipIf(F._default_context_str == 'cpu', reason="Need gpu for this test")
@unittest.skipIf(dgl.backend.backend_name == 'tensorflow', reason='Not support tensorflow for now')
def test_copy_from_gpu(): def test_copy_from_gpu():
hg = create_test_graph(idtype=F.int32) hg = create_test_graph(idtype=F.int32)
hg_gpu = hg.to(F.cuda()) hg_gpu = hg.to(F.cuda())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment