Unverified Commit 27520bc5 authored by Chao Ma's avatar Chao Ma Committed by GitHub
Browse files

[KVStore] add init_data() on client (#1466)

* add init_data on client

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* fix lint

* fix lint

* update

* update

* update

* update

* update

* update
parent 0c0a8974
......@@ -418,6 +418,7 @@ class KVServer(object):
name=str(client_id),
id=None,
data=None,
shape=None,
c_ptr=None)
_send_kv_msg(self._sender, msg, client_id)
......@@ -435,6 +436,7 @@ class KVServer(object):
name=shared_tensor,
id=None,
data=None,
shape=None,
c_ptr=None)
for client_id in range(len(self._client_namebook)):
......@@ -471,8 +473,33 @@ class KVServer(object):
name=msg.name,
id=msg.id,
data=res_tensor,
shape=None,
c_ptr=None)
_send_kv_msg(self._sender, back_msg, msg.rank)
# Init new data
elif msg.type == KVMsgType.INIT:
assert msg.rank == 0
data_str, target_name = msg.name.split('|')
data_name, data_type = self._deserialize_shared_tensor(data_str)
dtype = F.data_type_dict[data_type]
data_shape = F.asnumpy(msg.shape).tolist()
if self._server_id % self._group_count == 0: # master server
data_tensor = F.zeros(data_shape, dtype, F.cpu())
self.init_data(name=data_name, data_tensor=data_tensor)
else: # backup server
self.init_data(name=data_name)
g2l = self._data_store[target_name+'-g2l-']
self._data_store[data_name+'-g2l-'] = g2l
self._has_data.add(data_name+'-g2l-')
back_msg = KVStoreMsg(
type=KVMsgType.INIT,
rank=self._server_id,
name=msg.name,
id=None,
data=None,
shape=msg.shape,
c_ptr=None)
_send_kv_msg(self._sender, back_msg, 0)
# Barrier message
elif msg.type == KVMsgType.BARRIER:
self._barrier_count += 1
......@@ -483,6 +510,7 @@ class KVServer(object):
name=None,
id=None,
data=None,
shape=None,
c_ptr=None)
for client_id in range(self._client_count):
_send_kv_msg(self._sender, back_msg, client_id)
......@@ -522,6 +550,28 @@ class KVServer(object):
return str_data
def _deserialize_shared_tensor(self, data):
"""Deserialize shared tensor information sent from server
Parameters
----------
data : str
serialized string
Returns
-------
str
tensor name
str
data type
"""
data_list = data.split('/')
tensor_name = data_list[0]
data_type = data_list[-1]
return tensor_name, data_type
def _write_data_shape_type(self, filename, data):
"""Write data shape to a temp file.
......@@ -720,6 +770,7 @@ class KVClient(object):
name=self._addr,
id=None,
data=None,
shape=None,
c_ptr=None)
for server_id in range(self._server_count):
......@@ -757,6 +808,72 @@ class KVClient(object):
print("KVClient %d connect to kvstore successfully!" % self.get_id())
def init_data(self, name, shape, dtype, target_name):
"""Send message to kvserver to initialize new data and
get corresponded shared-tensor (e.g., partition_book, g2l) on kvclient.
The new data will be initialized to zeros.
Note that, this API must be invoked after the conenct() API.
Parameters
----------
name : str
data name
shape : list of int
data shape
dtype : dtype
data type
target_name : str
target name is used to find existing partition_book and g2l mapping.
"""
assert len(name) > 0, 'name cannot be empty.'
assert len(shape) > 0, 'shape cannot be empty.'
assert len(target_name) > 0, 'target_name cannot be empty.'
if self._client_id == 0: # only client_0 send message to server
partition_book = self._data_store[target_name+'-part-']
machines, count = np.unique(F.asnumpy(partition_book), return_counts=True)
assert shape[0] == len(partition_book)
# send message to all of the server nodes
for idx in range(len(machines)):
m_id = machines[idx]
data_str = self._serialize_shared_tensor(name, dtype)
data_str = data_str + '|' + target_name
partitioned_shape = shape.copy()
partitioned_shape[0] = count[idx]
for n in range(self._group_count):
server_id = m_id * self._group_count + n
msg = KVStoreMsg(
type=KVMsgType.INIT,
rank=0,
name=data_str,
id=None,
data=None,
shape=F.tensor(partitioned_shape),
c_ptr=None)
_send_kv_msg(self._sender, msg, server_id)
# recv confirmation message from server nodes
for server_id in range(self._server_count):
msg = _recv_kv_msg(self._receiver)
assert msg.type == KVMsgType.INIT
self.barrier() # wait all the client and server finish its job
g2l = self._data_store[target_name+'-g2l-']
partition_book = self._data_store[target_name+'-part-']
self._data_store[name+'-g2l-'] = g2l
self._data_store[name+'-part-'] = partition_book
self._has_data.add(name+'-g2l-')
self._has_data.add(name+'-part-')
# Read new data from shared-memory created by server
shape, data_type = self._read_data_shape_type(name+'-data-shape-'+str(self._machine_id))
assert data_type == get_type_str(dtype)
shared_data = empty_shared_mem(name+'-data-', False, shape, data_type)
dlpack = shared_data.to_dlpack()
self._data_store[name+'-data-'] = F.zerocopy_from_dlpack(dlpack)
self._has_data.add(name+'-data-')
self._data_name_list.append(name)
def print(self):
"""Print client information (Used by debug)
"""
......@@ -886,6 +1003,7 @@ class KVClient(object):
name=name,
id=partial_id,
data=partial_data,
shape=None,
c_ptr=None)
# randomly select a server node in target machine for load-balance
s_id = random.randint(machine[idx]*self._group_count, (machine[idx]+1)*self._group_count-1)
......@@ -967,6 +1085,7 @@ class KVClient(object):
name=name,
id=partial_id,
data=None,
shape=None,
c_ptr=None)
# randomly select a server node in target machine for load-balance
s_id = random.randint(machine[idx]*self._group_count, (machine[idx]+1)*self._group_count-1)
......@@ -985,6 +1104,7 @@ class KVClient(object):
name=name,
id=None,
data=local_data,
shape=None,
c_ptr=None)
msg_list.append(local_msg)
self._garbage_msg.append(local_msg)
......@@ -1013,6 +1133,7 @@ class KVClient(object):
name=None,
id=None,
data=None,
shape=None,
c_ptr=None)
for server_id in range(self._server_count):
......@@ -1035,6 +1156,7 @@ class KVClient(object):
name=None,
id=None,
data=None,
shape=None,
c_ptr=None)
_send_kv_msg(self._sender, msg, server_id)
......@@ -1102,6 +1224,29 @@ class KVClient(object):
return nic
def _serialize_shared_tensor(self, name, dtype):
"""Serialize shared tensor information.
Parameters
----------
name : str
tensor name
dtype : dtype
data type
Returns
-------
str
serialized string
"""
assert len(name) > 0, 'data name cannot be empty.'
str_data = name
str_data += '/'
str_data += get_type_str(dtype)
return str_data
def _deserialize_shared_tensor(self, data):
"""Deserialize shared tensor information sent from server
......
......@@ -194,7 +194,7 @@ class KVMsgType(Enum):
IP_ID = 7
KVStoreMsg = namedtuple("KVStoreMsg", "type rank name id data c_ptr")
KVStoreMsg = namedtuple("KVStoreMsg", "type rank name id data shape c_ptr")
"""Message of DGL kvstore
Data Field
......@@ -234,6 +234,15 @@ def _send_kv_msg(sender, msg, recv_id):
msg.rank,
msg.name,
tensor_id)
elif msg.type == KVMsgType.INIT:
tensor_shape = F.zerocopy_to_dgl_ndarray(msg.shape)
_CAPI_SenderSendKVMsg(
sender,
int(recv_id),
msg.type.value,
msg.rank,
msg.name,
tensor_shape)
elif msg.type == KVMsgType.IP_ID:
_CAPI_SenderSendKVMsg(
sender,
......@@ -284,6 +293,19 @@ def _recv_kv_msg(receiver):
name=name,
id=tensor_id,
data=None,
shape=None,
c_ptr=msg_ptr)
return msg
elif msg_type == KVMsgType.INIT:
name = _CAPI_ReceiverGetKVMsgName(msg_ptr)
tensor_shape = F.zerocopy_from_dgl_ndarray(_CAPI_ReceiverGetKVMsgShape(msg_ptr))
msg = KVStoreMsg(
type=msg_type,
rank=rank,
name=name,
id=None,
data=None,
shape=tensor_shape,
c_ptr=msg_ptr)
return msg
elif msg_type == KVMsgType.IP_ID:
......@@ -294,6 +316,7 @@ def _recv_kv_msg(receiver):
name=name,
id=None,
data=None,
shape=None,
c_ptr=msg_ptr)
return msg
elif msg_type in (KVMsgType.FINAL, KVMsgType.BARRIER):
......@@ -303,6 +326,7 @@ def _recv_kv_msg(receiver):
name=None,
id=None,
data=None,
shape=None,
c_ptr=msg_ptr)
return msg
else:
......@@ -315,6 +339,7 @@ def _recv_kv_msg(receiver):
name=name,
id=tensor_id,
data=data,
shape=None,
c_ptr=msg_ptr)
return msg
......
......@@ -485,10 +485,18 @@ static void send_kv_message(network::Sender* sender,
kv_msg->msg_type != kIPIDMsg) {
// Send ArrayMeta
ArrayMeta meta(kv_msg->msg_type);
if (kv_msg->msg_type != kInitMsg) {
meta.AddArray(kv_msg->id);
if (kv_msg->msg_type != kPullMsg) {
}
if (kv_msg->msg_type != kPullMsg &&
kv_msg->msg_type != kInitMsg) {
meta.AddArray(kv_msg->data);
}
if (kv_msg->msg_type != kPullMsg &&
kv_msg->msg_type != kPushMsg &&
kv_msg->msg_type != kPullBackMsg) {
meta.AddArray(kv_msg->shape);
}
int64_t meta_size = 0;
char* meta_data = meta.Serialize(&meta_size);
Message send_meta_msg;
......@@ -499,14 +507,19 @@ static void send_kv_message(network::Sender* sender,
}
CHECK_EQ(sender->Send(send_meta_msg, recv_id), ADD_SUCCESS);
// Send ID NDArray
if (kv_msg->msg_type != kInitMsg) {
Message send_id_msg;
send_id_msg.data = static_cast<char*>(kv_msg->id->data);
send_id_msg.size = kv_msg->id.GetSize();
NDArray id = kv_msg->id;
if (auto_free) {
send_id_msg.deallocator = [id](Message*) {};
}
CHECK_EQ(sender->Send(send_id_msg, recv_id), ADD_SUCCESS);
}
// Send data NDArray
if (kv_msg->msg_type != kPullMsg) {
if (kv_msg->msg_type != kPullMsg &&
kv_msg->msg_type != kInitMsg) {
Message send_data_msg;
send_data_msg.data = static_cast<char*>(kv_msg->data->data);
send_data_msg.size = kv_msg->data.GetSize();
......@@ -516,6 +529,19 @@ static void send_kv_message(network::Sender* sender,
}
CHECK_EQ(sender->Send(send_data_msg, recv_id), ADD_SUCCESS);
}
// Send shape NDArray
if (kv_msg->msg_type != kPullMsg &&
kv_msg->msg_type != kPushMsg &&
kv_msg->msg_type != kPullBackMsg) {
Message send_shape_msg;
send_shape_msg.data = static_cast<char*>(kv_msg->shape->data);
send_shape_msg.size = kv_msg->shape.GetSize();
NDArray shape = kv_msg->shape;
if (auto_free) {
send_shape_msg.deallocator = [shape](Message*) {};
}
CHECK_EQ(sender->Send(send_shape_msg, recv_id), ADD_SUCCESS);
}
}
}
......@@ -538,6 +564,7 @@ static KVStoreMsg* recv_kv_message(network::Receiver* receiver) {
ArrayMeta meta(recv_meta_msg.data, recv_meta_msg.size);
recv_meta_msg.deallocator(&recv_meta_msg);
// Recv ID NDArray
if (kv_msg->msg_type != kInitMsg) {
Message recv_id_msg;
CHECK_EQ(receiver->RecvFrom(&recv_id_msg, send_id), REMOVE_SUCCESS);
CHECK_EQ(meta.data_shape_[0], 1);
......@@ -547,14 +574,17 @@ static KVStoreMsg* recv_kv_message(network::Receiver* receiver) {
DLContext{kDLCPU, 0},
recv_id_msg.data,
AUTO_FREE);
}
// Recv Data NDArray
if (kv_msg->msg_type != kPullMsg) {
if (kv_msg->msg_type != kPullMsg &&
kv_msg->msg_type != kInitMsg) {
Message recv_data_msg;
CHECK_EQ(receiver->RecvFrom(&recv_data_msg, send_id), REMOVE_SUCCESS);
CHECK_GE(meta.data_shape_[2], 1);
int64_t ndim = meta.data_shape_[2];
CHECK_GE(ndim, 1);
std::vector<int64_t> vec_shape;
for (int i = 3; i < meta.data_shape_.size(); ++i) {
vec_shape.push_back(meta.data_shape_[i]);
for (int i = 0; i < ndim; ++i) {
vec_shape.push_back(meta.data_shape_[3+i]);
}
kv_msg->data = CreateNDArrayFromRaw(
vec_shape,
......@@ -563,6 +593,25 @@ static KVStoreMsg* recv_kv_message(network::Receiver* receiver) {
recv_data_msg.data,
AUTO_FREE);
}
// Recv Shape
if (kv_msg->msg_type != kPullMsg &&
kv_msg->msg_type != kPushMsg &&
kv_msg->msg_type != kPullBackMsg) {
Message recv_shape_msg;
CHECK_EQ(receiver->RecvFrom(&recv_shape_msg, send_id), REMOVE_SUCCESS);
int64_t ndim = meta.data_shape_[0];
CHECK_GE(ndim, 1);
std::vector<int64_t> vec_shape;
for (int i = 0; i < ndim; ++i) {
vec_shape.push_back(meta.data_shape_[1+i]);
}
kv_msg->shape = CreateNDArrayFromRaw(
vec_shape,
DLDataType{kDLInt, 64, 1},
DLContext{kDLCPU, 0},
recv_shape_msg.data,
AUTO_FREE);
}
return kv_msg;
}
......@@ -578,12 +627,21 @@ DGL_REGISTER_GLOBAL("network._CAPI_SenderSendKVMsg")
if (kv_msg.msg_type != kFinalMsg && kv_msg.msg_type != kBarrierMsg) {
std::string name = args[args_count++];
kv_msg.name = name;
if (kv_msg.msg_type != kIPIDMsg) {
if (kv_msg.msg_type != kIPIDMsg &&
kv_msg.msg_type != kInitMsg) {
kv_msg.id = args[args_count++];
}
if (kv_msg.msg_type != kPullMsg && kv_msg.msg_type != kIPIDMsg) {
if (kv_msg.msg_type != kPullMsg &&
kv_msg.msg_type != kIPIDMsg &&
kv_msg.msg_type != kInitMsg) {
kv_msg.data = args[args_count++];
}
if (kv_msg.msg_type != kIPIDMsg &&
kv_msg.msg_type != kPullMsg &&
kv_msg.msg_type != kPushMsg &&
kv_msg.msg_type != kPullBackMsg) {
kv_msg.shape = args[args_count++];
}
}
send_kv_message(sender, &kv_msg, recv_id, AUTO_FREE);
});
......@@ -630,6 +688,13 @@ DGL_REGISTER_GLOBAL("network._CAPI_ReceiverGetKVMsgData")
*rv = msg->data;
});
DGL_REGISTER_GLOBAL("network._CAPI_ReceiverGetKVMsgShape")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
KVMsgHandle chandle = args[0];
network::KVStoreMsg* msg = static_cast<KVStoreMsg*>(chandle);
*rv = msg->shape;
});
DGL_REGISTER_GLOBAL("network._CAPI_DeleteKVMsg")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
KVMsgHandle chandle = args[0];
......
......@@ -194,6 +194,10 @@ class KVStoreMsg {
* \brief data matrix
*/
NDArray data;
/*!
* \brief data shape
*/
NDArray shape;
};
} // namespace network
......
......@@ -40,10 +40,13 @@ def start_client():
my_client = KVClient(server_namebook=server_namebook)
my_client.connect()
my_client.init_data(name='data_2', shape=[num_entries, dim_size], dtype=F.float32, target_name='data_0')
name_list = my_client.get_data_name_list()
assert len(name_list) == 2
assert len(name_list) == 3
assert 'data_0' in name_list
assert 'data_1' in name_list
assert 'data_2' in name_list
meta_0 = my_client.get_data_meta('data_0')
assert meta_0[0] == F.float32
......@@ -53,12 +56,19 @@ def start_client():
assert meta_1[0] == F.float32
assert_array_equal(meta_1[2], partition_1)
my_client.push(name='data_0', id_tensor=F.tensor([0, 1, 2]), data_tensor=F.tensor([[1.,1.,1.],[2.,2.,2.],[3.,3.,3.]]))
meta_2 = my_client.get_data_meta('data_2')
assert meta_2[0] == F.float32
assert_array_equal(meta_2[2], partition_0)
res = my_client.pull(name='data_0', id_tensor=F.tensor([0, 1, 2]))
my_client.push(name='data_0', id_tensor=F.tensor([0, 1, 2]), data_tensor=F.tensor([[1.,1.,1.],[2.,2.,2.],[3.,3.,3.]]))
my_client.push(name='data_2', id_tensor=F.tensor([0, 1, 2]), data_tensor=F.tensor([[1.,1.,1.],[2.,2.,2.],[3.,3.,3.]]))
target = F.tensor([[1.,1.,1.],[2.,2.,2.],[3.,3.,3.]])
res = my_client.pull(name='data_0', id_tensor=F.tensor([0, 1, 2]))
assert_array_equal(res, target)
res = my_client.pull(name='data_2', id_tensor=F.tensor([0, 1, 2]))
assert_array_equal(res, target)
my_client.shut_down()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment