"git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "1193e2e8ec29a8094d11b6551ab8c57b062cdb68"
Unverified Commit 15e3ff87 authored by Chao Ma's avatar Chao Ma Committed by GitHub
Browse files

[Distributed] Fix memory leak in kvstore (#2016)

* Fix memory leak

* update

* update

* update

* update

* update

* update

* update
parent e6f6ce27
...@@ -737,7 +737,7 @@ class KVServer(object): ...@@ -737,7 +737,7 @@ class KVServer(object):
shared_data = empty_shared_mem(name+'-kvdata-', True, data_tensor.shape, data_type) shared_data = empty_shared_mem(name+'-kvdata-', True, data_tensor.shape, data_type)
dlpack = shared_data.to_dlpack() dlpack = shared_data.to_dlpack()
self._data_store[name] = F.zerocopy_from_dlpack(dlpack) self._data_store[name] = F.zerocopy_from_dlpack(dlpack)
self._data_store[name][:] = data_tensor[:] rpc.copy_data_to_shared_memory(self._data_store[name], data_tensor)
assert self._part_policy[name].get_part_size() == data_tensor.shape[0], \ assert self._part_policy[name].get_part_size() == data_tensor.shape[0], \
'kvserver expect partition {} for {} has {} rows, but gets {} rows'.format( 'kvserver expect partition {} for {} has {} rows, but gets {} rows'.format(
self._part_policy[name].part_id, self._part_policy[name].part_id,
......
...@@ -16,7 +16,7 @@ __all__ = ['set_rank', 'get_rank', 'Request', 'Response', 'register_service', \ ...@@ -16,7 +16,7 @@ __all__ = ['set_rank', 'get_rank', 'Request', 'Response', 'register_service', \
'get_num_machines', 'set_num_machines', 'get_machine_id', 'set_machine_id', \ 'get_num_machines', 'set_num_machines', 'get_machine_id', 'set_machine_id', \
'send_request', 'recv_request', 'send_response', 'recv_response', 'remote_call', \ 'send_request', 'recv_request', 'send_response', 'recv_response', 'remote_call', \
'send_request_to_machine', 'remote_call_to_machine', 'fast_pull', \ 'send_request_to_machine', 'remote_call_to_machine', 'fast_pull', \
'get_num_client', 'set_num_client', 'client_barrier'] 'get_num_client', 'set_num_client', 'client_barrier', 'copy_data_to_shared_memory']
REQUEST_CLASS_TO_SERVICE_ID = {} REQUEST_CLASS_TO_SERVICE_ID = {}
RESPONSE_CLASS_TO_SERVICE_ID = {} RESPONSE_CLASS_TO_SERVICE_ID = {}
...@@ -991,6 +991,12 @@ def register_ctrl_c(): ...@@ -991,6 +991,12 @@ def register_ctrl_c():
""" """
_CAPI_DGLRPCHandleCtrlC() _CAPI_DGLRPCHandleCtrlC()
def copy_data_to_shared_memory(source, dst):
"""Copy tensor data to shared-memory tensor
"""
_CAPI_DGLCopyDataToSharedMemory(F.zerocopy_to_dgl_ndarray(source),
F.zerocopy_to_dgl_ndarray(dst))
############### Some basic services will be defined here ############# ############### Some basic services will be defined here #############
CLIENT_REGISTER = 22451 CLIENT_REGISTER = 22451
......
...@@ -477,5 +477,15 @@ DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCFastPull") ...@@ -477,5 +477,15 @@ DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCFastPull")
*rv = res_tensor; *rv = res_tensor;
}); });
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLCopyDataToSharedMemory")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
NDArray source = args[0];
NDArray dst = args[1];
CHECK_EQ(source.GetSize(), dst.GetSize());
char* src_ptr = static_cast<char*>(source->data);
char* dst_ptr = static_cast<char*>(dst->data);
memcpy(src_ptr, dst_ptr, dst.GetSize());
});
} // namespace rpc } // namespace rpc
} // namespace dgl } // namespace dgl
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment