Unverified Commit 5fc334fc authored by Chao Ma's avatar Chao Ma Committed by GitHub
Browse files

[KVStore] Fix bug of get_meta() API (#1491)

* fix get-shape

* update

* update

* update

* update

* update

* fix typo

* update

* update
parent 30b8074a
...@@ -500,6 +500,18 @@ class KVServer(object): ...@@ -500,6 +500,18 @@ class KVServer(object):
shape=msg.shape, shape=msg.shape,
c_ptr=None) c_ptr=None)
_send_kv_msg(self._sender, back_msg, 0) _send_kv_msg(self._sender, back_msg, 0)
# Get shape message
elif msg.type == KVMsgType.GET_SHAPE:
data_shape = F.tensor(F.shape(self._data_store[msg.name+'-data-']))
back_msg = KVStoreMsg(
type=KVMsgType.GET_SHAPE_BACK,
rank=self._server_id,
name=msg.name,
id=None,
data=None,
shape=data_shape,
c_ptr=None)
_send_kv_msg(self._sender, back_msg, msg.rank)
# Barrier message # Barrier message
elif msg.type == KVMsgType.BARRIER: elif msg.type == KVMsgType.BARRIER:
self._barrier_count += 1 self._barrier_count += 1
...@@ -704,6 +716,7 @@ class KVClient(object): ...@@ -704,6 +716,7 @@ class KVClient(object):
self._has_data = set() self._has_data = set()
# This is used to store local data, which can share memory with local KVServer. # This is used to store local data, which can share memory with local KVServer.
self._data_store = {} self._data_store = {}
self._full_data_shape = {}
self._data_name_list = [] self._data_name_list = []
# Server information # Server information
self._server_namebook = server_namebook self._server_namebook = server_namebook
...@@ -792,10 +805,9 @@ class KVClient(object): ...@@ -792,10 +805,9 @@ class KVClient(object):
tensor_name, dtype = self._deserialize_shared_tensor(data) tensor_name, dtype = self._deserialize_shared_tensor(data)
while True: while True:
if (os.path.exists(tensor_name+'shape-'+str(self._machine_id))): if (os.path.exists(tensor_name+'shape-'+str(self._machine_id))):
time.sleep(2) # wait writing finish
break break
else: else:
time.sleep(2) # wait until the file been created time.sleep(1) # wait until the file been created
shape, data_type = self._read_data_shape_type(tensor_name+'shape-'+str(self._machine_id)) shape, data_type = self._read_data_shape_type(tensor_name+'shape-'+str(self._machine_id))
assert data_type == dtype assert data_type == dtype
shared_data = empty_shared_mem(tensor_name, False, shape, dtype) shared_data = empty_shared_mem(tensor_name, False, shape, dtype)
...@@ -805,6 +817,29 @@ class KVClient(object): ...@@ -805,6 +817,29 @@ class KVClient(object):
self._data_name_list.append(tensor_name[0:-6]) self._data_name_list.append(tensor_name[0:-6])
self._has_data.add(tensor_name) self._has_data.add(tensor_name)
# Get full shape of each data
for name in self._data_name_list:
data_shape = list(F.shape(self._data_store[name+'-data-']))
data_shape[0] = 0
msg = KVStoreMsg(
type=KVMsgType.GET_SHAPE,
rank=self._client_id,
name=name,
id=None,
data=None,
shape=None,
c_ptr=None)
# send msg
for m_id in range(self._machine_count):
s_id = m_id * self._group_count
_send_kv_msg(self._sender, msg, s_id)
# recv msg
for m_id in range(self._machine_count):
back_msg = _recv_kv_msg(self._receiver)
assert back_msg.type == KVMsgType.GET_SHAPE_BACK
data_shape[0] += ((F.asnumpy(back_msg.shape)).tolist())[0]
self._full_data_shape[name] = tuple(data_shape)
print("KVClient %d connect to kvstore successfully!" % self.get_id()) print("KVClient %d connect to kvstore successfully!" % self.get_id())
...@@ -872,6 +907,7 @@ class KVClient(object): ...@@ -872,6 +907,7 @@ class KVClient(object):
self._data_store[name+'-data-'] = F.zerocopy_from_dlpack(dlpack) self._data_store[name+'-data-'] = F.zerocopy_from_dlpack(dlpack)
self._has_data.add(name+'-data-') self._has_data.add(name+'-data-')
self._data_name_list.append(name) self._data_name_list.append(name)
self._full_data_shape[name] = tuple(shape)
def print(self): def print(self):
...@@ -947,8 +983,8 @@ class KVClient(object): ...@@ -947,8 +983,8 @@ class KVClient(object):
assert name + '-data-' in self._has_data, 'Data (%s) does not exist!' % name assert name + '-data-' in self._has_data, 'Data (%s) does not exist!' % name
data_type = F.dtype(self._data_store[name+'-data-']) data_type = F.dtype(self._data_store[name+'-data-'])
data_shape = F.shape(self._data_store[name+'-data-'])
partition_book = self._data_store[name+'-part-'] partition_book = self._data_store[name+'-part-']
data_shape = self._full_data_shape[name]
return (data_type, data_shape, partition_book) return (data_type, data_shape, partition_book)
......
...@@ -192,6 +192,8 @@ class KVMsgType(Enum): ...@@ -192,6 +192,8 @@ class KVMsgType(Enum):
PULL_BACK = 5 PULL_BACK = 5
BARRIER = 6 BARRIER = 6
IP_ID = 7 IP_ID = 7
GET_SHAPE = 8
GET_SHAPE_BACK = 9
KVStoreMsg = namedtuple("KVStoreMsg", "type rank name id data shape c_ptr") KVStoreMsg = namedtuple("KVStoreMsg", "type rank name id data shape c_ptr")
...@@ -234,7 +236,7 @@ def _send_kv_msg(sender, msg, recv_id): ...@@ -234,7 +236,7 @@ def _send_kv_msg(sender, msg, recv_id):
msg.rank, msg.rank,
msg.name, msg.name,
tensor_id) tensor_id)
elif msg.type == KVMsgType.INIT: elif msg.type in (KVMsgType.INIT, KVMsgType.GET_SHAPE_BACK):
tensor_shape = F.zerocopy_to_dgl_ndarray(msg.shape) tensor_shape = F.zerocopy_to_dgl_ndarray(msg.shape)
_CAPI_SenderSendKVMsg( _CAPI_SenderSendKVMsg(
sender, sender,
...@@ -243,7 +245,7 @@ def _send_kv_msg(sender, msg, recv_id): ...@@ -243,7 +245,7 @@ def _send_kv_msg(sender, msg, recv_id):
msg.rank, msg.rank,
msg.name, msg.name,
tensor_shape) tensor_shape)
elif msg.type == KVMsgType.IP_ID: elif msg.type in (KVMsgType.IP_ID, KVMsgType.GET_SHAPE):
_CAPI_SenderSendKVMsg( _CAPI_SenderSendKVMsg(
sender, sender,
int(recv_id), int(recv_id),
...@@ -296,7 +298,7 @@ def _recv_kv_msg(receiver): ...@@ -296,7 +298,7 @@ def _recv_kv_msg(receiver):
shape=None, shape=None,
c_ptr=msg_ptr) c_ptr=msg_ptr)
return msg return msg
elif msg_type == KVMsgType.INIT: elif msg_type in (KVMsgType.INIT, KVMsgType.GET_SHAPE_BACK):
name = _CAPI_ReceiverGetKVMsgName(msg_ptr) name = _CAPI_ReceiverGetKVMsgName(msg_ptr)
tensor_shape = F.zerocopy_from_dgl_ndarray(_CAPI_ReceiverGetKVMsgShape(msg_ptr)) tensor_shape = F.zerocopy_from_dgl_ndarray(_CAPI_ReceiverGetKVMsgShape(msg_ptr))
msg = KVStoreMsg( msg = KVStoreMsg(
...@@ -308,7 +310,7 @@ def _recv_kv_msg(receiver): ...@@ -308,7 +310,7 @@ def _recv_kv_msg(receiver):
shape=tensor_shape, shape=tensor_shape,
c_ptr=msg_ptr) c_ptr=msg_ptr)
return msg return msg
elif msg_type == KVMsgType.IP_ID: elif msg_type in (KVMsgType.IP_ID, KVMsgType.GET_SHAPE):
name = _CAPI_ReceiverGetKVMsgName(msg_ptr) name = _CAPI_ReceiverGetKVMsgName(msg_ptr)
msg = KVStoreMsg( msg = KVStoreMsg(
type=msg_type, type=msg_type,
......
...@@ -498,14 +498,17 @@ static void send_kv_message(network::Sender* sender, ...@@ -498,14 +498,17 @@ static void send_kv_message(network::Sender* sender,
CHECK_EQ(sender->Send(send_kv_msg, recv_id), ADD_SUCCESS); CHECK_EQ(sender->Send(send_kv_msg, recv_id), ADD_SUCCESS);
if (kv_msg->msg_type != kFinalMsg && if (kv_msg->msg_type != kFinalMsg &&
kv_msg->msg_type != kBarrierMsg && kv_msg->msg_type != kBarrierMsg &&
kv_msg->msg_type != kIPIDMsg) { kv_msg->msg_type != kIPIDMsg &&
kv_msg->msg_type != kGetShapeMsg) {
// Send ArrayMeta // Send ArrayMeta
ArrayMeta meta(kv_msg->msg_type); ArrayMeta meta(kv_msg->msg_type);
if (kv_msg->msg_type != kInitMsg) { if (kv_msg->msg_type != kInitMsg &&
kv_msg->msg_type != kGetShapeBackMsg) {
meta.AddArray(kv_msg->id); meta.AddArray(kv_msg->id);
} }
if (kv_msg->msg_type != kPullMsg && if (kv_msg->msg_type != kPullMsg &&
kv_msg->msg_type != kInitMsg) { kv_msg->msg_type != kInitMsg &&
kv_msg->msg_type != kGetShapeBackMsg) {
meta.AddArray(kv_msg->data); meta.AddArray(kv_msg->data);
} }
if (kv_msg->msg_type != kPullMsg && if (kv_msg->msg_type != kPullMsg &&
...@@ -523,7 +526,9 @@ static void send_kv_message(network::Sender* sender, ...@@ -523,7 +526,9 @@ static void send_kv_message(network::Sender* sender,
} }
CHECK_EQ(sender->Send(send_meta_msg, recv_id), ADD_SUCCESS); CHECK_EQ(sender->Send(send_meta_msg, recv_id), ADD_SUCCESS);
// Send ID NDArray // Send ID NDArray
if (kv_msg->msg_type != kInitMsg) { if (kv_msg->msg_type != kInitMsg &&
kv_msg->msg_type != kGetShapeMsg &&
kv_msg->msg_type != kGetShapeBackMsg) {
Message send_id_msg; Message send_id_msg;
send_id_msg.data = static_cast<char*>(kv_msg->id->data); send_id_msg.data = static_cast<char*>(kv_msg->id->data);
send_id_msg.size = kv_msg->id.GetSize(); send_id_msg.size = kv_msg->id.GetSize();
...@@ -535,7 +540,9 @@ static void send_kv_message(network::Sender* sender, ...@@ -535,7 +540,9 @@ static void send_kv_message(network::Sender* sender,
} }
// Send data NDArray // Send data NDArray
if (kv_msg->msg_type != kPullMsg && if (kv_msg->msg_type != kPullMsg &&
kv_msg->msg_type != kInitMsg) { kv_msg->msg_type != kInitMsg &&
kv_msg->msg_type != kGetShapeMsg &&
kv_msg->msg_type != kGetShapeBackMsg) {
Message send_data_msg; Message send_data_msg;
send_data_msg.data = static_cast<char*>(kv_msg->data->data); send_data_msg.data = static_cast<char*>(kv_msg->data->data);
send_data_msg.size = kv_msg->data.GetSize(); send_data_msg.size = kv_msg->data.GetSize();
...@@ -571,7 +578,8 @@ static KVStoreMsg* recv_kv_message(network::Receiver* receiver) { ...@@ -571,7 +578,8 @@ static KVStoreMsg* recv_kv_message(network::Receiver* receiver) {
recv_kv_msg.deallocator(&recv_kv_msg); recv_kv_msg.deallocator(&recv_kv_msg);
if (kv_msg->msg_type == kFinalMsg || if (kv_msg->msg_type == kFinalMsg ||
kv_msg->msg_type == kBarrierMsg || kv_msg->msg_type == kBarrierMsg ||
kv_msg->msg_type == kIPIDMsg) { kv_msg->msg_type == kIPIDMsg ||
kv_msg->msg_type == kGetShapeMsg) {
return kv_msg; return kv_msg;
} }
// Recv ArrayMeta // Recv ArrayMeta
...@@ -580,7 +588,8 @@ static KVStoreMsg* recv_kv_message(network::Receiver* receiver) { ...@@ -580,7 +588,8 @@ static KVStoreMsg* recv_kv_message(network::Receiver* receiver) {
ArrayMeta meta(recv_meta_msg.data, recv_meta_msg.size); ArrayMeta meta(recv_meta_msg.data, recv_meta_msg.size);
recv_meta_msg.deallocator(&recv_meta_msg); recv_meta_msg.deallocator(&recv_meta_msg);
// Recv ID NDArray // Recv ID NDArray
if (kv_msg->msg_type != kInitMsg) { if (kv_msg->msg_type != kInitMsg &&
kv_msg->msg_type != kGetShapeBackMsg) {
Message recv_id_msg; Message recv_id_msg;
CHECK_EQ(receiver->RecvFrom(&recv_id_msg, send_id), REMOVE_SUCCESS); CHECK_EQ(receiver->RecvFrom(&recv_id_msg, send_id), REMOVE_SUCCESS);
CHECK_EQ(meta.data_shape_[0], 1); CHECK_EQ(meta.data_shape_[0], 1);
...@@ -593,7 +602,8 @@ static KVStoreMsg* recv_kv_message(network::Receiver* receiver) { ...@@ -593,7 +602,8 @@ static KVStoreMsg* recv_kv_message(network::Receiver* receiver) {
} }
// Recv Data NDArray // Recv Data NDArray
if (kv_msg->msg_type != kPullMsg && if (kv_msg->msg_type != kPullMsg &&
kv_msg->msg_type != kInitMsg) { kv_msg->msg_type != kInitMsg &&
kv_msg->msg_type != kGetShapeBackMsg) {
Message recv_data_msg; Message recv_data_msg;
CHECK_EQ(receiver->RecvFrom(&recv_data_msg, send_id), REMOVE_SUCCESS); CHECK_EQ(receiver->RecvFrom(&recv_data_msg, send_id), REMOVE_SUCCESS);
int64_t ndim = meta.data_shape_[2]; int64_t ndim = meta.data_shape_[2];
...@@ -644,18 +654,23 @@ DGL_REGISTER_GLOBAL("network._CAPI_SenderSendKVMsg") ...@@ -644,18 +654,23 @@ DGL_REGISTER_GLOBAL("network._CAPI_SenderSendKVMsg")
std::string name = args[args_count++]; std::string name = args[args_count++];
kv_msg.name = name; kv_msg.name = name;
if (kv_msg.msg_type != kIPIDMsg && if (kv_msg.msg_type != kIPIDMsg &&
kv_msg.msg_type != kInitMsg) { kv_msg.msg_type != kInitMsg &&
kv_msg.msg_type != kGetShapeMsg &&
kv_msg.msg_type != kGetShapeBackMsg) {
kv_msg.id = args[args_count++]; kv_msg.id = args[args_count++];
} }
if (kv_msg.msg_type != kPullMsg && if (kv_msg.msg_type != kPullMsg &&
kv_msg.msg_type != kIPIDMsg && kv_msg.msg_type != kIPIDMsg &&
kv_msg.msg_type != kInitMsg) { kv_msg.msg_type != kInitMsg &&
kv_msg.msg_type != kGetShapeMsg &&
kv_msg.msg_type != kGetShapeBackMsg) {
kv_msg.data = args[args_count++]; kv_msg.data = args[args_count++];
} }
if (kv_msg.msg_type != kIPIDMsg && if (kv_msg.msg_type != kIPIDMsg &&
kv_msg.msg_type != kPullMsg && kv_msg.msg_type != kPullMsg &&
kv_msg.msg_type != kPushMsg && kv_msg.msg_type != kPushMsg &&
kv_msg.msg_type != kPullBackMsg) { kv_msg.msg_type != kPullBackMsg &&
kv_msg.msg_type != kGetShapeMsg) {
kv_msg.shape = args[args_count++]; kv_msg.shape = args[args_count++];
} }
} }
......
...@@ -64,7 +64,15 @@ enum MessageType { ...@@ -64,7 +64,15 @@ enum MessageType {
/*! /*!
* \brief IP and ID msg for KVStore * \brief IP and ID msg for KVStore
*/ */
kIPIDMsg = 7 kIPIDMsg = 7,
/*!
* \brief Get data shape msg for KVStore
*/
kGetShapeMsg = 8,
/*!
* \brief Get data shape back msg for KVStore
*/
kGetShapeBackMsg = 9
}; };
......
...@@ -66,26 +66,32 @@ def start_client(): ...@@ -66,26 +66,32 @@ def start_client():
meta_0 = my_client.get_data_meta('data_0') meta_0 = my_client.get_data_meta('data_0')
assert meta_0[0] == F.float32 assert meta_0[0] == F.float32
assert meta_0[1] == tuple(F.shape(data_0))
assert_array_equal(meta_0[2], partition_0) assert_array_equal(meta_0[2], partition_0)
meta_1 = my_client.get_data_meta('data_1') meta_1 = my_client.get_data_meta('data_1')
assert meta_1[0] == F.float32 assert meta_1[0] == F.float32
assert meta_1[1] == tuple(F.shape(data_1))
assert_array_equal(meta_1[2], partition_1) assert_array_equal(meta_1[2], partition_1)
meta_2 = my_client.get_data_meta('data_2') meta_2 = my_client.get_data_meta('data_2')
assert meta_2[0] == F.float32 assert meta_2[0] == F.float32
assert meta_2[1] == tuple(F.shape(data_0))
assert_array_equal(meta_2[2], partition_0) assert_array_equal(meta_2[2], partition_0)
meta_3 = my_client.get_data_meta('data_3') meta_3 = my_client.get_data_meta('data_3')
assert meta_3[0] == F.int64 assert meta_3[0] == F.int64
assert meta_3[1] == tuple(F.shape(data_3))
assert_array_equal(meta_3[2], partition_0) assert_array_equal(meta_3[2], partition_0)
meta_4 = my_client.get_data_meta('data_4') meta_4 = my_client.get_data_meta('data_4')
assert meta_4[0] == F.float64 assert meta_4[0] == F.float64
assert meta_4[1] == tuple(F.shape(data_4))
assert_array_equal(meta_3[2], partition_0) assert_array_equal(meta_3[2], partition_0)
meta_5 = my_client.get_data_meta('data_5') meta_5 = my_client.get_data_meta('data_5')
assert meta_5[0] == F.int32 assert meta_5[0] == F.int32
assert meta_5[1] == tuple(F.shape(data_5))
assert_array_equal(meta_3[2], partition_0) assert_array_equal(meta_3[2], partition_0)
my_client.push(name='data_0', id_tensor=F.tensor([0, 1, 2]), data_tensor=F.tensor([[1.,1.,1.],[2.,2.,2.],[3.,3.,3.]])) my_client.push(name='data_0', id_tensor=F.tensor([0, 1, 2]), data_tensor=F.tensor([[1.,1.,1.],[2.,2.,2.],[3.,3.,3.]]))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment