Unverified Commit 38d292e5 authored by Chao Ma's avatar Chao Ma Committed by GitHub
Browse files

[KVStore] Fix memory leak bug. (#1174)

* API change of kvstore

* add demo for kvstore

* update

* remove duplicated log

* change queue size

* update

* update

* update

* update

* update

* update

* update

* update

* update

* fix lint

* change name

* update

* fix lint

* update

* update

* update

* update

* change message queue size to a python argument

* change default queue size to 2GB

* OMP_NUM_THREADS=1

* add multiple NICs support for kvstore

* test

* fix lint

* update

* update

* update

* update

* update

* update

* update

* fix lint

* fix lint

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* fix lint

* update

* fix lint

* delete msg

* clear kv msg

* update

* update

* update

* update

* update

* update

* is not None

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update
parent e705118b
......@@ -201,6 +201,8 @@ class KVServer(object):
# Create C communicator of sender and receiver
self._sender = _create_sender(net_type, msg_queue_size)
self._receiver = _create_receiver(net_type, msg_queue_size)
# A naive garbage collocetion for kvstore
self._garbage_msg = []
def __del__(self):
......@@ -311,7 +313,8 @@ class KVServer(object):
rank=self._server_id,
name=str(client_id),
id=None,
data=None)
data=None,
c_ptr=None)
_send_kv_msg(self._sender, msg, client_id)
# send serilaized shared-memory tensor information to clients
......@@ -329,7 +332,8 @@ class KVServer(object):
rank=self._server_id,
name=shared_tensor,
id=None,
data=None)
data=None,
c_ptr=None)
for client_id in range(len(self._client_namebook)):
_send_kv_msg(self._sender, msg, client_id)
......@@ -356,7 +360,8 @@ class KVServer(object):
rank=self._server_id,
name=msg.name,
id=msg.id,
data=res_tensor)
data=res_tensor,
c_ptr=None)
_send_kv_msg(self._sender, back_msg, msg.rank)
# Barrier message
elif msg.type == KVMsgType.BARRIER:
......@@ -367,7 +372,8 @@ class KVServer(object):
rank=self._server_id,
name=None,
id=None,
data=None)
data=None,
c_ptr=None)
for i in range(self._client_count):
_send_kv_msg(self._sender, back_msg, i)
self._barrier_count = 0
......@@ -378,7 +384,10 @@ class KVServer(object):
else:
raise RuntimeError('Unknown type of kvstore message: %d' % msg.type.value)
_clear_kv_msg(msg)
self._garbage_msg.append(msg)
if len(self._garbage_msg) > 1000:
_clear_kv_msg(self._garbage_msg)
self._garbage_msg = []
def _push_handler(self, name, ID, data, target):
......@@ -512,6 +521,9 @@ class KVClient(object):
# create C communicator of sender and receiver
self._sender = _create_sender(net_type, msg_queue_size)
self._receiver = _create_receiver(net_type, msg_queue_size)
# A naive garbage collocetion for kvstore
self._garbage_msg = []
def __del__(self):
......@@ -568,7 +580,8 @@ class KVClient(object):
rank=0,
name=self._addr,
id=None,
data=None)
data=None,
c_ptr=None)
for server_id in range(self._server_count):
_send_kv_msg(self._sender, msg, server_id)
......@@ -641,7 +654,8 @@ class KVClient(object):
rank=self._client_id,
name=name,
id=partial_id,
data=partial_data)
data=partial_data,
c_ptr=None)
_send_kv_msg(self._sender, msg, server[idx])
start += count[idx]
......@@ -665,6 +679,10 @@ class KVClient(object):
assert len(name) > 0, 'name cannot be empty.'
assert F.ndim(id_tensor) == 1, 'ID must be a vector.'
if len(self._garbage_msg) > 1000:
_clear_kv_msg(self._garbage_msg)
self._garbage_msg = []
# partition data (we can move this part of code into C-api if needed)
server_id = self._data_store[name+'-part-'][id_tensor]
# sort index by server id
......@@ -695,7 +713,8 @@ class KVClient(object):
rank=self._client_id,
name=name,
id=partial_id,
data=None)
data=None,
c_ptr=None)
_send_kv_msg(self._sender, msg, server[idx])
pull_count += 1
......@@ -708,20 +727,21 @@ class KVClient(object):
rank=server_id,
name=name,
id=None,
data=data)
data=data,
c_ptr=None)
msg_list.append(local_msg)
self._garbage_msg.append(local_msg)
# wait message from server nodes
for idx in range(pull_count):
msg_list.append(_recv_kv_msg(self._receiver))
remote_msg = _recv_kv_msg(self._receiver)
msg_list.append(remote_msg)
self._garbage_msg.append(remote_msg)
# sort msg by server id
msg_list.sort(key=self._takeId)
data_tensor = F.cat(seq=[msg.data for msg in msg_list], dim=0)
for msg in msg_list:
_clear_kv_msg(msg)
return data_tensor[back_sorted_id] # return data with original index order
......@@ -735,7 +755,8 @@ class KVClient(object):
rank=self._client_id,
name=None,
id=None,
data=None)
data=None,
c_ptr=None)
for server_id in range(self._server_count):
_send_kv_msg(self._sender, msg, server_id)
......@@ -756,7 +777,8 @@ class KVClient(object):
rank=self._client_id,
name=None,
id=None,
data=None)
data=None,
c_ptr=None)
_send_kv_msg(self._sender, msg, server_id)
......@@ -907,3 +929,4 @@ class KVClient(object):
tensor_shape = tuple(tensor_shape)
return tensor_name, tensor_shape, data_type
......@@ -186,7 +186,7 @@ class KVMsgType(Enum):
BARRIER = 6
IP_ID = 7
KVStoreMsg = namedtuple("KVStoreMsg", "type rank name id data")
KVStoreMsg = namedtuple("KVStoreMsg", "type rank name id data, c_ptr")
"""Message of DGL kvstore
Data Field
......@@ -273,7 +273,8 @@ def _recv_kv_msg(receiver):
rank=rank,
name=name,
id=tensor_id,
data=None)
data=None,
c_ptr=msg_ptr)
return msg
elif msg_type == KVMsgType.IP_ID:
name = _CAPI_ReceiverGetKVMsgName(msg_ptr)
......@@ -282,7 +283,8 @@ def _recv_kv_msg(receiver):
rank=rank,
name=name,
id=None,
data=None)
data=None,
c_ptr=msg_ptr)
return msg
elif msg_type in (KVMsgType.FINAL, KVMsgType.BARRIER):
msg = KVStoreMsg(
......@@ -290,7 +292,8 @@ def _recv_kv_msg(receiver):
rank=rank,
name=None,
id=None,
data=None)
data=None,
c_ptr=msg_ptr)
return msg
else:
name = _CAPI_ReceiverGetKVMsgName(msg_ptr)
......@@ -301,22 +304,19 @@ def _recv_kv_msg(receiver):
rank=rank,
name=name,
id=tensor_id,
data=data)
data=data,
c_ptr=msg_ptr)
return msg
raise RuntimeError('Unknown message type: %d' % msg_type.value)
def _clear_kv_msg(msg):
def _clear_kv_msg(garbage_msg):
"""Clear data of kvstore message
Parameters
----------
msg : KVStoreMsg
kvstore message
"""
if msg.data is not None:
F.sync()
data = F.zerocopy_to_dgl_ndarray(msg.data)
_CAPI_DeleteNDArrayData(data)
F.sync()
for msg in garbage_msg:
if msg.c_ptr is not None:
_CAPI_DeleteKVMsg(msg.c_ptr)
garbage_msg = []
\ No newline at end of file
......@@ -24,6 +24,14 @@ using namespace dgl::runtime;
namespace dgl {
namespace network {
static void NaiveDeleter(DLManagedTensor* managed_tensor) {
delete [] managed_tensor->dl_tensor.shape;
delete [] managed_tensor->dl_tensor.strides;
delete [] managed_tensor->dl_tensor.data;
delete managed_tensor;
}
NDArray CreateNDArrayFromRaw(std::vector<int64_t> shape,
DLDataType dtype,
DLContext ctx,
......@@ -46,6 +54,7 @@ NDArray CreateNDArrayFromRaw(std::vector<int64_t> shape,
tensor.data = raw;
DLManagedTensor *managed_tensor = new DLManagedTensor();
managed_tensor->dl_tensor = tensor;
managed_tensor->deleter = NaiveDeleter;
return NDArray::FromDLPack(managed_tensor);
}
......@@ -591,12 +600,14 @@ DGL_REGISTER_GLOBAL("network._CAPI_ReceiverGetKVMsgData")
*rv = msg->data;
});
DGL_REGISTER_GLOBAL("network._CAPI_DeleteNDArrayData")
DGL_REGISTER_GLOBAL("network._CAPI_DeleteKVMsg")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
NDArray data = args[0];
delete [] data->data;
KVMsgHandle chandle = args[0];
network::KVStoreMsg* msg = static_cast<KVStoreMsg*>(chandle);
delete msg;
});
} // namespace network
} // namespace dgl
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment