"git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "0aa8310a4db072daba51187bbc6d1d725467ba35"
Unverified Commit 6731ea3a authored by Chao Ma's avatar Chao Ma Committed by GitHub
Browse files

[KVStore] Reduce memory cost of kvstore (#1156)

* API change of kvstore

* add demo for kvstore

* update

* remove duplicated log

* change queue size

* update

* update

* update

* update

* update

* update

* update

* update

* update

* fix lint

* change name

* update

* fix lint

* update

* update

* update

* update

* change message queue size to a python argument

* change default queue size to 2GB

* OMP_NUM_THREADS=1

* add multiple NICs support for kvstore

* test

* fix lint

* update

* update

* update

* update

* update

* update

* update

* fix lint

* fix lint

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* fix lint
parent 4f02bb75
...@@ -24,7 +24,8 @@ def start_client(): ...@@ -24,7 +24,8 @@ def start_client():
client = dgl.contrib.start_client(ip_config='ip_config.txt', client = dgl.contrib.start_client(ip_config='ip_config.txt',
ndata_partition_book=ndata_partition_book, ndata_partition_book=ndata_partition_book,
edata_partition_book=edata_partition_book) edata_partition_book=edata_partition_book,
close_shared_mem=True)
tensor_edata = client.pull(name='edata', id_tensor=mx.nd.array([0,1,2,3,4,5,6,7], dtype='int64')) tensor_edata = client.pull(name='edata', id_tensor=mx.nd.array([0,1,2,3,4,5,6,7], dtype='int64'))
......
...@@ -24,8 +24,8 @@ def start_client(): ...@@ -24,8 +24,8 @@ def start_client():
client = dgl.contrib.start_client(ip_config='ip_config.txt', client = dgl.contrib.start_client(ip_config='ip_config.txt',
ndata_partition_book=ndata_partition_book, ndata_partition_book=ndata_partition_book,
edata_partition_book=edata_partition_book) edata_partition_book=edata_partition_book,
close_shared_mem=True)
tensor_edata = client.pull(name='edata', id_tensor=th.tensor([0,1,2,3,4,5,6,7])) tensor_edata = client.pull(name='edata', id_tensor=th.tensor([0,1,2,3,4,5,6,7]))
tensor_ndata = client.pull(name='ndata', id_tensor=th.tensor([0,1,2,3,4,5,6,7])) tensor_ndata = client.pull(name='ndata', id_tensor=th.tensor([0,1,2,3,4,5,6,7]))
......
...@@ -4,6 +4,7 @@ from ..network import _finalize_sender, _finalize_receiver ...@@ -4,6 +4,7 @@ from ..network import _finalize_sender, _finalize_receiver
from ..network import _network_wait, _add_receiver_addr from ..network import _network_wait, _add_receiver_addr
from ..network import _receiver_wait, _sender_connect from ..network import _receiver_wait, _sender_connect
from ..network import _send_kv_msg, _recv_kv_msg from ..network import _send_kv_msg, _recv_kv_msg
from ..network import _clear_kv_msg
from ..network import KVMsgType, KVStoreMsg from ..network import KVMsgType, KVStoreMsg
from .. import backend as F from .. import backend as F
...@@ -377,6 +378,8 @@ class KVServer(object): ...@@ -377,6 +378,8 @@ class KVServer(object):
else: else:
raise RuntimeError('Unknown type of kvstore message: %d' % msg.type.value) raise RuntimeError('Unknown type of kvstore message: %d' % msg.type.value)
_clear_kv_msg(msg)
def _push_handler(self, name, ID, data, target): def _push_handler(self, name, ID, data, target):
"""Default handler for PUSH message. """Default handler for PUSH message.
...@@ -418,6 +421,7 @@ class KVServer(object): ...@@ -418,6 +421,7 @@ class KVServer(object):
""" """
return target[name][ID] return target[name][ID]
def _serialize_shared_tensor(self, name, shape, dtype): def _serialize_shared_tensor(self, name, shape, dtype):
"""Serialize shared tensor """Serialize shared tensor
...@@ -715,6 +719,9 @@ class KVClient(object): ...@@ -715,6 +719,9 @@ class KVClient(object):
msg_list.sort(key=self._takeId) msg_list.sort(key=self._takeId)
data_tensor = F.cat(seq=[msg.data for msg in msg_list], dim=0) data_tensor = F.cat(seq=[msg.data for msg in msg_list], dim=0)
for msg in msg_list:
_clear_kv_msg(msg)
return data_tensor[back_sorted_id] # return data with original index order return data_tensor[back_sorted_id] # return data with original index order
......
...@@ -313,3 +313,18 @@ def _recv_kv_msg(receiver): ...@@ -313,3 +313,18 @@ def _recv_kv_msg(receiver):
return msg return msg
raise RuntimeError('Unknown message type: %d' % msg_type.value) raise RuntimeError('Unknown message type: %d' % msg_type.value)
def _clear_kv_msg(msg):
"""Clear data of kvstore message
Parameters
----------
msg : KVStoreMsg
kvstore message
"""
if msg.data is not None:
F.sync()
data = F.zerocopy_to_dgl_ndarray(msg.data)
_CAPI_DeleteNDArrayData(data)
\ No newline at end of file
...@@ -591,5 +591,12 @@ DGL_REGISTER_GLOBAL("network._CAPI_ReceiverGetKVMsgData") ...@@ -591,5 +591,12 @@ DGL_REGISTER_GLOBAL("network._CAPI_ReceiverGetKVMsgData")
*rv = msg->data; *rv = msg->data;
}); });
DGL_REGISTER_GLOBAL("network._CAPI_DeleteNDArrayData")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
NDArray data = args[0];
delete [] data->data;
});
} // namespace network } // namespace network
} // namespace dgl } // namespace dgl
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment