Unverified Commit 38d292e5 authored by Chao Ma's avatar Chao Ma Committed by GitHub
Browse files

[KVStore] Fix memory leak bug. (#1174)

* API change of kvstore

* add demo for kvstore

* update

* remove duplicated log

* change queue size

* update

* update

* update

* update

* update

* update

* update

* update

* update

* fix lint

* change name

* update

* fix lint

* update

* update

* update

* update

* change message queue size to a python argument

* change default queue size to 2GB

* OMP_NUM_THREADS=1

* add multiple NICs support for kvstore

* test

* fix lint

* update

* update

* update

* update

* update

* update

* update

* fix lint

* fix lint

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* fix lint

* update

* fix lint

* delete msg

* clear kv msg

* update

* update

* update

* update

* update

* update

* is not None

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update
parent e705118b
...@@ -201,6 +201,8 @@ class KVServer(object): ...@@ -201,6 +201,8 @@ class KVServer(object):
# Create C communicator of sender and receiver # Create C communicator of sender and receiver
self._sender = _create_sender(net_type, msg_queue_size) self._sender = _create_sender(net_type, msg_queue_size)
self._receiver = _create_receiver(net_type, msg_queue_size) self._receiver = _create_receiver(net_type, msg_queue_size)
# A naive garbage collocetion for kvstore
self._garbage_msg = []
def __del__(self): def __del__(self):
...@@ -311,7 +313,8 @@ class KVServer(object): ...@@ -311,7 +313,8 @@ class KVServer(object):
rank=self._server_id, rank=self._server_id,
name=str(client_id), name=str(client_id),
id=None, id=None,
data=None) data=None,
c_ptr=None)
_send_kv_msg(self._sender, msg, client_id) _send_kv_msg(self._sender, msg, client_id)
# send serilaized shared-memory tensor information to clients # send serilaized shared-memory tensor information to clients
...@@ -329,7 +332,8 @@ class KVServer(object): ...@@ -329,7 +332,8 @@ class KVServer(object):
rank=self._server_id, rank=self._server_id,
name=shared_tensor, name=shared_tensor,
id=None, id=None,
data=None) data=None,
c_ptr=None)
for client_id in range(len(self._client_namebook)): for client_id in range(len(self._client_namebook)):
_send_kv_msg(self._sender, msg, client_id) _send_kv_msg(self._sender, msg, client_id)
...@@ -356,7 +360,8 @@ class KVServer(object): ...@@ -356,7 +360,8 @@ class KVServer(object):
rank=self._server_id, rank=self._server_id,
name=msg.name, name=msg.name,
id=msg.id, id=msg.id,
data=res_tensor) data=res_tensor,
c_ptr=None)
_send_kv_msg(self._sender, back_msg, msg.rank) _send_kv_msg(self._sender, back_msg, msg.rank)
# Barrier message # Barrier message
elif msg.type == KVMsgType.BARRIER: elif msg.type == KVMsgType.BARRIER:
...@@ -367,7 +372,8 @@ class KVServer(object): ...@@ -367,7 +372,8 @@ class KVServer(object):
rank=self._server_id, rank=self._server_id,
name=None, name=None,
id=None, id=None,
data=None) data=None,
c_ptr=None)
for i in range(self._client_count): for i in range(self._client_count):
_send_kv_msg(self._sender, back_msg, i) _send_kv_msg(self._sender, back_msg, i)
self._barrier_count = 0 self._barrier_count = 0
...@@ -378,7 +384,10 @@ class KVServer(object): ...@@ -378,7 +384,10 @@ class KVServer(object):
else: else:
raise RuntimeError('Unknown type of kvstore message: %d' % msg.type.value) raise RuntimeError('Unknown type of kvstore message: %d' % msg.type.value)
_clear_kv_msg(msg) self._garbage_msg.append(msg)
if len(self._garbage_msg) > 1000:
_clear_kv_msg(self._garbage_msg)
self._garbage_msg = []
def _push_handler(self, name, ID, data, target): def _push_handler(self, name, ID, data, target):
...@@ -512,6 +521,9 @@ class KVClient(object): ...@@ -512,6 +521,9 @@ class KVClient(object):
# create C communicator of sender and receiver # create C communicator of sender and receiver
self._sender = _create_sender(net_type, msg_queue_size) self._sender = _create_sender(net_type, msg_queue_size)
self._receiver = _create_receiver(net_type, msg_queue_size) self._receiver = _create_receiver(net_type, msg_queue_size)
# A naive garbage collocetion for kvstore
self._garbage_msg = []
def __del__(self): def __del__(self):
...@@ -568,7 +580,8 @@ class KVClient(object): ...@@ -568,7 +580,8 @@ class KVClient(object):
rank=0, rank=0,
name=self._addr, name=self._addr,
id=None, id=None,
data=None) data=None,
c_ptr=None)
for server_id in range(self._server_count): for server_id in range(self._server_count):
_send_kv_msg(self._sender, msg, server_id) _send_kv_msg(self._sender, msg, server_id)
...@@ -641,7 +654,8 @@ class KVClient(object): ...@@ -641,7 +654,8 @@ class KVClient(object):
rank=self._client_id, rank=self._client_id,
name=name, name=name,
id=partial_id, id=partial_id,
data=partial_data) data=partial_data,
c_ptr=None)
_send_kv_msg(self._sender, msg, server[idx]) _send_kv_msg(self._sender, msg, server[idx])
start += count[idx] start += count[idx]
...@@ -665,6 +679,10 @@ class KVClient(object): ...@@ -665,6 +679,10 @@ class KVClient(object):
assert len(name) > 0, 'name cannot be empty.' assert len(name) > 0, 'name cannot be empty.'
assert F.ndim(id_tensor) == 1, 'ID must be a vector.' assert F.ndim(id_tensor) == 1, 'ID must be a vector.'
if len(self._garbage_msg) > 1000:
_clear_kv_msg(self._garbage_msg)
self._garbage_msg = []
# partition data (we can move this part of code into C-api if needed) # partition data (we can move this part of code into C-api if needed)
server_id = self._data_store[name+'-part-'][id_tensor] server_id = self._data_store[name+'-part-'][id_tensor]
# sort index by server id # sort index by server id
...@@ -695,7 +713,8 @@ class KVClient(object): ...@@ -695,7 +713,8 @@ class KVClient(object):
rank=self._client_id, rank=self._client_id,
name=name, name=name,
id=partial_id, id=partial_id,
data=None) data=None,
c_ptr=None)
_send_kv_msg(self._sender, msg, server[idx]) _send_kv_msg(self._sender, msg, server[idx])
pull_count += 1 pull_count += 1
...@@ -708,20 +727,21 @@ class KVClient(object): ...@@ -708,20 +727,21 @@ class KVClient(object):
rank=server_id, rank=server_id,
name=name, name=name,
id=None, id=None,
data=data) data=data,
c_ptr=None)
msg_list.append(local_msg) msg_list.append(local_msg)
self._garbage_msg.append(local_msg)
# wait message from server nodes # wait message from server nodes
for idx in range(pull_count): for idx in range(pull_count):
msg_list.append(_recv_kv_msg(self._receiver)) remote_msg = _recv_kv_msg(self._receiver)
msg_list.append(remote_msg)
self._garbage_msg.append(remote_msg)
# sort msg by server id # sort msg by server id
msg_list.sort(key=self._takeId) msg_list.sort(key=self._takeId)
data_tensor = F.cat(seq=[msg.data for msg in msg_list], dim=0) data_tensor = F.cat(seq=[msg.data for msg in msg_list], dim=0)
for msg in msg_list:
_clear_kv_msg(msg)
return data_tensor[back_sorted_id] # return data with original index order return data_tensor[back_sorted_id] # return data with original index order
...@@ -735,7 +755,8 @@ class KVClient(object): ...@@ -735,7 +755,8 @@ class KVClient(object):
rank=self._client_id, rank=self._client_id,
name=None, name=None,
id=None, id=None,
data=None) data=None,
c_ptr=None)
for server_id in range(self._server_count): for server_id in range(self._server_count):
_send_kv_msg(self._sender, msg, server_id) _send_kv_msg(self._sender, msg, server_id)
...@@ -756,7 +777,8 @@ class KVClient(object): ...@@ -756,7 +777,8 @@ class KVClient(object):
rank=self._client_id, rank=self._client_id,
name=None, name=None,
id=None, id=None,
data=None) data=None,
c_ptr=None)
_send_kv_msg(self._sender, msg, server_id) _send_kv_msg(self._sender, msg, server_id)
...@@ -907,3 +929,4 @@ class KVClient(object): ...@@ -907,3 +929,4 @@ class KVClient(object):
tensor_shape = tuple(tensor_shape) tensor_shape = tuple(tensor_shape)
return tensor_name, tensor_shape, data_type return tensor_name, tensor_shape, data_type
...@@ -186,7 +186,7 @@ class KVMsgType(Enum): ...@@ -186,7 +186,7 @@ class KVMsgType(Enum):
BARRIER = 6 BARRIER = 6
IP_ID = 7 IP_ID = 7
KVStoreMsg = namedtuple("KVStoreMsg", "type rank name id data") KVStoreMsg = namedtuple("KVStoreMsg", "type rank name id data, c_ptr")
"""Message of DGL kvstore """Message of DGL kvstore
Data Field Data Field
...@@ -273,7 +273,8 @@ def _recv_kv_msg(receiver): ...@@ -273,7 +273,8 @@ def _recv_kv_msg(receiver):
rank=rank, rank=rank,
name=name, name=name,
id=tensor_id, id=tensor_id,
data=None) data=None,
c_ptr=msg_ptr)
return msg return msg
elif msg_type == KVMsgType.IP_ID: elif msg_type == KVMsgType.IP_ID:
name = _CAPI_ReceiverGetKVMsgName(msg_ptr) name = _CAPI_ReceiverGetKVMsgName(msg_ptr)
...@@ -282,7 +283,8 @@ def _recv_kv_msg(receiver): ...@@ -282,7 +283,8 @@ def _recv_kv_msg(receiver):
rank=rank, rank=rank,
name=name, name=name,
id=None, id=None,
data=None) data=None,
c_ptr=msg_ptr)
return msg return msg
elif msg_type in (KVMsgType.FINAL, KVMsgType.BARRIER): elif msg_type in (KVMsgType.FINAL, KVMsgType.BARRIER):
msg = KVStoreMsg( msg = KVStoreMsg(
...@@ -290,7 +292,8 @@ def _recv_kv_msg(receiver): ...@@ -290,7 +292,8 @@ def _recv_kv_msg(receiver):
rank=rank, rank=rank,
name=None, name=None,
id=None, id=None,
data=None) data=None,
c_ptr=msg_ptr)
return msg return msg
else: else:
name = _CAPI_ReceiverGetKVMsgName(msg_ptr) name = _CAPI_ReceiverGetKVMsgName(msg_ptr)
...@@ -301,22 +304,19 @@ def _recv_kv_msg(receiver): ...@@ -301,22 +304,19 @@ def _recv_kv_msg(receiver):
rank=rank, rank=rank,
name=name, name=name,
id=tensor_id, id=tensor_id,
data=data) data=data,
c_ptr=msg_ptr)
return msg return msg
raise RuntimeError('Unknown message type: %d' % msg_type.value) raise RuntimeError('Unknown message type: %d' % msg_type.value)
def _clear_kv_msg(msg): def _clear_kv_msg(garbage_msg):
"""Clear data of kvstore message """Clear data of kvstore message
Parameters
----------
msg : KVStoreMsg
kvstore message
""" """
if msg.data is not None:
F.sync() F.sync()
data = F.zerocopy_to_dgl_ndarray(msg.data) for msg in garbage_msg:
_CAPI_DeleteNDArrayData(data) if msg.c_ptr is not None:
_CAPI_DeleteKVMsg(msg.c_ptr)
garbage_msg = []
\ No newline at end of file
...@@ -24,6 +24,14 @@ using namespace dgl::runtime; ...@@ -24,6 +24,14 @@ using namespace dgl::runtime;
namespace dgl { namespace dgl {
namespace network { namespace network {
static void NaiveDeleter(DLManagedTensor* managed_tensor) {
delete [] managed_tensor->dl_tensor.shape;
delete [] managed_tensor->dl_tensor.strides;
delete [] managed_tensor->dl_tensor.data;
delete managed_tensor;
}
NDArray CreateNDArrayFromRaw(std::vector<int64_t> shape, NDArray CreateNDArrayFromRaw(std::vector<int64_t> shape,
DLDataType dtype, DLDataType dtype,
DLContext ctx, DLContext ctx,
...@@ -46,6 +54,7 @@ NDArray CreateNDArrayFromRaw(std::vector<int64_t> shape, ...@@ -46,6 +54,7 @@ NDArray CreateNDArrayFromRaw(std::vector<int64_t> shape,
tensor.data = raw; tensor.data = raw;
DLManagedTensor *managed_tensor = new DLManagedTensor(); DLManagedTensor *managed_tensor = new DLManagedTensor();
managed_tensor->dl_tensor = tensor; managed_tensor->dl_tensor = tensor;
managed_tensor->deleter = NaiveDeleter;
return NDArray::FromDLPack(managed_tensor); return NDArray::FromDLPack(managed_tensor);
} }
...@@ -591,12 +600,14 @@ DGL_REGISTER_GLOBAL("network._CAPI_ReceiverGetKVMsgData") ...@@ -591,12 +600,14 @@ DGL_REGISTER_GLOBAL("network._CAPI_ReceiverGetKVMsgData")
*rv = msg->data; *rv = msg->data;
}); });
DGL_REGISTER_GLOBAL("network._CAPI_DeleteNDArrayData") DGL_REGISTER_GLOBAL("network._CAPI_DeleteKVMsg")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
NDArray data = args[0]; KVMsgHandle chandle = args[0];
delete [] data->data; network::KVStoreMsg* msg = static_cast<KVStoreMsg*>(chandle);
delete msg;
}); });
} // namespace network } // namespace network
} // namespace dgl } // namespace dgl
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment